"""
提供者优化测试

测试数据提供者的优化功能
"""

import pytest
import asyncio
import pandas as pd
from unittest.mock import Mock, AsyncMock, patch
from datetime import datetime, timedelta

from quickstock.providers.optimizations import (
    ProviderOptimizer, ProviderOptimizationManager, ProviderOptimizationConfig,
    BatchConfig, OptimizationStrategy
)
from quickstock.providers.manager import DataSourceManager
from quickstock.models import LimitUpStatsRequest
from quickstock.config import Config


class TestProviderOptimizations:
    """提供者优化测试类"""
    
    @pytest.fixture
    def mock_provider(self):
        """模拟数据提供者"""
        provider = Mock()
        provider.get_provider_name = Mock(return_value='test_provider')
        
        # 模拟股票日线数据
        stock_data = pd.DataFrame({
            'ts_code': ['000001.SZ', '000002.SZ', '600000.SH'],
            'trade_date': ['20241015', '20241015', '20241015'],
            'open': [10.0, 20.0, 8.0],
            'high': [11.0, 22.0, 8.8],
            'low': [9.8, 19.8, 7.9],
            'close': [11.0, 22.0, 8.8],
            'pre_close': [10.0, 20.0, 8.0],
            'change': [1.0, 2.0, 0.8],
            'pct_chg': [10.0, 10.0, 10.0],
            'vol': [1000, 2000, 1500],
            'amount': [11000.0, 44000.0, 13200.0]
        })
        
        # 模拟股票基础信息
        basic_data = pd.DataFrame({
            'ts_code': ['000001.SZ', '000002.SZ', '600000.SH'],
            'name': ['平安银行', '万科A', '浦发银行']
        })
        
        provider.get_stock_daily = AsyncMock(return_value=stock_data)
        provider.get_stock_basic = AsyncMock(return_value=basic_data)
        
        return provider
    
    @pytest.fixture
    def optimization_config(self):
        """优化配置"""
        return ProviderOptimizationConfig(
            provider_name='test_provider',
            enabled_strategies=[
                OptimizationStrategy.BATCH_PROCESSING,
                OptimizationStrategy.SMART_CACHING
            ],
            batch_config=BatchConfig(
                batch_size=2,
                max_concurrent_batches=2,
                batch_delay=0.01,
                timeout_per_batch=10
            ),
            cache_ttl=60,
            rate_limit_per_second=10.0,
            timeout=30
        )
    
    @pytest.fixture
    def optimizer(self, optimization_config):
        """优化器"""
        optimizer = ProviderOptimizer()
        optimizer.register_optimization(optimization_config)
        return optimizer
    
    @pytest.mark.asyncio
    async def test_batch_processing_optimization(self, optimizer, mock_provider):
        """测试批处理优化"""
        request = LimitUpStatsRequest(trade_date='20241015')
        
        # 执行优化请求
        result = await optimizer.optimize_stock_data_request(mock_provider, request)
        
        # 验证结果
        assert isinstance(result, pd.DataFrame)
        assert not result.empty
        assert len(result) == 3
        assert 'name' in result.columns
        
        # 验证提供者方法被调用
        mock_provider.get_stock_basic.assert_called()
        mock_provider.get_stock_daily.assert_called()
    
    @pytest.mark.asyncio
    async def test_parallel_processing_optimization(self, mock_provider):
        """测试并行处理优化"""
        config = ProviderOptimizationConfig(
            provider_name='test_provider',
            enabled_strategies=[OptimizationStrategy.PARALLEL_REQUESTS],
            timeout=30
        )
        
        optimizer = ProviderOptimizer()
        optimizer.register_optimization(config)
        
        request = LimitUpStatsRequest(trade_date='20241015')
        
        # 执行优化请求
        result = await optimizer.optimize_stock_data_request(mock_provider, request)
        
        # 验证结果
        assert isinstance(result, pd.DataFrame)
        assert not result.empty
        assert len(result) == 3
    
    @pytest.mark.asyncio
    async def test_caching_optimization(self, optimizer, mock_provider):
        """测试缓存优化"""
        request = LimitUpStatsRequest(trade_date='20241015')
        
        # 第一次请求
        result1 = await optimizer.optimize_stock_data_request(mock_provider, request)
        
        # 第二次请求（应该命中缓存）
        result2 = await optimizer.optimize_stock_data_request(mock_provider, request)
        
        # 验证结果一致
        pd.testing.assert_frame_equal(result1, result2)
        
        # 验证缓存统计
        stats = optimizer.get_optimization_stats('test_provider')
        assert stats['cache_hits'] > 0
    
    @pytest.mark.asyncio
    async def test_rate_limiting(self, optimizer, mock_provider):
        """测试速率限制"""
        request = LimitUpStatsRequest(trade_date='20241015')
        
        # 并发执行多个请求
        tasks = [
            optimizer.optimize_stock_data_request(mock_provider, request)
            for _ in range(5)
        ]
        
        start_time = datetime.now()
        results = await asyncio.gather(*tasks)
        end_time = datetime.now()
        
        # 验证所有请求都成功
        assert len(results) == 5
        for result in results:
            assert isinstance(result, pd.DataFrame)
            assert not result.empty
        
        # 验证速率限制生效（应该有一定的延迟）
        total_time = (end_time - start_time).total_seconds()
        assert total_time > 0  # 至少有一些延迟
    
    @pytest.mark.asyncio
    async def test_error_handling_in_optimization(self, optimizer, mock_provider):
        """测试优化中的错误处理"""
        # 模拟提供者失败
        mock_provider.get_stock_daily.side_effect = Exception("数据源暂时不可用")
        
        request = LimitUpStatsRequest(trade_date='20241015')
        
        # 执行请求应该抛出异常
        with pytest.raises(Exception):
            await optimizer.optimize_stock_data_request(mock_provider, request)
        
        # 验证失败统计
        stats = optimizer.get_optimization_stats('test_provider')
        assert stats['failed_requests'] > 0
    
    @pytest.mark.asyncio
    async def test_timeout_handling(self, mock_provider):
        """测试超时处理"""
        config = ProviderOptimizationConfig(
            provider_name='test_provider',
            enabled_strategies=[OptimizationStrategy.PARALLEL_REQUESTS],
            timeout=0.1  # 很短的超时时间
        )
        
        optimizer = ProviderOptimizer()
        optimizer.register_optimization(config)
        
        # 模拟慢响应
        async def slow_response(*args, **kwargs):
            await asyncio.sleep(0.2)  # 超过超时时间
            return pd.DataFrame()
        
        mock_provider.get_stock_daily = slow_response
        mock_provider.get_stock_basic = slow_response
        
        request = LimitUpStatsRequest(trade_date='20241015')
        
        # 执行请求应该返回空DataFrame（超时处理）
        result = await optimizer.optimize_stock_data_request(mock_provider, request)
        assert isinstance(result, pd.DataFrame)
        assert result.empty
    
    def test_optimization_config_creation(self):
        """测试优化配置创建"""
        config = ProviderOptimizationConfig(
            provider_name='test_provider',
            enabled_strategies=[OptimizationStrategy.BATCH_PROCESSING],
            cache_ttl=3600
        )
        
        assert config.provider_name == 'test_provider'
        assert OptimizationStrategy.BATCH_PROCESSING in config.enabled_strategies
        assert config.cache_ttl == 3600
    
    def test_optimization_manager_initialization(self):
        """测试优化管理器初始化"""
        manager = ProviderOptimizationManager()
        manager.initialize_optimizations()
        
        # 验证默认配置被加载
        optimizer = manager.get_optimizer()
        stats = optimizer.get_optimization_stats()
        
        # 应该有默认的提供者配置
        assert len(stats) > 0
        assert 'baostock' in stats or 'eastmoney' in stats or 'tonghuashun' in stats
    
    def test_config_updates(self):
        """测试配置更新"""
        manager = ProviderOptimizationManager()
        manager.initialize_optimizations()
        
        # 更新配置
        updates = {
            'cache_ttl': 7200,
            'batch_size': 200,
            'rate_limit_per_second': 10.0
        }
        
        manager.update_provider_config('baostock', updates)
        
        # 验证配置更新
        config = manager.get_provider_config('baostock')
        assert config.cache_ttl == 7200
        assert config.batch_config.batch_size == 200
        assert config.rate_limit_per_second == 10.0
    
    def test_stats_collection(self, optimizer, mock_provider):
        """测试统计信息收集"""
        # 获取初始统计
        initial_stats = optimizer.get_optimization_stats('test_provider')
        assert initial_stats['total_requests'] == 0
        
        # 模拟一些请求统计
        optimizer._performance_stats['test_provider']['total_requests'] = 10
        optimizer._performance_stats['test_provider']['successful_requests'] = 8
        optimizer._performance_stats['test_provider']['failed_requests'] = 2
        optimizer._performance_stats['test_provider']['total_response_time'] = 20.0
        
        # 获取更新后的统计
        updated_stats = optimizer.get_optimization_stats('test_provider')
        assert updated_stats['total_requests'] == 10
        assert updated_stats['success_rate'] == 0.8
        assert updated_stats['average_response_time'] == 2.5
    
    def test_cache_management(self, optimizer):
        """测试缓存管理"""
        # 模拟缓存数据
        request = LimitUpStatsRequest(trade_date='20241015')
        test_data = pd.DataFrame({'test': [1, 2, 3]})
        
        optimizer._cache_data('test_provider', request, test_data, 60)
        
        # 验证缓存存在
        cached_data = optimizer._get_cached_data('test_provider', request)
        assert cached_data is not None
        pd.testing.assert_frame_equal(cached_data, test_data)
        
        # 清理缓存
        optimizer.clear_cache('test_provider')
        
        # 验证缓存被清理
        cached_data = optimizer._get_cached_data('test_provider', request)
        assert cached_data is None
    
    @pytest.mark.asyncio
    async def test_integration_with_data_source_manager(self):
        """测试与数据源管理器的集成"""
        # 创建模拟配置
        config = Mock()
        config.enable_baostock = False
        config.enable_eastmoney = False
        config.enable_tonghuashun = False
        config.tushare_token = None
        config.get_data_source_priority = Mock(return_value=['test_provider'])
        
        # 创建数据源管理器
        manager = DataSourceManager(config)
        
        # 注册测试提供者
        mock_provider = Mock()
        mock_provider.get_provider_name = Mock(return_value='test_provider')
        mock_provider.health_check = AsyncMock(return_value=True)
        
        manager.register_provider('test_provider', mock_provider)
        
        # 验证优化管理器存在
        assert hasattr(manager, 'optimization_manager')
        assert manager.optimization_manager is not None
        
        # 测试优化统计获取
        stats = manager.get_optimization_stats()
        assert isinstance(stats, dict)
        
        # 测试优化配置更新
        manager.update_provider_optimization('test_provider', {'cache_ttl': 1800})
        
        # 测试优化策略启用
        manager.enable_provider_optimization('test_provider', ['batch_processing'])
    
    @pytest.mark.asyncio
    async def test_performance_benchmark(self):
        """测试性能基准测试"""
        # 创建模拟配置和管理器
        config = Mock()
        config.enable_baostock = False
        config.enable_eastmoney = False
        config.enable_tonghuashun = False
        config.tushare_token = None
        config.get_data_source_priority = Mock(return_value=['test_provider'])
        
        manager = DataSourceManager(config)
        
        # 注册测试提供者
        mock_provider = Mock()
        mock_provider.get_provider_name = Mock(return_value='test_provider')
        mock_provider.health_check = AsyncMock(return_value=True)
        
        # 模拟数据
        test_data = pd.DataFrame({
            'ts_code': ['000001.SZ'],
            'trade_date': ['20241015'],
            'close': [10.0]
        })
        
        mock_provider.get_stock_daily = AsyncMock(return_value=test_data)
        mock_provider.get_stock_basic = AsyncMock(return_value=test_data)
        
        manager.register_provider('test_provider', mock_provider)
        
        # 执行性能基准测试
        benchmark_result = await manager.benchmark_provider_performance('test_provider', '20241015')
        
        # 验证基准测试结果
        assert isinstance(benchmark_result, dict)
        assert 'test_date' in benchmark_result
        assert 'provider_name' in benchmark_result
        assert benchmark_result['provider_name'] == 'test_provider'
        assert benchmark_result['test_date'] == '20241015'


if __name__ == '__main__':
    pytest.main([__file__, '-v'])