"""
财务报告综合集成测试

测试财务报告功能的完整工作流程，包括：
- 端到端财务数据工作流测试
- 带TTL验证的缓存集成测试
- 带速率限制场景的批处理测试
- 各种失败模式的错误处理集成测试
- 并发请求处理和线程安全测试
"""

import pytest
import asyncio
import pandas as pd
import time
from datetime import datetime, timedelta
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from concurrent.futures import ThreadPoolExecutor
import threading
import random

from quickstock.client import QuickStockClient
from quickstock.config import Config
from quickstock.services.financial_reports_service import FinancialReportsService
from quickstock.core.data_manager import DataManager
from quickstock.core.cache import CacheLayer
from quickstock.models import (
    FinancialReport, EarningsForecast, FlashReport,
    FinancialReportsRequest, EarningsForecastRequest, FlashReportsRequest
)
from quickstock.core.errors import (
    FinancialDataError, ReportNotFoundError, ForecastDataError, 
    FlashReportError, NetworkError, RateLimitError, ValidationError
)


class TestFinancialReportsEndToEndWorkflow:
    """财务报告端到端工作流程测试"""
    
    @pytest.fixture
    def config(self):
        """测试配置"""
        return Config(
            cache_enabled=True,
            cache_expire_hours=24,
            max_concurrent_requests=10,
            enable_baostock=True,
            log_level='INFO'
        )
    
    @pytest.fixture
    def client(self, config):
        """创建测试客户端"""
        return QuickStockClient(config)
    
    @pytest.fixture
    def sample_financial_data(self):
        """样本财务数据"""
        return {
            'financial_reports': pd.DataFrame({
                'ts_code': ['000001.SZ', '000001.SZ'],
                'report_date': ['20231231', '20230930'],
                'report_type': ['A', 'Q3'],
                'total_revenue': [1000000.0, 750000.0],
                'net_profit': [200000.0, 150000.0],
                'total_assets': [5000000.0, 4800000.0],
                'total_liabilities': [3000000.0, 2900000.0],
                'shareholders_equity': [2000000.0, 1900000.0],
                'operating_cash_flow': [300000.0, 250000.0],
                'eps': [2.5, 1.8],
                'roe': [10.0, 8.5]
            }),
            'earnings_forecast': pd.DataFrame({
                'ts_code': ['000001.SZ', '000001.SZ'],
                'forecast_date': ['20240315', '20240215'],
                'forecast_period': ['20231231', '20231231'],
                'forecast_type': ['预增', '预增'],
                'net_profit_min': [180000.0, 170000.0],
                'net_profit_max': [220000.0, 210000.0],
                'growth_rate_min': [10.0, 8.0],
                'growth_rate_max': [20.0, 18.0],
                'forecast_summary': ['预计净利润同比增长10%-20%', '预计净利润同比增长8%-18%']
            }),
            'flash_reports': pd.DataFrame({
                'ts_code': ['000001.SZ', '000001.SZ'],
                'report_date': ['20240430', '20240331'],
                'publish_date': ['20240430', '20240331'],
                'report_period': ['20240331', '20240331'],
                'total_revenue': [250000.0, 240000.0],
                'net_profit': [50000.0, 48000.0],
                'revenue_growth': [15.0, 12.0],
                'profit_growth': [25.0, 22.0],
                'eps': [0.6, 0.58],
                'report_summary': ['一季度业绩快报', '一季度业绩快报修正']
            })
        }
    
    def test_complete_financial_analysis_workflow(self, client, sample_financial_data):
        """测试完整的财务分析工作流程"""
        with patch('quickstock.services.financial_reports_service.FinancialReportsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 模拟服务方法返回
            async def mock_get_financial_reports(request):
                df = sample_financial_data['financial_reports']
                return [FinancialReport(**row) for _, row in df.iterrows()]
            
            async def mock_get_earnings_forecast(request):
                df = sample_financial_data['earnings_forecast']
                return [EarningsForecast(**row) for _, row in df.iterrows()]
            
            async def mock_get_flash_reports(request):
                df = sample_financial_data['flash_reports']
                return [FlashReport(**row) for _, row in df.iterrows()]
            
            mock_service.get_financial_reports = mock_get_financial_reports
            mock_service.get_earnings_forecast = mock_get_earnings_forecast
            mock_service.get_earnings_flash_reports = mock_get_flash_reports
            
            # 1. 获取财务报告
            financial_reports = client.get_financial_reports(
                '000001.SZ',
                start_date='20230101',
                end_date='20231231'
            )
            
            assert len(financial_reports) == 2
            assert financial_reports[0]['ts_code'] == '000001.SZ'
            assert financial_reports[0]['report_type'] == 'A'
            assert financial_reports[0]['total_revenue'] == 1000000.0
            
            # 2. 获取业绩预告
            earnings_forecast = client.get_earnings_forecast(
                '000001.SZ',
                start_date='20240101',
                end_date='20240331'
            )
            
            assert len(earnings_forecast) == 2
            assert earnings_forecast[0]['forecast_type'] == '预增'
            assert earnings_forecast[0]['net_profit_min'] == 180000.0
            
            # 3. 获取业绩快报
            flash_reports = client.get_earnings_flash_reports(
                '000001.SZ',
                start_date='20240301',
                end_date='20240430'
            )
            
            assert len(flash_reports) == 2
            assert flash_reports[0]['revenue_growth'] == 15.0
            assert flash_reports[0]['profit_growth'] == 25.0
            
            # 4. 批量获取多种数据
            async def mock_batch_get(**kwargs):
                # Return the format expected by the client - direct stock code mapping
                return {
                    '000001.SZ': {
                        'financial_reports': [FinancialReport(**financial_reports[0])],
                        'earnings_forecast': [EarningsForecast(**earnings_forecast[0])],
                        'flash_reports': [FlashReport(**flash_reports[0])]
                    }
                }
            
            mock_service.get_batch_financial_data = mock_batch_get
            
            batch_data = client.get_batch_financial_data(
                ['000001.SZ'],
                data_types=['financial_reports', 'earnings_forecast', 'flash_reports']
            )
            
            assert '000001.SZ' in batch_data
            stock_data = batch_data['000001.SZ']
            assert 'financial_reports' in stock_data
            assert 'earnings_forecast' in stock_data
            assert 'flash_reports' in stock_data
    
    def test_multi_stock_comprehensive_analysis(self, client, sample_financial_data):
        """测试多股票综合分析"""
        with patch('quickstock.services.financial_reports_service.FinancialReportsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 模拟批量数据返回
            async def mock_batch_get(**kwargs):
                stock_codes = kwargs['stock_codes']
                result = {}
                
                for code in stock_codes:
                    # 为每只股票生成不同的数据
                    result[code] = {
                        'financial_reports': [
                            FinancialReport(
                                ts_code=code,
                                report_date='20231231',
                                report_type='A',
                                total_revenue=1000000.0 + hash(code) % 100000,
                                net_profit=200000.0 + hash(code) % 50000,
                                total_assets=5000000.0,
                                total_liabilities=3000000.0,
                                shareholders_equity=2000000.0,
                                operating_cash_flow=300000.0,
                                eps=2.5,
                                roe=10.0
                            ).to_dict()
                        ],
                        'earnings_forecast': [],
                        'flash_reports': []
                    }
                
                return {
                    'success_count': len(stock_codes),
                    'failed_count': 0,
                    'total_count': len(stock_codes),
                    'data': result,
                    'failed_stocks': {},
                    'processing_time': 1.5
                }
            
            mock_service.get_batch_financial_data = mock_batch_get
            
            # 测试多股票批量分析
            stock_codes = ['000001.SZ', '000002.SZ', '600000.SH', '600036.SH']
            batch_result = client.get_batch_financial_data(
                stock_codes,
                data_types=['financial_reports']
            )
            
            assert batch_result['success_count'] == 4
            assert batch_result['failed_count'] == 0
            assert len(batch_result['data']) == 4
            
            # 验证每只股票都有数据
            for code in stock_codes:
                assert code in batch_result['data']
                assert len(batch_result['data'][code]['financial_reports']) == 1
    
    def test_data_consistency_across_requests(self, client, sample_financial_data):
        """测试跨请求的数据一致性"""
        with patch('quickstock.services.financial_reports_service.FinancialReportsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 确保相同请求返回相同数据
            consistent_data = [FinancialReport(
                ts_code='000001.SZ',
                report_date='20231231',
                report_type='A',
                total_revenue=1000000.0,
                net_profit=200000.0,
                total_assets=5000000.0,
                total_liabilities=3000000.0,
                shareholders_equity=2000000.0,
                operating_cash_flow=300000.0,
                eps=2.5,
                roe=10.0
            )]
            
            async def mock_consistent_get(request):
                return consistent_data
            
            mock_service.get_financial_reports = mock_consistent_get
            
            # 多次请求相同数据
            results = []
            for _ in range(3):
                result = client.get_financial_reports('000001.SZ')
                results.append(result)
            
            # 验证数据一致性
            for i in range(1, len(results)):
                assert results[i] == results[0]
                assert results[i][0]['total_revenue'] == results[0][0]['total_revenue']
                assert results[i][0]['net_profit'] == results[0][0]['net_profit']


class TestFinancialReportsCacheIntegration:
    """财务报告缓存集成测试"""
    
    @pytest.fixture
    def service_with_cache(self):
        """创建带缓存的服务"""
        config = Config(cache_enabled=True, cache_expire_hours=1)
        data_manager = Mock(spec=DataManager)
        cache_layer = Mock(spec=CacheLayer)
        
        # 模拟缓存行为
        cache_layer.get = AsyncMock(return_value=None)
        cache_layer.set = AsyncMock()
        
        return FinancialReportsService(data_manager, cache_layer)
    
    @pytest.mark.asyncio
    async def test_financial_reports_cache_ttl_validation(self, service_with_cache):
        """测试财务报告缓存TTL验证"""
        # 模拟数据管理器返回数据
        mock_data = pd.DataFrame({
            'ts_code': ['000001.SZ'],
            'report_date': ['20231231'],
            'report_type': ['A'],
            'total_revenue': [1000000.0],
            'net_profit': [200000.0],
            'total_assets': [5000000.0],
            'total_liabilities': [3000000.0],
            'shareholders_equity': [2000000.0],
            'operating_cash_flow': [300000.0],
            'eps': [2.5],
            'roe': [10.0]
        })
        
        service_with_cache.data_manager.get_data = AsyncMock(return_value=mock_data)
        
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        # 第一次请求 - 应该设置缓存
        result = await service_with_cache.get_financial_reports(request)
        
        assert len(result) == 1
        assert result[0].ts_code == '000001.SZ'
        
        # 验证缓存被调用
        service_with_cache.cache_layer.get.assert_called_once()
        service_with_cache.cache_layer.set.assert_called_once()
        
        # 验证缓存TTL设置正确（财务报告应该是24小时）
        cache_call_args = service_with_cache.cache_layer.set.call_args
        assert cache_call_args[0][2] == 24  # TTL参数
    
    @pytest.mark.asyncio
    async def test_earnings_forecast_cache_ttl(self, service_with_cache):
        """测试业绩预告缓存TTL"""
        mock_data = pd.DataFrame({
            'ts_code': ['000001.SZ'],
            'forecast_date': ['20240315'],
            'forecast_period': ['20231231'],
            'forecast_type': ['预增'],
            'net_profit_min': [180000.0],
            'net_profit_max': [220000.0],
            'growth_rate_min': [10.0],
            'growth_rate_max': [20.0],
            'forecast_summary': ['预计净利润同比增长10%-20%']
        })
        
        service_with_cache.data_manager.get_data = AsyncMock(return_value=mock_data)
        
        request = EarningsForecastRequest(
            ts_code='000001.SZ',
            start_date='20240301',
            end_date='20240331'
        )
        
        result = await service_with_cache.get_earnings_forecast(request)
        
        assert len(result) == 1
        
        # 验证业绩预告缓存TTL为4小时
        cache_call_args = service_with_cache.cache_layer.set.call_args
        assert cache_call_args[0][2] == 4
    
    @pytest.mark.asyncio
    async def test_flash_reports_cache_ttl(self, service_with_cache):
        """测试业绩快报缓存TTL"""
        mock_data = pd.DataFrame({
            'ts_code': ['000001.SZ'],
            'report_date': ['20240430'],
            'publish_date': ['20240430'],
            'report_period': ['20240331'],
            'total_revenue': [250000.0],
            'net_profit': [50000.0],
            'revenue_growth': [15.0],
            'profit_growth': [25.0],
            'eps': [0.6],
            'report_summary': ['一季度业绩快报']
        })
        
        service_with_cache.data_manager.get_data = AsyncMock(return_value=mock_data)
        
        request = FlashReportsRequest(
            ts_code='000001.SZ',
            start_date='20240401',
            end_date='20240430'
        )
        
        result = await service_with_cache.get_earnings_flash_reports(request)
        
        assert len(result) == 1
        
        # 验证业绩快报缓存TTL为1小时
        cache_call_args = service_with_cache.cache_layer.set.call_args
        assert cache_call_args[0][2] == 1
    
    @pytest.mark.asyncio
    async def test_cache_hit_behavior(self, service_with_cache):
        """测试缓存命中行为"""
        # 模拟缓存命中
        cached_data = pd.DataFrame({
            'ts_code': ['000001.SZ'],
            'report_date': ['20231231'],
            'report_type': ['A'],
            'total_revenue': [1000000.0],
            'net_profit': [200000.0],
            'total_assets': [5000000.0],
            'total_liabilities': [3000000.0],
            'shareholders_equity': [2000000.0],
            'operating_cash_flow': [300000.0],
            'eps': [2.5],
            'roe': [10.0]
        })
        
        service_with_cache.cache_layer.get = AsyncMock(return_value=cached_data)
        
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        result = await service_with_cache.get_financial_reports(request)
        
        assert len(result) == 1
        assert result[0].ts_code == '000001.SZ'
        
        # 验证缓存被调用，但数据管理器没有被调用
        service_with_cache.cache_layer.get.assert_called_once()
        service_with_cache.data_manager.get_data.assert_not_called()
        
        # 验证统计信息
        stats = service_with_cache.get_service_stats()
        assert stats['cache_hits'] == 1
        assert stats['cache_misses'] == 0
    
    @pytest.mark.asyncio
    async def test_cache_invalidation_on_error(self, service_with_cache):
        """测试错误时的缓存失效"""
        # 模拟数据管理器抛出错误
        service_with_cache.data_manager.get_data = AsyncMock(
            side_effect=NetworkError("网络错误")
        )
        
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        with pytest.raises(FinancialDataError):
            await service_with_cache.get_financial_reports(request)
        
        # 验证缓存没有被设置（因为出错了）
        service_with_cache.cache_layer.set.assert_not_called()


class TestFinancialReportsBatchProcessing:
    """财务报告批处理测试"""
    
    @pytest.fixture
    def batch_service(self):
        """创建批处理服务"""
        config = Config(
            max_concurrent_requests=5,
            cache_enabled=True
        )
        data_manager = Mock(spec=DataManager)
        cache_layer = Mock(spec=CacheLayer)
        cache_layer.get = AsyncMock(return_value=None)
        cache_layer.set = AsyncMock()
        
        return FinancialReportsService(data_manager, cache_layer)
    
    @pytest.mark.asyncio
    async def test_batch_processing_with_rate_limiting(self, batch_service):
        """测试带速率限制的批处理"""
        # 模拟速率限制场景
        call_count = 0
        rate_limit_threshold = 3
        
        async def rate_limited_get_data(request):
            nonlocal call_count
            call_count += 1
            
            if call_count <= rate_limit_threshold:
                # 前几次请求正常
                return pd.DataFrame({
                    'ts_code': [request.ts_code],
                    'report_date': ['20231231'],
                    'report_type': ['A'],
                    'total_revenue': [1000000.0],
                    'net_profit': [200000.0],
                    'total_assets': [5000000.0],
                    'total_liabilities': [3000000.0],
                    'shareholders_equity': [2000000.0],
                    'operating_cash_flow': [300000.0],
                    'eps': [2.5],
                    'roe': [10.0]
                })
            else:
                # 后续请求触发速率限制
                raise RateLimitError("API调用频率超限", details={'retry_after': 1})
        
        batch_service.data_manager.get_data = rate_limited_get_data
        
        # 批量请求多只股票
        stock_codes = [f'00000{i}.SZ' for i in range(1, 8)]  # 7只股票
        
        result = await batch_service.get_batch_financial_data(
            stock_codes=stock_codes,
            data_types=['financial_reports']
        )
        
        # 验证部分成功，部分失败
        assert result['total_count'] == 7
        assert result['success_count'] <= rate_limit_threshold
        assert result['failed_count'] > 0
        
        # 验证失败的股票有错误信息
        assert len(result['failed_stocks']) > 0
        for failed_code, error_msg in result['failed_stocks'].items():
            assert 'API调用频率超限' in error_msg or 'RateLimitError' in error_msg
    
    @pytest.mark.asyncio
    async def test_batch_size_limit_enforcement(self, batch_service):
        """测试批量大小限制执行"""
        # 尝试超过批量大小限制
        large_stock_list = [f'00000{i}.SZ' for i in range(1, 52)]  # 51只股票，超过限制
        
        with pytest.raises(ValidationError, match="批量大小超过限制"):
            await batch_service.get_batch_financial_data(
                stock_codes=large_stock_list,
                data_types=['financial_reports']
            )
    
    @pytest.mark.asyncio
    async def test_batch_timeout_handling(self, batch_service):
        """测试批处理超时处理"""
        # 模拟慢速响应
        async def slow_get_data(request):
            await asyncio.sleep(2)  # 2秒延迟
            return pd.DataFrame({
                'ts_code': [request.ts_code],
                'report_date': ['20231231'],
                'report_type': ['A'],
                'total_revenue': [1000000.0],
                'net_profit': [200000.0],
                'total_assets': [5000000.0],
                'total_liabilities': [3000000.0],
                'shareholders_equity': [2000000.0],
                'operating_cash_flow': [300000.0],
                'eps': [2.5],
                'roe': [10.0]
            })
        
        batch_service.data_manager.get_data = slow_get_data
        
        # 设置较短的超时时间
        original_timeout = batch_service.BATCH_CONFIG['batch_timeout']
        batch_service.BATCH_CONFIG['batch_timeout'] = 1  # 1秒超时
        
        try:
            stock_codes = ['000001.SZ', '000002.SZ']
            
            result = await batch_service.get_batch_financial_data(
                stock_codes=stock_codes,
                data_types=['financial_reports']
            )
            
            # 应该有超时失败
            assert result['failed_count'] > 0
            assert any('超时' in error for error in result['failed_stocks'].values())
            
        finally:
            # 恢复原始超时设置
            batch_service.BATCH_CONFIG['batch_timeout'] = original_timeout
    
    @pytest.mark.asyncio
    async def test_concurrent_limit_enforcement(self, batch_service):
        """测试并发限制执行"""
        # 跟踪并发请求数
        active_requests = 0
        max_concurrent = 0
        lock = asyncio.Lock()
        
        async def concurrent_tracking_get_data(request):
            nonlocal active_requests, max_concurrent
            
            async with lock:
                active_requests += 1
                max_concurrent = max(max_concurrent, active_requests)
            
            try:
                await asyncio.sleep(0.1)  # 模拟处理时间
                return pd.DataFrame({
                    'ts_code': [request.ts_code],
                    'report_date': ['20231231'],
                    'report_type': ['A'],
                    'total_revenue': [1000000.0],
                    'net_profit': [200000.0],
                    'total_assets': [5000000.0],
                    'total_liabilities': [3000000.0],
                    'shareholders_equity': [2000000.0],
                    'operating_cash_flow': [300000.0],
                    'eps': [2.5],
                    'roe': [10.0]
                })
            finally:
                async with lock:
                    active_requests -= 1
        
        batch_service.data_manager.get_data = concurrent_tracking_get_data
        
        # 请求多只股票
        stock_codes = [f'00000{i}.SZ' for i in range(1, 11)]  # 10只股票
        
        result = await batch_service.get_batch_financial_data(
            stock_codes=stock_codes,
            data_types=['financial_reports']
        )
        
        # 验证并发限制被遵守
        concurrent_limit = batch_service.BATCH_CONFIG['concurrent_limit']
        assert max_concurrent <= concurrent_limit
        assert result['success_count'] == 10


class TestFinancialReportsErrorHandling:
    """财务报告错误处理集成测试"""
    
    @pytest.fixture
    def error_service(self):
        """创建错误处理服务"""
        config = Config(cache_enabled=True)
        data_manager = Mock(spec=DataManager)
        cache_layer = Mock(spec=CacheLayer)
        cache_layer.get = AsyncMock(return_value=None)
        cache_layer.set = AsyncMock()
        
        return FinancialReportsService(data_manager, cache_layer)
    
    @pytest.mark.asyncio
    async def test_network_error_retry_mechanism(self, error_service):
        """测试网络错误重试机制"""
        # 模拟网络错误然后成功
        call_count = 0
        
        async def network_error_then_success(request):
            nonlocal call_count
            call_count += 1
            
            if call_count <= 2:
                raise NetworkError("网络连接失败")
            else:
                return pd.DataFrame({
                    'ts_code': [request.ts_code],
                    'report_date': ['20231231'],
                    'report_type': ['A'],
                    'total_revenue': [1000000.0],
                    'net_profit': [200000.0],
                    'total_assets': [5000000.0],
                    'total_liabilities': [3000000.0],
                    'shareholders_equity': [2000000.0],
                    'operating_cash_flow': [300000.0],
                    'eps': [2.5],
                    'roe': [10.0]
                })
        
        error_service.data_manager.get_data = network_error_then_success
        
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        # 应该在重试后成功
        result = await error_service.get_financial_reports(request)
        
        assert len(result) == 1
        assert result[0].ts_code == '000001.SZ'
        assert call_count == 3  # 2次失败 + 1次成功
        
        # 验证重试统计
        stats = error_service.get_service_stats()
        assert stats['retry_attempts'] == 2
    
    @pytest.mark.asyncio
    async def test_rate_limit_error_handling(self, error_service):
        """测试速率限制错误处理"""
        # 模拟速率限制错误
        async def rate_limit_error(request):
            raise RateLimitError("API调用频率超限", details={'retry_after': 1})
        
        error_service.data_manager.get_data = rate_limit_error
        
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        # 应该最终失败（重试后仍然失败）
        with pytest.raises(FinancialDataError):
            await error_service.get_financial_reports(request)
        
        # 验证重试了
        stats = error_service.get_service_stats()
        assert stats['retry_attempts'] > 0
    
    @pytest.mark.asyncio
    async def test_validation_error_no_retry(self, error_service):
        """测试验证错误不重试"""
        call_count = 0
        
        async def validation_error(request):
            nonlocal call_count
            call_count += 1
            raise ValidationError("参数验证失败")
        
        error_service.data_manager.get_data = validation_error
        
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        with pytest.raises(FinancialDataError):
            await error_service.get_financial_reports(request)
        
        # 验证只调用了一次（没有重试）
        assert call_count == 1
    
    @pytest.mark.asyncio
    async def test_mixed_error_scenarios_in_batch(self, error_service):
        """测试批处理中的混合错误场景"""
        async def mixed_errors(request):
            code = request.ts_code
            
            if code == '000001.SZ':
                # 正常返回
                return pd.DataFrame({
                    'ts_code': [code],
                    'report_date': ['20231231'],
                    'report_type': ['A'],
                    'total_revenue': [1000000.0],
                    'net_profit': [200000.0],
                    'total_assets': [5000000.0],
                    'total_liabilities': [3000000.0],
                    'shareholders_equity': [2000000.0],
                    'operating_cash_flow': [300000.0],
                    'eps': [2.5],
                    'roe': [10.0]
                })
            elif code == '000002.SZ':
                # 网络错误
                raise NetworkError("网络连接失败")
            elif code == '000003.SZ':
                # 速率限制
                raise RateLimitError("API调用频率超限")
            elif code == '000004.SZ':
                # 数据不存在
                return pd.DataFrame()  # 空数据框
            else:
                # 验证错误
                raise ValidationError("无效的股票代码")
        
        error_service.data_manager.get_data = mixed_errors
        
        stock_codes = ['000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ', '000005.SZ']
        
        result = await error_service.get_batch_financial_data(
            stock_codes=stock_codes,
            data_types=['financial_reports']
        )
        
        # 验证结果
        assert result['total_count'] == 5
        assert result['success_count'] >= 1  # 至少000001.SZ成功
        assert result['failed_count'] >= 4   # 其他都失败
        
        # 验证成功的数据
        if '000001.SZ' in result['data']:
            assert len(result['data']['000001.SZ']['financial_reports']) == 1
        
        # 验证失败的股票有错误信息
        assert len(result['failed_stocks']) >= 4
    
    @pytest.mark.asyncio
    async def test_error_recovery_strategies(self, error_service):
        """测试错误恢复策略"""
        # 模拟临时错误后恢复
        failure_count = 0
        
        async def intermittent_failure(request):
            nonlocal failure_count
            failure_count += 1
            
            if failure_count % 3 == 0:  # 每3次请求成功一次
                return pd.DataFrame({
                    'ts_code': [request.ts_code],
                    'report_date': ['20231231'],
                    'report_type': ['A'],
                    'total_revenue': [1000000.0],
                    'net_profit': [200000.0],
                    'total_assets': [5000000.0],
                    'total_liabilities': [3000000.0],
                    'shareholders_equity': [2000000.0],
                    'operating_cash_flow': [300000.0],
                    'eps': [2.5],
                    'roe': [10.0]
                })
            else:
                raise NetworkError("临时网络错误")
        
        error_service.data_manager.get_data = intermittent_failure
        
        request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        # 应该在重试后成功
        result = await error_service.get_financial_reports(request)
        
        assert len(result) == 1
        assert result[0].ts_code == '000001.SZ'


class TestFinancialReportsConcurrency:
    """财务报告并发处理测试"""
    
    @pytest.fixture
    def concurrent_service(self):
        """创建并发服务"""
        config = Config(
            max_concurrent_requests=10,
            cache_enabled=True
        )
        data_manager = Mock(spec=DataManager)
        cache_layer = Mock(spec=CacheLayer)
        cache_layer.get = AsyncMock(return_value=None)
        cache_layer.set = AsyncMock()
        
        return FinancialReportsService(data_manager, cache_layer)
    
    @pytest.mark.asyncio
    async def test_concurrent_requests_thread_safety(self, concurrent_service):
        """测试并发请求的线程安全"""
        # 共享状态用于测试线程安全
        shared_counter = {'value': 0}
        lock = asyncio.Lock()
        
        async def thread_safe_get_data(request):
            async with lock:
                shared_counter['value'] += 1
                current_value = shared_counter['value']
            
            # 模拟处理时间
            await asyncio.sleep(0.01)
            
            return pd.DataFrame({
                'ts_code': [request.ts_code],
                'report_date': ['20231231'],
                'report_type': ['A'],
                'total_revenue': [1000000.0 + current_value],
                'net_profit': [200000.0],
                'total_assets': [5000000.0],
                'total_liabilities': [3000000.0],
                'shareholders_equity': [2000000.0],
                'operating_cash_flow': [300000.0],
                'eps': [2.5],
                'roe': [10.0]
            })
        
        concurrent_service.data_manager.get_data = thread_safe_get_data
        
        # 创建多个并发请求
        requests = [
            FinancialReportsRequest(
                ts_code=f'00000{i}.SZ',
                start_date='20231201',
                end_date='20231231'
            ) for i in range(1, 11)
        ]
        
        # 并发执行
        tasks = [concurrent_service.get_financial_reports(req) for req in requests]
        results = await asyncio.gather(*tasks)
        
        # 验证所有请求都成功
        assert len(results) == 10
        for result in results:
            assert len(result) == 1
            assert result[0].ts_code.startswith('00000')
        
        # 验证共享状态正确更新
        assert shared_counter['value'] == 10
    
    @pytest.mark.asyncio
    async def test_concurrent_cache_access(self, concurrent_service):
        """测试并发缓存访问"""
        # 模拟缓存竞争条件
        cache_access_count = {'get': 0, 'set': 0}
        
        async def counting_cache_get(key):
            cache_access_count['get'] += 1
            await asyncio.sleep(0.001)  # 模拟缓存访问延迟
            return None  # 缓存未命中
        
        async def counting_cache_set(key, data, ttl):
            cache_access_count['set'] += 1
            await asyncio.sleep(0.001)  # 模拟缓存设置延迟
        
        concurrent_service.cache_layer.get = counting_cache_get
        concurrent_service.cache_layer.set = counting_cache_set
        
        # 模拟数据获取
        async def get_data(request):
            return pd.DataFrame({
                'ts_code': [request.ts_code],
                'report_date': ['20231231'],
                'report_type': ['A'],
                'total_revenue': [1000000.0],
                'net_profit': [200000.0],
                'total_assets': [5000000.0],
                'total_liabilities': [3000000.0],
                'shareholders_equity': [2000000.0],
                'operating_cash_flow': [300000.0],
                'eps': [2.5],
                'roe': [10.0]
            })
        
        concurrent_service.data_manager.get_data = get_data
        
        # 并发请求相同数据（测试缓存竞争）
        same_request = FinancialReportsRequest(
            ts_code='000001.SZ',
            start_date='20231201',
            end_date='20231231'
        )
        
        tasks = [concurrent_service.get_financial_reports(same_request) for _ in range(5)]
        results = await asyncio.gather(*tasks)
        
        # 验证所有请求都成功
        assert len(results) == 5
        for result in results:
            assert len(result) == 1
            assert result[0].ts_code == '000001.SZ'
        
        # 验证缓存访问次数
        assert cache_access_count['get'] == 5  # 每个请求都检查缓存
        assert cache_access_count['set'] >= 1  # 至少设置一次缓存
    
    def test_thread_pool_concurrent_execution(self, concurrent_service):
        """测试线程池并发执行"""
        def sync_financial_request(stock_code):
            """同步财务请求包装器"""
            async def async_request():
                request = FinancialReportsRequest(
                    ts_code=stock_code,
                    start_date='20231201',
                    end_date='20231231'
                )
                
                # 模拟数据获取
                concurrent_service.data_manager.get_data = AsyncMock(return_value=pd.DataFrame({
                    'ts_code': [stock_code],
                    'report_date': ['20231231'],
                    'report_type': ['A'],
                    'total_revenue': [1000000.0],
                    'net_profit': [200000.0],
                    'total_assets': [5000000.0],
                    'total_liabilities': [3000000.0],
                    'shareholders_equity': [2000000.0],
                    'operating_cash_flow': [300000.0],
                    'eps': [2.5],
                    'roe': [10.0]
                }))
                
                return await concurrent_service.get_financial_reports(request)
            
            # 在新的事件循环中运行异步函数
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            try:
                return loop.run_until_complete(async_request())
            finally:
                loop.close()
        
        # 使用线程池执行多个请求
        stock_codes = [f'00000{i}.SZ' for i in range(1, 6)]
        
        with ThreadPoolExecutor(max_workers=3) as executor:
            futures = [executor.submit(sync_financial_request, code) for code in stock_codes]
            results = [future.result() for future in futures]
        
        # 验证所有请求都成功
        assert len(results) == 5
        for i, result in enumerate(results):
            assert len(result) == 1
            assert result[0].ts_code == stock_codes[i]
    
    @pytest.mark.asyncio
    async def test_concurrent_error_handling(self, concurrent_service):
        """测试并发错误处理"""
        # 模拟随机错误
        async def random_error_get_data(request):
            # 使用股票代码的哈希来确定是否出错（保证一致性）
            error_seed = hash(request.ts_code) % 3
            
            if error_seed == 0:
                raise NetworkError("网络错误")
            elif error_seed == 1:
                raise RateLimitError("速率限制")
            else:
                return pd.DataFrame({
                    'ts_code': [request.ts_code],
                    'report_date': ['20231231'],
                    'report_type': ['A'],
                    'total_revenue': [1000000.0],
                    'net_profit': [200000.0],
                    'total_assets': [5000000.0],
                    'total_liabilities': [3000000.0],
                    'shareholders_equity': [2000000.0],
                    'operating_cash_flow': [300000.0],
                    'eps': [2.5],
                    'roe': [10.0]
                })
        
        concurrent_service.data_manager.get_data = random_error_get_data
        
        # 创建多个并发请求
        requests = [
            FinancialReportsRequest(
                ts_code=f'00000{i}.SZ',
                start_date='20231201',
                end_date='20231231'
            ) for i in range(1, 10)
        ]
        
        # 并发执行（使用gather with return_exceptions=True）
        tasks = [concurrent_service.get_financial_reports(req) for req in requests]
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # 验证结果混合了成功和异常
        success_count = sum(1 for r in results if not isinstance(r, Exception))
        error_count = sum(1 for r in results if isinstance(r, Exception))
        
        assert success_count > 0  # 应该有一些成功的请求
        assert error_count > 0    # 应该有一些失败的请求
        assert success_count + error_count == 9
        
        # 验证成功的结果格式正确
        for result in results:
            if not isinstance(result, Exception):
                assert len(result) == 1
                assert result[0].ts_code.startswith('00000')
    
    @pytest.mark.asyncio
    async def test_performance_under_high_concurrency(self, concurrent_service):
        """测试高并发下的性能"""
        # 模拟快速数据获取
        async def fast_get_data(request):
            await asyncio.sleep(0.001)  # 1ms延迟
            return pd.DataFrame({
                'ts_code': [request.ts_code],
                'report_date': ['20231231'],
                'report_type': ['A'],
                'total_revenue': [1000000.0],
                'net_profit': [200000.0],
                'total_assets': [5000000.0],
                'total_liabilities': [3000000.0],
                'shareholders_equity': [2000000.0],
                'operating_cash_flow': [300000.0],
                'eps': [2.5],
                'roe': [10.0]
            })
        
        concurrent_service.data_manager.get_data = fast_get_data
        
        # 创建大量并发请求
        requests = [
            FinancialReportsRequest(
                ts_code=f'{i:06d}.SZ',
                start_date='20231201',
                end_date='20231231'
            ) for i in range(1, 51)  # 50个请求
        ]
        
        start_time = time.time()
        
        # 并发执行
        tasks = [concurrent_service.get_financial_reports(req) for req in requests]
        results = await asyncio.gather(*tasks)
        
        elapsed_time = time.time() - start_time
        
        # 验证所有请求都成功
        assert len(results) == 50
        for result in results:
            assert len(result) == 1
        
        # 性能验证：50个请求应该在合理时间内完成
        # 如果是串行执行，需要50ms，并发应该快得多
        assert elapsed_time < 1.0  # 应该在1秒内完成
        
        # 验证服务统计
        stats = concurrent_service.get_service_stats()
        assert stats['total_requests'] == 50
        assert stats['successful_requests'] == 50
        assert stats['failed_requests'] == 0


if __name__ == "__main__":
    pytest.main([__file__, "-v"])