"""
财务报告服务测试

测试FinancialReportsService的核心功能
"""

import pytest
import asyncio
import pandas as pd
from datetime import datetime, timedelta
from unittest.mock import Mock, AsyncMock, patch, MagicMock

from quickstock.services.financial_reports_service import FinancialReportsService
from quickstock.models import (
    FinancialReport, EarningsForecast, FlashReport,
    FinancialReportsRequest, EarningsForecastRequest, FlashReportsRequest
)
from quickstock.core.errors import (
    FinancialDataError, ReportNotFoundError, ForecastDataError, 
    FlashReportError, ValidationError, NetworkError, RateLimitError
)


# 模块级别的fixtures
@pytest.fixture
def mock_data_manager():
    """模拟数据管理器"""
    return AsyncMock()

@pytest.fixture
def mock_cache_layer():
    """模拟缓存层"""
    cache = AsyncMock()
    cache.get.return_value = None  # 默认缓存未命中
    return cache

@pytest.fixture
def service(mock_data_manager, mock_cache_layer):
    """创建财务报告服务实例"""
    return FinancialReportsService(mock_data_manager, mock_cache_layer)

@pytest.fixture
def sample_financial_report_data():
    """示例财务报告数据"""
    return pd.DataFrame([
        {
            'ts_code': '000001.SZ',
            'report_date': '20231231',
            'report_type': 'A',
            'total_revenue': 1000000.0,
            'net_profit': 100000.0,
            'total_assets': 5000000.0,
            'total_liabilities': 3000000.0,
            'shareholders_equity': 2000000.0,
            'operating_cash_flow': 150000.0,
            'eps': 1.25,
            'roe': 5.0
        },
        {
            'ts_code': '000001.SZ',
            'report_date': '20230930',
            'report_type': 'Q3',
            'total_revenue': 750000.0,
            'net_profit': 75000.0,
            'total_assets': 4800000.0,
            'total_liabilities': 2900000.0,
            'shareholders_equity': 1900000.0,
            'operating_cash_flow': 120000.0,
            'eps': 0.94,
            'roe': 3.9
        }
    ])

@pytest.fixture
def sample_earnings_forecast_data():
    """示例业绩预告数据"""
    return pd.DataFrame([
        {
            'ts_code': '000001.SZ',
            'forecast_date': '20240115',
            'forecast_period': '20231231',
            'forecast_type': '预增',
            'net_profit_min': 90000.0,
            'net_profit_max': 110000.0,
            'growth_rate_min': 10.0,
            'growth_rate_max': 20.0,
            'forecast_summary': '预计净利润同比增长10%-20%'
        }
    ])

@pytest.fixture
def sample_flash_report_data():
    """示例业绩快报数据"""
    return pd.DataFrame([
        {
            'ts_code': '000001.SZ',
            'report_date': '20240130',
            'publish_date': '20240130',
            'report_period': '20231231',
            'total_revenue': 1000000.0,
            'net_profit': 100000.0,
            'revenue_growth': 15.0,
            'profit_growth': 12.0,
            'eps': 1.25,
            'report_summary': '2023年业绩快报'
        }
    ])


class TestFinancialReportsService:
    """财务报告服务测试类"""


class TestGetFinancialReports:
    """测试获取财务报告功能"""
    
    @pytest.mark.asyncio
    async def test_get_financial_reports_success(self, service, mock_data_manager, 
                                               sample_financial_report_data):
        """测试成功获取财务报告"""
        # 准备测试数据
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20230101',
            end_date='20231231'
        )
        
        # 模拟数据管理器返回数据
        mock_data_manager.get_data.return_value = sample_financial_report_data
        
        # 执行测试
        result = await service.get_financial_reports(request)
        
        # 验证结果
        assert len(result) == 2
        assert all(isinstance(report, FinancialReport) for report in result)
        assert result[0].ts_code == '000001.SZ'
        assert result[0].total_revenue == 1000000.0
        assert result[1].report_type == 'Q3'
        
        # 验证数据管理器被正确调用
        mock_data_manager.get_data.assert_called_once()
        call_args = mock_data_manager.get_data.call_args[0][0]
        assert call_args.data_type == 'financial_reports'
        assert call_args.ts_code == '000001.SZ'
    
    @pytest.mark.asyncio
    async def test_get_financial_reports_from_cache(self, service, mock_cache_layer,
                                                  sample_financial_report_data):
        """测试从缓存获取财务报告"""
        # 准备测试数据
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20230101',
            end_date='20231231'
        )
        
        # 模拟缓存命中
        mock_cache_layer.get.return_value = sample_financial_report_data
        
        # 执行测试
        result = await service.get_financial_reports(request)
        
        # 验证结果
        assert len(result) == 2
        assert service._stats['cache_hits'] == 1
        assert service._stats['cache_misses'] == 0
    
    @pytest.mark.asyncio
    async def test_get_financial_reports_empty_data(self, service, mock_data_manager):
        """测试获取空财务报告数据"""
        # 准备测试数据
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20230101',
            end_date='20231231'
        )
        
        # 模拟数据管理器返回空数据
        mock_data_manager.get_data.return_value = pd.DataFrame()
        
        # 执行测试并验证异常
        with pytest.raises(ReportNotFoundError) as exc_info:
            await service.get_financial_reports(request)
        
        assert '000001.SZ' in str(exc_info.value)
    
    @pytest.mark.asyncio
    async def test_get_financial_reports_invalid_request(self, service):
        """测试无效请求参数"""
        # 创建无效请求
        request = FinancialReportsRequest(
            ts_code='',  # 空股票代码
            start_date='20230101',
            end_date='20231231'
        )
        
        # 执行测试并验证异常
        with pytest.raises(ValidationError):
            await service.get_financial_reports(request)


class TestGetEarningsForecast:
    """测试获取业绩预告功能"""
    
    @pytest.mark.asyncio
    async def test_get_earnings_forecast_success(self, service, mock_data_manager,
                                               sample_earnings_forecast_data):
        """测试成功获取业绩预告"""
        # 准备测试数据
        request = EarningsForecastRequest(
            ts_code='000001.SZ',
            start_date='20240101',
            end_date='20240131'
        )
        
        # 模拟数据管理器返回数据
        mock_data_manager.get_data.return_value = sample_earnings_forecast_data
        
        # 执行测试
        result = await service.get_earnings_forecast(request)
        
        # 验证结果
        assert len(result) == 1
        assert isinstance(result[0], EarningsForecast)
        assert result[0].ts_code == '000001.SZ'
        assert result[0].forecast_type == '预增'
        assert result[0].net_profit_min == 90000.0
    
    @pytest.mark.asyncio
    async def test_get_earnings_forecast_empty_data(self, service, mock_data_manager):
        """测试获取空业绩预告数据"""
        # 准备测试数据
        request = EarningsForecastRequest(
            ts_code='000001.SZ',
            start_date='20240101',
            end_date='20240131'
        )
        
        # 模拟数据管理器返回空数据
        mock_data_manager.get_data.return_value = pd.DataFrame()
        
        # 执行测试并验证异常
        with pytest.raises(ForecastDataError) as exc_info:
            await service.get_earnings_forecast(request)
        
        assert '000001.SZ' in str(exc_info.value)


class TestGetEarningsFlashReports:
    """测试获取业绩快报功能"""
    
    @pytest.mark.asyncio
    async def test_get_flash_reports_success(self, service, mock_data_manager,
                                           sample_flash_report_data):
        """测试成功获取业绩快报"""
        # 准备测试数据
        request = FlashReportsRequest(
            ts_code='000001.SZ',
            start_date='20240101',
            end_date='20240131'
        )
        
        # 模拟数据管理器返回数据
        mock_data_manager.get_data.return_value = sample_flash_report_data
        
        # 执行测试
        result = await service.get_earnings_flash_reports(request)
        
        # 验证结果
        assert len(result) == 1
        assert isinstance(result[0], FlashReport)
        assert result[0].ts_code == '000001.SZ'
        assert result[0].total_revenue == 1000000.0
        assert result[0].revenue_growth == 15.0
    
    @pytest.mark.asyncio
    async def test_get_flash_reports_empty_data(self, service, mock_data_manager):
        """测试获取空业绩快报数据"""
        # 准备测试数据
        request = FlashReportsRequest(
            ts_code='000001.SZ',
            start_date='20240101',
            end_date='20240131'
        )
        
        # 模拟数据管理器返回空数据
        mock_data_manager.get_data.return_value = pd.DataFrame()
        
        # 执行测试并验证异常
        with pytest.raises(FlashReportError) as exc_info:
            await service.get_earnings_flash_reports(request)
        
        assert '000001.SZ' in str(exc_info.value)


class TestBatchFinancialData:
    """测试批量获取财务数据功能"""
    
    @pytest.mark.asyncio
    async def test_get_batch_financial_data_success(self, service):
        """测试成功批量获取财务数据"""
        # 准备测试数据
        stock_codes = ['000001.SZ', '000002.SZ']
        data_types = ['financial_reports']
        
        # 模拟单个股票处理成功
        with patch.object(service, '_process_single_stock_batch') as mock_process:
            mock_process.return_value = {
                'financial_reports': [{'ts_code': '000001.SZ', 'total_revenue': 1000000.0}]
            }
            
            # 执行测试
            result = await service.get_batch_financial_data(
                stock_codes, data_types, '20230101', '20231231'
            )
            
            # 验证结果
            assert result['success_count'] == 2
            assert result['failed_count'] == 0
            assert result['total_count'] == 2
            assert len(result['data']) == 2
            assert '000001.SZ' in result['data']
            assert '000002.SZ' in result['data']
    
    @pytest.mark.asyncio
    async def test_get_batch_financial_data_partial_failure(self, service):
        """测试批量获取财务数据部分失败"""
        # 准备测试数据
        stock_codes = ['000001.SZ', '000002.SZ']
        data_types = ['financial_reports']
        
        # 模拟部分股票处理失败
        def mock_process_side_effect(stock_code, *args, **kwargs):
            if stock_code == '000001.SZ':
                return {'financial_reports': [{'ts_code': '000001.SZ'}]}
            else:
                raise Exception("处理失败")
        
        with patch.object(service, '_process_single_stock_batch') as mock_process:
            mock_process.side_effect = mock_process_side_effect
            
            # 执行测试
            result = await service.get_batch_financial_data(
                stock_codes, data_types, '20230101', '20231231'
            )
            
            # 验证结果
            assert result['success_count'] == 1
            assert result['failed_count'] == 1
            assert result['total_count'] == 2
            assert '000001.SZ' in result['data']
            assert '000002.SZ' in result['failed_stocks']
    
    @pytest.mark.asyncio
    async def test_get_batch_financial_data_empty_stock_codes(self, service):
        """测试空股票代码列表"""
        with pytest.raises(ValidationError) as exc_info:
            await service.get_batch_financial_data([], ['financial_reports'])
        
        assert "股票代码列表不能为空" in str(exc_info.value)
    
    @pytest.mark.asyncio
    async def test_get_batch_financial_data_exceed_limit(self, service):
        """测试超过批量大小限制"""
        # 创建超过限制的股票代码列表
        stock_codes = [f'00000{i}.SZ' for i in range(100)]  # 超过50的限制
        
        with pytest.raises(ValidationError) as exc_info:
            await service.get_batch_financial_data(stock_codes, ['financial_reports'])
        
        assert "批量大小超过限制" in str(exc_info.value)


class TestRetryMechanism:
    """测试重试机制"""
    
    @pytest.mark.asyncio
    async def test_fetch_with_retry_success_after_failure(self, service):
        """测试重试后成功"""
        # 模拟函数：第一次失败，第二次成功
        call_count = 0
        async def mock_fetch_func(request):
            nonlocal call_count
            call_count += 1
            if call_count == 1:
                raise NetworkError("网络错误")
            return pd.DataFrame([{'ts_code': '000001.SZ'}])
        
        request = FinancialReportsRequest(ts_code='000001.SZ')
        
        # 执行测试
        result = await service._fetch_with_retry(mock_fetch_func, request)
        
        # 验证结果
        assert not result.empty
        assert call_count == 2
        assert service._stats['retry_attempts'] == 1
    
    @pytest.mark.asyncio
    async def test_fetch_with_retry_max_retries_exceeded(self, service):
        """测试超过最大重试次数"""
        # 模拟函数：总是失败
        async def mock_fetch_func(request):
            raise NetworkError("网络错误")
        
        request = FinancialReportsRequest(ts_code='000001.SZ')
        
        # 执行测试并验证异常
        with pytest.raises(NetworkError):
            await service._fetch_with_retry(mock_fetch_func, request, max_retries=2)
    
    @pytest.mark.asyncio
    async def test_should_retry_logic(self, service):
        """测试重试逻辑判断"""
        # 网络错误应该重试
        assert service._should_retry(NetworkError("网络错误"))
        assert service._should_retry(RateLimitError("速率限制"))
        
        # 验证错误不应该重试
        assert not service._should_retry(ValidationError("参数错误"))
        assert not service._should_retry(ReportNotFoundError('000001.SZ', '20231231'))


class TestCacheManagement:
    """测试缓存管理功能"""
    
    @pytest.mark.asyncio
    async def test_cache_ttl_configuration(self, service, mock_cache_layer):
        """测试缓存TTL配置"""
        # 验证不同数据类型的TTL配置
        assert service.CACHE_TTL_CONFIG['financial_reports'] == 24
        assert service.CACHE_TTL_CONFIG['earnings_forecast'] == 4
        assert service.CACHE_TTL_CONFIG['flash_reports'] == 1
    
    @pytest.mark.asyncio
    async def test_cache_key_generation(self, service):
        """测试缓存键生成"""
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20230101',
            end_date='20231231'
        )
        
        cache_key = service._generate_cache_key('financial_reports', request)
        
        # 验证缓存键格式
        assert cache_key.startswith('financial_financial_reports_')
        assert len(cache_key) > 20  # 应该包含哈希值
    
    @pytest.mark.asyncio
    async def test_clear_cache(self, service, mock_cache_layer):
        """测试清理缓存"""
        await service.clear_cache()
        mock_cache_layer.clear.assert_called_once()


class TestDataProcessing:
    """测试数据处理功能"""
    
    def test_process_financial_reports_data(self, service, sample_financial_report_data):
        """测试处理财务报告数据"""
        request = FinancialReportsRequest(ts_code='000001.SZ')
        
        result = service._process_financial_reports_data(sample_financial_report_data, request)
        
        assert len(result) == 2
        assert all(isinstance(report, FinancialReport) for report in result)
        assert result[0].ts_code == '000001.SZ'
        assert result[0].total_revenue == 1000000.0
    
    def test_process_financial_reports_data_invalid_data(self, service):
        """测试处理无效财务报告数据"""
        # 创建包含无效数据的DataFrame
        invalid_data = pd.DataFrame([
            {
                'ts_code': 'INVALID',  # 无效股票代码
                'report_date': '20231231',
                'report_type': 'A',
                'total_revenue': 'invalid',  # 无效数值
                'net_profit': 100000.0,
            }
        ])
        
        request = FinancialReportsRequest(ts_code='INVALID')
        
        # 应该抛出异常，因为没有有效数据
        with pytest.raises(ReportNotFoundError):
            service._process_financial_reports_data(invalid_data, request)
    
    def test_convert_to_dataframe_and_back(self, service):
        """测试数据格式转换"""
        # 创建测试数据
        reports = [
            FinancialReport(
                ts_code='000001.SZ',
                report_date='20231231',
                report_type='A',
                total_revenue=1000000.0,
                net_profit=100000.0,
                total_assets=5000000.0,
                total_liabilities=3000000.0,
                shareholders_equity=2000000.0,
                operating_cash_flow=150000.0,
                eps=1.25,
                roe=5.0
            )
        ]
        
        # 转换为DataFrame
        df = service._convert_financial_reports_to_dataframe(reports)
        assert len(df) == 1
        assert df.iloc[0]['ts_code'] == '000001.SZ'
        
        # 转换回对象列表
        converted_reports = service._convert_to_financial_reports(df)
        assert len(converted_reports) == 1
        assert converted_reports[0].ts_code == '000001.SZ'
        assert converted_reports[0].total_revenue == 1000000.0


class TestServiceStats:
    """测试服务统计功能"""
    
    def test_initial_stats(self, service):
        """测试初始统计状态"""
        stats = service.get_service_stats()
        
        assert stats['total_requests'] == 0
        assert stats['successful_requests'] == 0
        assert stats['failed_requests'] == 0
        assert stats['success_rate'] == 0
        assert stats['cache_hit_rate'] == 0
    
    def test_update_stats(self, service):
        """测试统计更新"""
        # 模拟一些操作
        service._update_request_stats()
        service._update_success_stats(1.5)
        service._stats['cache_hits'] = 1
        service._stats['cache_misses'] = 1
        
        stats = service.get_service_stats()
        
        assert stats['total_requests'] == 1
        assert stats['successful_requests'] == 1
        assert stats['success_rate'] == 1.0
        assert stats['cache_hit_rate'] == 0.5
        assert stats['average_processing_time'] == 1.5
    
    def test_reset_stats(self, service):
        """测试重置统计"""
        # 先更新一些统计
        service._update_request_stats()
        service._update_success_stats(1.0)
        
        # 重置统计
        service.reset_stats()
        
        stats = service.get_service_stats()
        assert stats['total_requests'] == 0
        assert stats['successful_requests'] == 0


class TestServiceLifecycle:
    """测试服务生命周期"""
    
    @pytest.mark.asyncio
    async def test_service_close(self, service, mock_cache_layer):
        """测试服务关闭"""
        await service.close()
        mock_cache_layer.close.assert_called_once()


# 集成测试
class TestIntegration:
    """集成测试"""
    
    @pytest.mark.asyncio
    async def test_full_workflow_financial_reports(self, service, mock_data_manager, 
                                                  mock_cache_layer, sample_financial_report_data):
        """测试完整的财务报告获取流程"""
        # 准备测试数据
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20230101',
            end_date='20231231',
            report_type='A'
        )
        
        # 模拟缓存未命中
        mock_cache_layer.get.return_value = None
        
        # 模拟数据管理器返回数据
        mock_data_manager.get_data.return_value = sample_financial_report_data
        
        # 执行完整流程
        result = await service.get_financial_reports(request)
        
        # 验证结果
        assert len(result) == 2
        assert all(isinstance(report, FinancialReport) for report in result)
        
        # 验证缓存被调用
        mock_cache_layer.get.assert_called_once()
        mock_cache_layer.set.assert_called_once()
        
        # 验证数据管理器被调用
        mock_data_manager.get_data.assert_called_once()
        
        # 验证统计信息
        stats = service.get_service_stats()
        assert stats['total_requests'] == 1
        assert stats['successful_requests'] == 1
        assert stats['cache_misses'] == 1


if __name__ == '__main__':
    pytest.main([__file__])