"""
客户端涨停统计集成测试

测试QuickStockClient的涨停统计相关方法，包括数据库集成和缓存优化
"""

import pytest
import asyncio
import tempfile
import os
from unittest.mock import Mock, patch, AsyncMock
from datetime import datetime, timedelta
import pandas as pd

from quickstock.client import QuickStockClient
from quickstock.config import Config
from quickstock.models import LimitUpStats, StockDailyData
from quickstock.core.database import DatabaseManager
from quickstock.core.repository import LimitUpStatsRepository


class TestClientLimitUpIntegration:
    """客户端涨停统计集成测试"""
    
    @pytest.fixture
    def temp_db_path(self):
        """创建临时数据库文件"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
            db_path = f.name
        yield db_path
        # 清理
        if os.path.exists(db_path):
            os.unlink(db_path)
    
    @pytest.fixture
    def test_config(self, temp_db_path):
        """创建测试配置"""
        config = Config()
        config.database_path = temp_db_path
        config.cache_enabled = True
        config.cache_expire_hours = 1
        return config
    
    @pytest.fixture
    def client(self, test_config):
        """创建测试客户端"""
        return QuickStockClient(test_config)
    
    @pytest.fixture
    def sample_stock_data(self):
        """创建示例股票数据"""
        return pd.DataFrame([
            {
                'ts_code': '000001.SZ',
                'trade_date': '20241015',
                'name': '平安银行',
                'open': 10.0,
                'high': 11.0,
                'low': 10.0,
                'close': 11.0,
                'pre_close': 10.0,
                'change': 1.0,
                'pct_chg': 10.0,
                'vol': 1000000,
                'amount': 11000000.0
            },
            {
                'ts_code': '600000.SH',
                'trade_date': '20241015',
                'name': '浦发银行',
                'open': 8.0,
                'high': 8.8,
                'low': 8.0,
                'close': 8.8,
                'pre_close': 8.0,
                'change': 0.8,
                'pct_chg': 10.0,
                'vol': 2000000,
                'amount': 17600000.0
            },
            {
                'ts_code': '688001.SH',
                'trade_date': '20241015',
                'name': '华兴源创',
                'open': 21.0,
                'high': 25.2,
                'low': 21.0,
                'close': 25.2,
                'pre_close': 21.0,
                'change': 4.2,
                'pct_chg': 20.0,
                'vol': 500000,
                'amount': 12600000.0
            },
            {
                'ts_code': '000002.SZ',
                'trade_date': '20241015',
                'name': 'ST万科',
                'open': 6.3,
                'high': 6.6,
                'low': 6.3,
                'close': 6.6,
                'pre_close': 6.3,
                'change': 0.3,
                'pct_chg': 4.76,
                'vol': 800000,
                'amount': 5280000.0
            }
        ])
    
    @pytest.fixture
    def expected_stats(self):
        """期望的统计结果"""
        return {
            'total': 3,  # 3只涨停（不包括ST万科，因为涨幅不够10%）
            'non_st': 3,
            'shanghai': 1,  # 浦发银行
            'shenzhen': 1,  # 平安银行
            'star': 1,      # 华兴源创
            'beijing': 0,
            'st': 0
        }
    
    @pytest.mark.asyncio
    async def test_daily_limit_up_stats_basic(self, client, sample_stock_data, expected_stats):
        """测试基本的涨停统计功能"""
        # Mock数据管理器返回示例数据
        with patch.object(client.data_manager, 'get_data', new_callable=AsyncMock) as mock_get_data:
            # 第一次调用返回股票数据，第二次调用返回基础信息
            mock_get_data.side_effect = [
                sample_stock_data,  # 股票日线数据
                pd.DataFrame([      # 股票基础信息
                    {'ts_code': '000001.SZ', 'name': '平安银行'},
                    {'ts_code': '600000.SH', 'name': '浦发银行'},
                    {'ts_code': '688001.SH', 'name': '华兴源创'},
                    {'ts_code': '000002.SZ', 'name': 'ST万科'}
                ])
            ]
            
            # 调用客户端方法
            result = client.daily_limit_up_stats('20241015')
            
            # 验证结果
            assert isinstance(result, dict)
            assert result['total'] >= 3  # 至少3只涨停
            assert result['non_st'] >= 3
            assert 'shanghai' in result
            assert 'shenzhen' in result
            assert 'star' in result
            assert 'beijing' in result
            assert 'st' in result
    
    @pytest.mark.asyncio
    async def test_daily_limit_up_stats_with_database_save(self, client, sample_stock_data, temp_db_path):
        """测试涨停统计数据保存到数据库"""
        # 初始化数据库
        db_manager = DatabaseManager(temp_db_path)
        await db_manager.initialize()
        
        with patch.object(client.data_manager, 'get_data', new_callable=AsyncMock) as mock_get_data:
            mock_get_data.side_effect = [
                sample_stock_data,
                pd.DataFrame([
                    {'ts_code': '000001.SZ', 'name': '平安银行'},
                    {'ts_code': '600000.SH', 'name': '浦发银行'},
                    {'ts_code': '688001.SH', 'name': '华兴源创'},
                    {'ts_code': '000002.SZ', 'name': 'ST万科'}
                ])
            ]
            
            # 调用客户端方法（保存到数据库）
            result = client.daily_limit_up_stats('20241015', save_to_db=True)
            
            # 验证数据库中是否保存了数据
            repository = LimitUpStatsRepository(db_manager)
            saved_stats = await repository.get_limit_up_stats('20241015')
            
            assert saved_stats is not None
            assert saved_stats.trade_date == '20241015'
            assert saved_stats.total >= 3
    
    @pytest.mark.asyncio
    async def test_daily_limit_up_stats_cache_hit(self, client, sample_stock_data, temp_db_path):
        """测试缓存命中情况"""
        # 初始化数据库
        db_manager = DatabaseManager(temp_db_path)
        await db_manager.initialize()
        
        with patch.object(client.data_manager, 'get_data', new_callable=AsyncMock) as mock_get_data:
            mock_get_data.side_effect = [
                sample_stock_data,
                pd.DataFrame([
                    {'ts_code': '000001.SZ', 'name': '平安银行'},
                    {'ts_code': '600000.SH', 'name': '浦发银行'},
                    {'ts_code': '688001.SH', 'name': '华兴源创'},
                    {'ts_code': '000002.SZ', 'name': 'ST万科'}
                ])
            ]
            
            # 第一次调用（数据源获取并保存）
            result1 = client.daily_limit_up_stats('20241015', save_to_db=True)
            
            # 第二次调用（应该从缓存获取）
            result2 = client.daily_limit_up_stats('20241015', save_to_db=True)
            
            # 验证结果一致
            assert result1 == result2
            
            # 验证数据源只被调用了一次（第二次从缓存获取）
            assert mock_get_data.call_count == 2  # 第一次调用时的两次get_data
    
    @pytest.mark.asyncio
    async def test_query_limit_up_stats(self, client, temp_db_path):
        """测试查询涨停统计数据"""
        # 初始化数据库并插入测试数据
        db_manager = DatabaseManager(temp_db_path)
        await db_manager.initialize()
        
        repository = LimitUpStatsRepository(db_manager)
        test_stats = LimitUpStats(
            trade_date='20241015',
            total=5,
            non_st=4,
            shanghai=2,
            shenzhen=2,
            star=1,
            beijing=0,
            st=1,
            limit_up_stocks=['000001.SZ', '600000.SH', '688001.SH', '000002.SZ', '000003.SZ'],
            market_breakdown={
                'shanghai': ['600000.SH', '600001.SH'],
                'shenzhen': ['000001.SZ', '000002.SZ'],
                'star': ['688001.SH'],
                'beijing': [],
                'st': ['000002.SZ']
            }
        )
        await repository.save_limit_up_stats(test_stats)
        
        # 测试查询特定日期
        result = client.query_limit_up_stats(trade_date='20241015')
        
        assert len(result) == 1
        assert result[0]['trade_date'] == '20241015'
        assert result[0]['total'] == 5
        assert result[0]['non_st'] == 4
    
    @pytest.mark.asyncio
    async def test_delete_limit_up_stats(self, client, temp_db_path):
        """测试删除涨停统计数据"""
        # 初始化数据库并插入测试数据
        db_manager = DatabaseManager(temp_db_path)
        await db_manager.initialize()
        
        repository = LimitUpStatsRepository(db_manager)
        test_stats = LimitUpStats(
            trade_date='20241015',
            total=3,
            non_st=3,
            shanghai=1,
            shenzhen=1,
            star=1,
            beijing=0,
            st=0,
            limit_up_stocks=['000001.SZ', '600000.SH', '688001.SH'],
            market_breakdown={
                'shanghai': ['600000.SH'],
                'shenzhen': ['000001.SZ'],
                'star': ['688001.SH'],
                'beijing': []
            }
        )
        await repository.save_limit_up_stats(test_stats)
        
        # 验证数据存在
        existing_stats = await repository.get_limit_up_stats('20241015')
        assert existing_stats is not None
        
        # 删除数据
        success = client.delete_limit_up_stats('20241015')
        assert success is True
        
        # 验证数据已删除
        deleted_stats = await repository.get_limit_up_stats('20241015')
        assert deleted_stats is None
    
    @pytest.mark.asyncio
    async def test_list_limit_up_dates(self, client, temp_db_path):
        """测试列出涨停统计日期"""
        # 初始化数据库并插入多个日期的测试数据
        db_manager = DatabaseManager(temp_db_path)
        await db_manager.initialize()
        
        repository = LimitUpStatsRepository(db_manager)
        
        # 插入多个日期的数据
        test_dates = ['20241015', '20241016', '20241017']
        for date in test_dates:
            test_stats = LimitUpStats(
                trade_date=date,
                total=3,
                non_st=3,
                shanghai=1,
                shenzhen=1,
                star=1,
                beijing=0,
                st=0,
                limit_up_stocks=['000001.SZ', '600000.SH', '688001.SH'],
                market_breakdown={
                    'shanghai': ['600000.SH'],
                    'shenzhen': ['000001.SZ'],
                    'star': ['688001.SH'],
                    'beijing': []
                }
            )
            await repository.save_limit_up_stats(test_stats)
        
        # 测试列出所有日期
        dates = client.list_limit_up_dates()
        
        assert len(dates) == 3
        assert set(dates) == set(test_dates)
        
        # 测试日期范围过滤
        filtered_dates = client.list_limit_up_dates(
            start_date='20241016',
            end_date='20241017'
        )
        
        assert len(filtered_dates) == 2
        assert '20241015' not in filtered_dates
        assert '20241016' in filtered_dates
        assert '20241017' in filtered_dates
    
    @pytest.mark.asyncio
    async def test_batch_get_limit_up_stats(self, client, temp_db_path):
        """测试批量获取涨停统计数据"""
        # 初始化数据库并插入测试数据
        db_manager = DatabaseManager(temp_db_path)
        await db_manager.initialize()
        
        repository = LimitUpStatsRepository(db_manager)
        
        # 插入测试数据
        test_data = {
            '20241015': {'total': 5, 'non_st': 4, 'st': 1},
            '20241016': {'total': 3, 'non_st': 3, 'st': 0},
            '20241017': {'total': 7, 'non_st': 6, 'st': 1}
        }
        
        for date, stats_data in test_data.items():
            test_stats = LimitUpStats(
                trade_date=date,
                total=stats_data['total'],
                non_st=stats_data['non_st'],
                shanghai=1,
                shenzhen=1,
                star=1,
                beijing=0,
                st=stats_data['st'],
                limit_up_stocks=[f'stock_{i}' for i in range(stats_data['total'])],
                market_breakdown={'shanghai': [], 'shenzhen': [], 'star': [], 'beijing': []}
            )
            await repository.save_limit_up_stats(test_stats)
        
        # 测试批量获取
        dates = ['20241015', '20241016', '20241017', '20241018']  # 最后一个日期不存在
        results = client.batch_get_limit_up_stats(dates)
        
        assert len(results) == 4
        assert results['20241015']['total'] == 5
        assert results['20241016']['total'] == 3
        assert results['20241017']['total'] == 7
        assert results['20241018'] is None  # 不存在的日期
    
    @pytest.mark.asyncio
    async def test_cache_management_methods(self, client):
        """测试缓存管理方法"""
        # 测试同步缓存
        sync_result = client.sync_limit_up_cache(days=7)
        assert isinstance(sync_result, dict)
        assert 'synced_count' in sync_result
        assert 'date_range' in sync_result
        
        # 测试获取缓存统计
        cache_stats = client.get_limit_up_cache_stats()
        assert isinstance(cache_stats, dict)
        assert 'total_requests' in cache_stats
        assert 'memory_hit_rate' in cache_stats
        
        # 测试清理缓存
        cleanup_result = client.cleanup_limit_up_cache()
        assert isinstance(cleanup_result, dict)
        assert 'cleanup_time' in cleanup_result
    
    @pytest.mark.asyncio
    async def test_database_stats(self, client, temp_db_path):
        """测试数据库统计信息"""
        # 初始化数据库
        db_manager = DatabaseManager(temp_db_path)
        await db_manager.initialize()
        
        # 获取数据库统计信息
        stats = client.get_limit_up_database_stats()
        
        assert isinstance(stats, dict)
        assert 'total_records' in stats
    
    @pytest.mark.asyncio
    async def test_error_handling(self, client):
        """测试错误处理"""
        # 测试无效日期格式
        with pytest.raises(Exception):  # 应该抛出ValidationError
            client.daily_limit_up_stats('invalid_date')
        
        # 测试未来日期
        future_date = (datetime.now() + timedelta(days=1)).strftime('%Y%m%d')
        with pytest.raises(Exception):  # 应该抛出ValidationError
            client.daily_limit_up_stats(future_date)
    
    @pytest.mark.asyncio
    async def test_force_refresh(self, client, sample_stock_data):
        """测试强制刷新功能"""
        with patch.object(client.data_manager, 'get_data', new_callable=AsyncMock) as mock_get_data:
            mock_get_data.side_effect = [
                sample_stock_data,
                pd.DataFrame([
                    {'ts_code': '000001.SZ', 'name': '平安银行'},
                    {'ts_code': '600000.SH', 'name': '浦发银行'},
                    {'ts_code': '688001.SH', 'name': '华兴源创'},
                    {'ts_code': '000002.SZ', 'name': 'ST万科'}
                ])
            ]
            
            # 第一次调用
            result1 = client.daily_limit_up_stats('20241015', save_to_db=True)
            
            # 重置mock以准备第二次调用
            mock_get_data.side_effect = [
                sample_stock_data,
                pd.DataFrame([
                    {'ts_code': '000001.SZ', 'name': '平安银行'},
                    {'ts_code': '600000.SH', 'name': '浦发银行'},
                    {'ts_code': '688001.SH', 'name': '华兴源创'},
                    {'ts_code': '000002.SZ', 'name': 'ST万科'}
                ])
            ]
            
            # 第二次调用（强制刷新）
            result2 = client.daily_limit_up_stats('20241015', force_refresh=True)
            
            # 验证强制刷新时确实重新获取了数据
            assert mock_get_data.call_count == 4  # 两次调用，每次2个get_data
    
    @pytest.mark.asyncio
    async def test_market_filter(self, client, sample_stock_data):
        """测试市场过滤功能"""
        with patch.object(client.data_manager, 'get_data', new_callable=AsyncMock) as mock_get_data:
            mock_get_data.side_effect = [
                sample_stock_data,
                pd.DataFrame([
                    {'ts_code': '000001.SZ', 'name': '平安银行'},
                    {'ts_code': '600000.SH', 'name': '浦发银行'},
                    {'ts_code': '688001.SH', 'name': '华兴源创'},
                    {'ts_code': '000002.SZ', 'name': 'ST万科'}
                ])
            ]
            
            # 只获取上海市场的涨停统计
            result = client.daily_limit_up_stats(
                '20241015',
                market_filter=['shanghai'],
                save_to_db=False
            )
            
            # 验证结果只包含上海市场的数据
            assert isinstance(result, dict)
            # 由于过滤了市场，总数可能会减少
            assert result['shanghai'] >= 0
            assert result['shenzhen'] == 0  # 应该被过滤掉
            assert result['star'] == 0      # 应该被过滤掉