"""
QuickStockClient 涨跌分布统计接口集成测试

测试客户端API接口的参数验证、返回格式支持和缓存管理功能
"""

import json
import pytest
import pandas as pd
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from datetime import datetime, timedelta

from quickstock.client import QuickStockClient
from quickstock.config import Config
from quickstock.core.errors import ValidationError, QuickStockError
from quickstock.models.price_distribution_models import (
    PriceDistributionRequest, 
    PriceDistributionStats,
    DistributionRange
)


# 全局测试fixtures
@pytest.fixture
def mock_config():
    """创建模拟配置"""
    config = Mock(spec=Config)
    config.enable_auto_code_conversion = True
    config.log_code_conversions = False
    config.code_conversion_error_strategy = 'strict'
    config.log_level = 'INFO'
    config.log_file = None
    return config

@pytest.fixture
def mock_data_manager():
    """创建模拟数据管理器"""
    data_manager = Mock()
    data_manager.config = Mock()
    data_manager.get_cache_stats.return_value = {}
    data_manager.get_memory_stats.return_value = {}
    data_manager.get_data_source_health.return_value = {}
    return data_manager

@pytest.fixture
def client(mock_config, mock_data_manager):
    """创建测试客户端"""
    with patch('quickstock.client.DataManager', return_value=mock_data_manager):
        client = QuickStockClient(mock_config)
        client.data_manager = mock_data_manager
        return client

@pytest.fixture
def sample_stats():
    """创建示例统计数据"""
    return PriceDistributionStats(
        trade_date='20240315',
        total_stocks=1000,
        positive_ranges={
            '0-3%': 200,
            '3-5%': 150,
            '5-7%': 100,
            '7-10%': 80,
            '>=10%': 70
        },
        positive_percentages={
            '0-3%': 20.0,
            '3-5%': 15.0,
            '5-7%': 10.0,
            '7-10%': 8.0,
            '>=10%': 7.0
        },
        negative_ranges={
            '0到-3%': 180,
            '-3到-5%': 120,
            '-5到-7%': 70,
            '-7到-10%': 50,
            '<=-10%': 30
        },
        negative_percentages={
            '0到-3%': 18.0,
            '-3到-5%': 12.0,
            '-5到-7%': 7.0,
            '-7到-10%': 5.0,
            '<=-10%': 3.0
        },
        market_breakdown={
            'shanghai': {
                'total_stocks': 400,
                'positive_ranges': {'0-3%': 80, '3-5%': 60},
                'negative_ranges': {'0到-3%': 70, '-3到-5%': 50},
                'positive_percentages': {'0-3%': 20.0, '3-5%': 15.0},
                'negative_percentages': {'0到-3%': 17.5, '-3到-5%': 12.5}
            },
            'shenzhen': {
                'total_stocks': 600,
                'positive_ranges': {'0-3%': 120, '3-5%': 90},
                'negative_ranges': {'0到-3%': 110, '-3到-5%': 70},
                'positive_percentages': {'0-3%': 20.0, '3-5%': 15.0},
                'negative_percentages': {'0到-3%': 18.3, '-3到-5%': 11.7}
            }
        },
        processing_time=2.5,
        data_quality_score=0.95
    )


class TestQuickStockClientPriceDistribution:
    """测试QuickStockClient的涨跌分布统计接口"""
    pass


class TestPriceDistributionStatsMethod:
    """测试price_distribution_stats方法"""
    
    def test_basic_dataframe_format(self, client, sample_stats):
        """测试基本DataFrame格式返回"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 模拟异步方法
            async def mock_get_stats(request):
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            # 调用方法
            result = client.price_distribution_stats('20240315')
            
            # 验证结果
            assert isinstance(result, pd.DataFrame)
            assert not result.empty
            assert 'trade_date' in result.columns
            assert 'market' in result.columns
            assert 'total_stocks' in result.columns
            assert result['trade_date'].iloc[0] == '20240315'
    
    def test_dict_format_return(self, client, sample_stats):
        """测试字典格式返回"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            # 调用方法
            result = client.price_distribution_stats('20240315', format='dict')
            
            # 验证结果
            assert isinstance(result, dict)
            assert 'trade_date' in result
            assert 'total_stocks' in result
            assert result['trade_date'] == '20240315'
            assert result['total_stocks'] == 1000
    
    def test_json_format_return(self, client, sample_stats):
        """测试JSON格式返回"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            # 调用方法
            result = client.price_distribution_stats('20240315', format='json')
            
            # 验证结果
            assert isinstance(result, str)
            parsed = json.loads(result)
            assert 'trade_date' in parsed
            assert 'total_stocks' in parsed
            assert parsed['trade_date'] == '20240315'
    
    def test_parameter_validation(self, client):
        """测试参数验证"""
        # 测试无效格式
        with pytest.raises(QuickStockError) as exc_info:
            client.price_distribution_stats('20240315', format='invalid')
        assert 'Invalid format' in str(exc_info.value)
        
        # 测试无效日期格式（通过PriceDistributionRequest验证）
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 模拟请求验证失败
            async def mock_get_stats(request):
                raise ValidationError("Invalid trade date format")
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            with pytest.raises(QuickStockError) as exc_info:
                client.price_distribution_stats('invalid_date')
            assert 'Invalid trade date format' in str(exc_info.value)
    
    def test_custom_ranges_parameter(self, client, sample_stats):
        """测试自定义区间参数"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                # 验证请求中包含自定义区间
                assert request.distribution_ranges is not None
                assert '0-2%' in request.distribution_ranges
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            custom_ranges = {
                '0-2%': (0.0, 2.0),
                '2-5%': (2.0, 5.0),
                '>=5%': (5.0, float('inf'))
            }
            
            result = client.price_distribution_stats('20240315', custom_ranges=custom_ranges)
            assert isinstance(result, pd.DataFrame)
    
    def test_market_filter_parameter(self, client, sample_stats):
        """测试市场过滤器参数"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                # 验证请求中包含市场过滤器
                assert request.market_filter == ['shanghai', 'shenzhen']
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            result = client.price_distribution_stats(
                '20240315', 
                market_filter=['shanghai', 'shenzhen']
            )
            assert isinstance(result, pd.DataFrame)
    
    def test_include_st_parameter(self, client, sample_stats):
        """测试包含ST股票参数"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                # 验证请求中ST参数设置
                assert request.include_st == False
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            result = client.price_distribution_stats('20240315', include_st=False)
            assert isinstance(result, pd.DataFrame)
    
    def test_cache_parameters(self, client, sample_stats):
        """测试缓存相关参数"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                # 验证缓存参数
                assert request.force_refresh == True
                assert request.save_to_db == False
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            result = client.price_distribution_stats(
                '20240315', 
                use_cache=False,
                force_refresh=True,
                save_to_db=False
            )
            assert isinstance(result, pd.DataFrame)
    
    def test_timeout_parameter(self, client, sample_stats):
        """测试超时参数"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                # 验证超时参数
                assert request.timeout == 60
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            result = client.price_distribution_stats('20240315', timeout=60)
            assert isinstance(result, pd.DataFrame)
    
    def test_date_format_support(self, client, sample_stats):
        """测试日期格式支持"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                # 验证日期被标准化为YYYYMMDD格式
                assert request.trade_date == '20240315'
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            # 测试YYYY-MM-DD格式
            result = client.price_distribution_stats('2024-03-15')
            assert isinstance(result, pd.DataFrame)
    
    def test_error_handling(self, client):
        """测试错误处理"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 模拟服务异常
            async def mock_get_stats(request):
                raise Exception("Service error")
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            with pytest.raises(QuickStockError) as exc_info:
                client.price_distribution_stats('20240315')
            assert 'Service error' in str(exc_info.value)


class TestCacheManagementMethods:
    """测试缓存管理方法"""
    
    def test_get_distribution_cache_info(self, client):
        """测试获取缓存信息"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            expected_info = {
                'cache_entries': 10,
                'total_size': 1024,
                'hit_rate': 0.85
            }
            
            async def mock_get_cache_info(trade_date):
                return expected_info
            
            mock_service.get_cache_info = mock_get_cache_info
            
            # 测试获取所有缓存信息
            result = client.get_distribution_cache_info()
            assert result == expected_info
            
            # 测试获取特定日期缓存信息
            result = client.get_distribution_cache_info('20240315')
            assert result == expected_info
    
    def test_clear_distribution_cache(self, client):
        """测试清理缓存"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_clear_cache(trade_date, pattern):
                if trade_date:
                    return 1
                elif pattern:
                    return 5
                else:
                    return 10
            
            mock_service.clear_cache = mock_clear_cache
            
            # 测试清理所有缓存
            result = client.clear_distribution_cache()
            assert result == 10
            
            # 测试清理特定日期缓存
            result = client.clear_distribution_cache('20240315')
            assert result == 1
            
            # 测试按模式清理缓存
            result = client.clear_distribution_cache(pattern='2024*')
            assert result == 5
    
    def test_refresh_distribution_cache(self, client):
        """测试刷新缓存"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            expected_result = {
                'trade_date': '20240315',
                'success': True,
                'deleted_entries': 1,
                'processing_time': 2.5
            }
            
            async def mock_refresh_cache(trade_date, force):
                return expected_result
            
            mock_service.refresh_cache = mock_refresh_cache
            
            # 测试普通刷新
            result = client.refresh_distribution_cache('20240315')
            assert result == expected_result
            
            # 测试强制刷新
            result = client.refresh_distribution_cache('20240315', force=True)
            assert result == expected_result
    
    def test_validate_distribution_cache(self, client):
        """测试验证缓存一致性"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            cache_info = {
                'cache_entries': 1,
                'total_size': 512
            }
            
            sample_stats = PriceDistributionStats(
                trade_date='20240315',
                total_stocks=1000,
                positive_ranges={'0-3%': 200},
                positive_percentages={'0-3%': 20.0},
                negative_ranges={'0到-3%': 180},
                negative_percentages={'0到-3%': 18.0},
                market_breakdown={}
            )
            
            async def mock_get_cache_info(trade_date):
                return cache_info
            
            async def mock_get_cached_stats(trade_date):
                return sample_stats
            
            mock_service.get_cache_info = mock_get_cache_info
            mock_service.get_cached_stats = mock_get_cached_stats
            
            result = client.validate_distribution_cache('20240315')
            
            assert result['trade_date'] == '20240315'
            assert result['is_consistent'] == True
            assert 'cache_info' in result
            assert 'validation_time' in result
    
    def test_get_distribution_cache_stats(self, client):
        """测试获取缓存统计信息"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            service_stats = {
                'performance': {
                    'cache_hits': 85,
                    'cache_misses': 15,
                    'total_requests': 100
                },
                'analyzer': {'processing_time': 2.5},
                'classifier': {'market_rules': 10},
                'data_manager': {'cache_stats': {}}
            }
            
            cache_info = {
                'total_entries': 50,
                'total_size': 2048
            }
            
            mock_service.get_service_stats.return_value = service_stats
            
            async def mock_get_cache_info():
                return cache_info
            
            mock_service.get_cache_info = mock_get_cache_info
            
            result = client.get_distribution_cache_stats()
            
            assert 'service_performance' in result
            assert 'cache_info' in result
            assert 'hit_rate' in result
            assert result['hit_rate'] == 0.85  # 85/(85+15)
    
    def test_cache_method_error_handling(self, client):
        """测试缓存方法的错误处理"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 模拟异常
            async def mock_error():
                raise Exception("Cache error")
            
            mock_service.get_cache_info = mock_error
            mock_service.clear_cache = mock_error
            mock_service.refresh_cache = mock_error
            
            # 测试获取缓存信息错误处理
            result = client.get_distribution_cache_info()
            assert 'error' in result
            
            # 测试清理缓存错误处理
            result = client.clear_distribution_cache()
            assert result == 0
            
            # 测试刷新缓存错误处理
            result = client.refresh_distribution_cache('20240315')
            assert result['success'] == False
            assert 'error' in result


class TestDataFrameConversion:
    """测试DataFrame转换功能"""
    
    def test_convert_stats_to_dataframe_basic(self, client, sample_stats):
        """测试基本DataFrame转换"""
        df = client._convert_stats_to_dataframe(sample_stats)
        
        assert isinstance(df, pd.DataFrame)
        assert not df.empty
        
        # 检查基本列
        expected_columns = [
            'trade_date', 'formatted_date', 'market', 'market_name',
            'total_stocks', 'positive_count', 'negative_count',
            'positive_percentage', 'negative_percentage'
        ]
        
        for col in expected_columns:
            assert col in df.columns
        
        # 检查总体市场行
        total_row = df[df['market'] == 'total'].iloc[0]
        assert total_row['total_stocks'] == 1000
        assert total_row['positive_count'] == 600  # 200+150+100+80+70
        assert total_row['negative_count'] == 450  # 180+120+70+50+30
    
    def test_convert_stats_to_dataframe_with_ranges(self, client, sample_stats):
        """测试包含区间数据的DataFrame转换"""
        df = client._convert_stats_to_dataframe(sample_stats)
        
        # 检查正涨幅区间列
        positive_range_columns = [col for col in df.columns if col.startswith('positive_') and not col.endswith('_pct')]
        assert len(positive_range_columns) == 5  # 5个正涨幅区间
        
        # 检查负涨幅区间列
        negative_range_columns = [col for col in df.columns if col.startswith('negative_') and not col.endswith('_pct')]
        assert len(negative_range_columns) == 5  # 5个负涨幅区间
        
        # 检查百分比列
        positive_pct_columns = [col for col in df.columns if col.startswith('positive_') and col.endswith('_pct')]
        assert len(positive_pct_columns) == 5
        
        negative_pct_columns = [col for col in df.columns if col.startswith('negative_') and col.endswith('_pct')]
        assert len(negative_pct_columns) == 5
    
    def test_convert_stats_to_dataframe_market_breakdown(self, client, sample_stats):
        """测试市场分解数据的DataFrame转换"""
        df = client._convert_stats_to_dataframe(sample_stats)
        
        # 检查市场行
        markets = df['market'].unique()
        assert 'total' in markets
        assert 'shanghai' in markets
        assert 'shenzhen' in markets
        
        # 检查上海市场数据
        shanghai_row = df[df['market'] == 'shanghai'].iloc[0]
        assert shanghai_row['total_stocks'] == 400
        assert shanghai_row['market_name'] == '上海证券交易所'
    
    def test_convert_stats_to_dataframe_metadata(self, client, sample_stats):
        """测试元数据的DataFrame转换"""
        df = client._convert_stats_to_dataframe(sample_stats)
        
        # 检查元数据列
        metadata_columns = ['processing_time', 'data_quality_score', 'created_at']
        for col in metadata_columns:
            assert col in df.columns
        
        # 检查元数据值
        assert df['processing_time'].iloc[0] == 2.5
        assert df['data_quality_score'].iloc[0] == 0.95
        assert df['trade_date'].iloc[0] == '20240315'
        assert df['formatted_date'].iloc[0] == '2024-03-15'
    
    def test_convert_stats_to_dataframe_error_handling(self, client):
        """测试DataFrame转换的错误处理"""
        # 创建有问题的统计数据
        invalid_stats = Mock()
        invalid_stats.trade_date = '20240315'
        invalid_stats.total_stocks = 1000
        invalid_stats.get_summary.side_effect = Exception("Summary error")
        invalid_stats.get_total_positive_count.return_value = 600
        invalid_stats.get_total_negative_count.return_value = 400
        
        df = client._convert_stats_to_dataframe(invalid_stats)
        
        # 应该返回简化的DataFrame
        assert isinstance(df, pd.DataFrame)
        assert not df.empty
        assert 'error' in df.columns
        assert df['trade_date'].iloc[0] == '20240315'


class TestClientInitializationAndValidation:
    """测试客户端初始化和验证"""
    
    def test_client_not_initialized_error(self):
        """测试客户端未初始化错误"""
        client = QuickStockClient.__new__(QuickStockClient)  # 创建未初始化的实例
        
        with pytest.raises(QuickStockError) as exc_info:
            client.price_distribution_stats('20240315')
        assert '客户端未正确初始化' in str(exc_info.value)
    
    def test_import_error_handling(self, client):
        """测试导入错误处理"""
        with patch('quickstock.client.QuickStockClient._ensure_initialized'):
            # 模拟导入失败
            with patch('builtins.__import__', side_effect=ImportError("Module not found")):
                with pytest.raises(QuickStockError) as exc_info:
                    client.price_distribution_stats('20240315')
                assert 'Module not found' in str(exc_info.value)


class TestIntegrationScenarios:
    """测试集成场景"""
    
    def test_complete_workflow_dataframe(self, client, sample_stats):
        """测试完整的DataFrame工作流"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                # 验证请求参数
                assert request.trade_date == '20240315'
                assert request.include_st == False
                assert request.market_filter == ['shanghai']
                assert request.force_refresh == True
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            # 执行完整调用
            result = client.price_distribution_stats(
                '2024-03-15',  # 测试日期格式转换
                include_st=False,
                market_filter=['shanghai'],
                force_refresh=True,
                format='dataframe'
            )
            
            # 验证结果
            assert isinstance(result, pd.DataFrame)
            assert not result.empty
            assert result['trade_date'].iloc[0] == '20240315'
            assert 'shanghai' in result['market'].values
    
    def test_complete_workflow_json(self, client, sample_stats):
        """测试完整的JSON工作流"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            async def mock_get_stats(request):
                return sample_stats
            
            mock_service.get_price_distribution_stats = mock_get_stats
            
            # 执行JSON格式调用
            result = client.price_distribution_stats('20240315', format='json')
            
            # 验证JSON结果
            assert isinstance(result, str)
            parsed = json.loads(result)
            assert parsed['trade_date'] == '20240315'
            assert parsed['total_stocks'] == 1000
            assert 'positive_ranges' in parsed
            assert 'market_breakdown' in parsed
    
    def test_cache_management_workflow(self, client):
        """测试缓存管理工作流"""
        with patch('quickstock.services.price_distribution_stats_service.PriceDistributionStatsService') as mock_service_class:
            mock_service = Mock()
            mock_service_class.return_value = mock_service
            
            # 模拟缓存操作
            async def mock_get_cache_info(trade_date=None):
                return {'cache_entries': 5, 'total_size': 1024}
            
            async def mock_clear_cache(trade_date=None, pattern=None):
                return 5
            
            async def mock_refresh_cache(trade_date, force=False):
                return {'success': True, 'deleted_entries': 1}
            
            mock_service.get_cache_info = mock_get_cache_info
            mock_service.clear_cache = mock_clear_cache
            mock_service.refresh_cache = mock_refresh_cache
            
            # 执行缓存管理操作
            info = client.get_distribution_cache_info()
            assert info['cache_entries'] == 5
            
            cleared = client.clear_distribution_cache()
            assert cleared == 5
            
            refreshed = client.refresh_distribution_cache('20240315')
            assert refreshed['success'] == True


if __name__ == '__main__':
    pytest.main([__file__, '-v'])