"""
涨跌分布分析器集成测试

测试PriceDistributionAnalyzer的完整功能，包括与计算器、聚合器和缓存管理器的集成
"""

import asyncio
import logging
import os
import tempfile
import unittest
from unittest.mock import Mock, patch, AsyncMock
import pandas as pd
import numpy as np

from quickstock.utils.price_distribution_analyzer import (
    PriceDistributionAnalyzer,
    PriceDistributionAnalysisError,
    analyze_stock_distribution,
    calculate_market_distribution,
    perform_complete_analysis
)
from quickstock.models.price_distribution_models import (
    PriceDistributionRequest,
    PriceDistributionStats,
    DistributionRange
)
from quickstock.core.price_distribution_cache import PriceDistributionCacheManager


class TestPriceDistributionAnalyzerIntegration(unittest.TestCase):
    """涨跌分布分析器集成测试"""
    
    def setUp(self):
        """设置测试环境"""
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.DEBUG)
        
        # 创建临时数据库文件
        self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
        self.temp_db.close()
        
        # 创建缓存管理器
        self.cache_manager = PriceDistributionCacheManager(
            db_path=self.temp_db.name,
            memory_cache_size=100,
            logger=self.logger
        )
        
        # 创建分析器
        self.analyzer = PriceDistributionAnalyzer(
            cache_manager=self.cache_manager,
            logger=self.logger
        )
        
        # 创建测试数据
        self.test_stock_data = self._create_test_stock_data()
        self.test_classified_data = self._create_test_classified_data()
        self.test_ranges = self._create_test_ranges()
    
    def tearDown(self):
        """清理测试环境"""
        # 删除临时数据库文件
        if os.path.exists(self.temp_db.name):
            os.unlink(self.temp_db.name)
    
    def _create_test_stock_data(self) -> pd.DataFrame:
        """创建测试股票数据"""
        np.random.seed(42)  # 确保测试结果可重现
        
        stock_codes = [f"{i:06d}.SZ" for i in range(1, 101)]  # 100只股票
        pct_changes = np.random.normal(0, 5, 100)  # 正态分布的涨跌幅
        
        # 添加一些极端值
        pct_changes[0] = 10.5   # 大涨
        pct_changes[1] = -9.8   # 大跌
        pct_changes[2] = 0.0    # 平盘
        pct_changes[3] = 4.5    # 中等涨幅
        pct_changes[4] = -3.2   # 中等跌幅
        
        return pd.DataFrame({
            'ts_code': stock_codes,
            'pct_chg': pct_changes,
            'name': [f'股票{i}' for i in range(1, 101)],
            'close': np.random.uniform(10, 100, 100),
            'volume': np.random.randint(1000, 100000, 100)
        })
    
    def _create_test_classified_data(self) -> dict:
        """创建测试分类数据"""
        # 将测试数据按市场分类
        total_data = self.test_stock_data.copy()
        
        # 上海市场 (60开头)
        shanghai_codes = [f"60{i:04d}.SH" for i in range(1, 31)]
        shanghai_data = total_data.iloc[:30].copy()
        shanghai_data['ts_code'] = shanghai_codes
        
        # 深圳市场 (00开头)
        shenzhen_codes = [f"00{i:04d}.SZ" for i in range(1, 31)]
        shenzhen_data = total_data.iloc[30:60].copy()
        shenzhen_data['ts_code'] = shenzhen_codes
        
        # 科创板 (688开头)
        star_codes = [f"688{i:03d}.SH" for i in range(1, 21)]
        star_data = total_data.iloc[60:80].copy()
        star_data['ts_code'] = star_codes
        
        # 北证 (8开头)
        beijing_codes = [f"8{i:05d}.BJ" for i in range(1, 21)]
        beijing_data = total_data.iloc[80:100].copy()
        beijing_data['ts_code'] = beijing_codes
        
        return {
            'total': total_data,
            'shanghai': shanghai_data,
            'shenzhen': shenzhen_data,
            'star': star_data,
            'beijing': beijing_data,
            'non_st': total_data  # 假设都是非ST股票
        }
    
    def _create_test_ranges(self) -> list:
        """创建测试区间"""
        return [
            DistributionRange("0-3%", 0.0, 3.0, True, "0-3%"),
            DistributionRange("3-5%", 3.0, 5.0, True, "3-5%"),
            DistributionRange("5-7%", 5.0, 7.0, True, "5-7%"),
            DistributionRange("7-10%", 7.0, 10.0, True, "7-10%"),
            DistributionRange(">=10%", 10.0, float('inf'), True, ">=10%"),
            DistributionRange("0到-3%", -3.0, 0.0, False, "0到-3%"),
            DistributionRange("-3到-5%", -5.0, -3.0, False, "-3到-5%"),
            DistributionRange("-5到-7%", -7.0, -5.0, False, "-5到-7%"),
            DistributionRange("-7到-10%", -10.0, -7.0, False, "-7到-10%"),
            DistributionRange("<=-10%", float('-inf'), -10.0, False, "<=-10%")
        ]
    
    def test_analyze_distribution_basic(self):
        """测试基本分布分析功能"""
        async def run_test():
            result = await self.analyzer.analyze_distribution(
                self.test_stock_data, 
                self.test_ranges
            )
            
            # 验证结果结构
            self.assertIn('total_stocks', result)
            self.assertIn('distribution_results', result)
            self.assertIn('range_summary', result)
            self.assertIn('processing_time', result)
            self.assertIn('data_quality_score', result)
            
            # 验证股票总数
            self.assertEqual(result['total_stocks'], 100)
            
            # 验证分布结果
            distribution_results = result['distribution_results']
            self.assertIsInstance(distribution_results, dict)
            self.assertTrue(len(distribution_results) > 0)
            
            # 验证股票数量一致性
            total_classified = sum(
                res.stock_count for res in distribution_results.values()
            )
            self.assertEqual(total_classified, 100)
            
            # 验证处理时间
            self.assertGreater(result['processing_time'], 0)
            
            # 验证数据质量分数
            self.assertGreaterEqual(result['data_quality_score'], 0)
            self.assertLessEqual(result['data_quality_score'], 1)
        
        asyncio.run(run_test())
    
    def test_analyze_distribution_with_cache(self):
        """测试带缓存的分布分析"""
        async def run_test():
            # 第一次调用 - 应该计算并缓存
            result1 = await self.analyzer.analyze_distribution(
                self.test_stock_data, 
                self.test_ranges,
                use_cache=True
            )
            
            # 第二次调用 - 应该从缓存获取
            result2 = await self.analyzer.analyze_distribution(
                self.test_stock_data, 
                self.test_ranges,
                use_cache=True
            )
            
            # 验证结果一致性
            self.assertEqual(result1['total_stocks'], result2['total_stocks'])
            
            # 验证缓存统计
            stats = self.analyzer.get_performance_stats()
            self.assertGreater(stats['cache_hits'], 0)
        
        asyncio.run(run_test())
    
    def test_calculate_market_breakdown(self):
        """测试市场板块分布计算"""
        async def run_test():
            result = await self.analyzer.calculate_market_breakdown(
                self.test_classified_data,
                self.test_ranges
            )
            
            # 验证结果结构
            self.assertIsInstance(result, dict)
            
            # 验证包含所有市场
            expected_markets = {'total', 'shanghai', 'shenzhen', 'star', 'beijing', 'non_st'}
            self.assertEqual(set(result.keys()), expected_markets)
            
            # 验证每个市场的结果结构
            for market_name, market_stats in result.items():
                self.assertIn('market_name', market_stats)
                self.assertIn('total_stocks', market_stats)
                self.assertIn('positive_ranges', market_stats)
                self.assertIn('negative_ranges', market_stats)
                self.assertIn('positive_percentages', market_stats)
                self.assertIn('negative_percentages', market_stats)
                self.assertIn('stock_codes', market_stats)
                
                # 验证股票数量
                expected_count = len(self.test_classified_data[market_name])
                self.assertEqual(market_stats['total_stocks'], expected_count)
        
        asyncio.run(run_test())
    
    def test_analyze_complete_distribution(self):
        """测试完整分布分析"""
        async def run_test():
            # 创建测试请求
            request = PriceDistributionRequest(
                trade_date='20240115',
                include_st=True,
                market_filter=['shanghai', 'shenzhen', 'star', 'beijing']
            )
            
            result = await self.analyzer.analyze_complete_distribution(
                request,
                self.test_stock_data,
                self.test_classified_data
            )
            
            # 验证结果类型
            self.assertIsInstance(result, PriceDistributionStats)
            
            # 验证基本属性
            self.assertEqual(result.trade_date, '20240115')
            self.assertEqual(result.total_stocks, 100)
            
            # 验证分布数据
            self.assertIsInstance(result.positive_ranges, dict)
            self.assertIsInstance(result.negative_ranges, dict)
            self.assertIsInstance(result.positive_percentages, dict)
            self.assertIsInstance(result.negative_percentages, dict)
            
            # 验证市场板块数据
            self.assertIsInstance(result.market_breakdown, dict)
            self.assertTrue(len(result.market_breakdown) > 0)
            
            # 验证处理时间和质量分数
            self.assertGreater(result.processing_time, 0)
            self.assertGreaterEqual(result.data_quality_score, 0)
            self.assertLessEqual(result.data_quality_score, 1)
            
            # 验证数据一致性
            positive_total = sum(result.positive_ranges.values())
            negative_total = sum(result.negative_ranges.values())
            self.assertEqual(positive_total + negative_total, result.total_stocks)
        
        asyncio.run(run_test())
    
    def test_error_handling_invalid_data(self):
        """测试无效数据的错误处理"""
        async def run_test():
            # 测试空数据
            empty_data = pd.DataFrame()
            with self.assertRaises(PriceDistributionAnalysisError):
                await self.analyzer.analyze_distribution(empty_data)
            
            # 测试缺少必需列的数据
            invalid_data = pd.DataFrame({
                'ts_code': ['000001.SZ'],
                'close': [10.0]  # 缺少pct_chg列
            })
            with self.assertRaises(PriceDistributionAnalysisError):
                await self.analyzer.analyze_distribution(invalid_data)
            
            # 测试非DataFrame数据
            with self.assertRaises(PriceDistributionAnalysisError):
                await self.analyzer.analyze_distribution("invalid_data")
        
        asyncio.run(run_test())
    
    def test_performance_stats_tracking(self):
        """测试性能统计跟踪"""
        async def run_test():
            # 重置统计
            self.analyzer.reset_performance_stats()
            
            # 执行几次分析
            for i in range(3):
                await self.analyzer.analyze_distribution(
                    self.test_stock_data, 
                    self.test_ranges,
                    use_cache=False  # 禁用缓存确保每次都计算
                )
            
            # 检查统计信息
            stats = self.analyzer.get_performance_stats()
            self.assertEqual(stats['total_analyses'], 3)
            self.assertGreater(stats['total_processing_time'], 0)
            self.assertGreater(stats['average_processing_time'], 0)
        
        asyncio.run(run_test())
    
    def test_data_quality_score_calculation(self):
        """测试数据质量分数计算"""
        async def run_test():
            # 创建高质量数据
            good_data = self.test_stock_data.copy()
            result1 = await self.analyzer.analyze_distribution(good_data, self.test_ranges)
            
            # 创建低质量数据（包含空值）
            bad_data = self.test_stock_data.copy()
            bad_data.loc[:10, 'pct_chg'] = np.nan  # 添加空值
            result2 = await self.analyzer.analyze_distribution(bad_data, self.test_ranges)
            
            # 验证质量分数差异
            self.assertGreater(result1['data_quality_score'], result2['data_quality_score'])
        
        asyncio.run(run_test())
    
    def test_custom_ranges(self):
        """测试自定义区间"""
        async def run_test():
            # 创建自定义区间
            custom_ranges = [
                DistributionRange("0-2%", 0.0, 2.0, True, "0-2%"),
                DistributionRange("2-5%", 2.0, 5.0, True, "2-5%"),
                DistributionRange(">=5%", 5.0, float('inf'), True, ">=5%"),
                DistributionRange("0到-2%", -2.0, 0.0, False, "0到-2%"),
                DistributionRange("<-2%", float('-inf'), -2.0, False, "<-2%")
            ]
            
            result = await self.analyzer.analyze_distribution(
                self.test_stock_data, 
                custom_ranges
            )
            
            # 验证使用了自定义区间
            distribution_results = result['distribution_results']
            range_names = set(distribution_results.keys())
            expected_names = {"0-2%", "2-5%", ">=5%", "0到-2%", "<-2%"}
            self.assertEqual(range_names, expected_names)
        
        asyncio.run(run_test())
    
    def test_validation_result(self):
        """测试分析结果验证"""
        async def run_test():
            result = await self.analyzer.analyze_distribution(
                self.test_stock_data, 
                self.test_ranges
            )
            
            # 验证分析结果
            validation = await self.analyzer.validate_analysis_result(result)
            
            # 验证验证结果结构
            self.assertIn('is_valid', validation)
            self.assertIn('errors', validation)
            self.assertIn('warnings', validation)
            self.assertIn('checks_performed', validation)
            
            # 对于正常结果，应该是有效的
            self.assertTrue(validation['is_valid'])
            self.assertEqual(len(validation['errors']), 0)
        
        asyncio.run(run_test())
    
    def test_convenience_functions(self):
        """测试便利函数"""
        async def run_test():
            # 测试analyze_stock_distribution
            result1 = await analyze_stock_distribution(
                self.test_stock_data,
                self.test_ranges,
                self.cache_manager,
                self.logger
            )
            self.assertIn('total_stocks', result1)
            
            # 测试calculate_market_distribution
            result2 = await calculate_market_distribution(
                self.test_classified_data,
                self.test_ranges,
                self.cache_manager,
                self.logger
            )
            self.assertIsInstance(result2, dict)
            
            # 测试perform_complete_analysis
            request = PriceDistributionRequest(
                trade_date='20240115',
                include_st=True
            )
            result3 = await perform_complete_analysis(
                request,
                self.test_stock_data,
                self.test_classified_data,
                self.cache_manager,
                self.logger
            )
            self.assertIsInstance(result3, PriceDistributionStats)
        
        asyncio.run(run_test())
    
    def test_cache_integration(self):
        """测试缓存集成"""
        async def run_test():
            # 测试缓存设置和获取
            cache_key = "test_analysis_key"
            test_data = {"test": "data"}
            
            # 设置缓存
            success = await self.cache_manager.set_distribution_stats(
                cache_key, test_data, market='test', include_st=True
            )
            self.assertTrue(success)
            
            # 获取缓存
            cached_data = await self.cache_manager.get_distribution_stats(
                cache_key, market='test', include_st=True
            )
            self.assertEqual(cached_data, test_data)
            
            # 测试缓存信息
            cache_info = await self.cache_manager.get_cache_info()
            self.assertIn('memory_cache', cache_info)
            self.assertIn('global_stats', cache_info)
        
        asyncio.run(run_test())
    
    def test_large_dataset_performance(self):
        """测试大数据集性能"""
        async def run_test():
            # 创建大数据集（1000只股票）
            np.random.seed(42)
            large_stock_codes = [f"{i:06d}.SZ" for i in range(1, 1001)]
            large_pct_changes = np.random.normal(0, 3, 1000)
            
            large_data = pd.DataFrame({
                'ts_code': large_stock_codes,
                'pct_chg': large_pct_changes,
                'name': [f'股票{i}' for i in range(1, 1001)],
                'close': np.random.uniform(10, 100, 1000),
                'volume': np.random.randint(1000, 100000, 1000)
            })
            
            # 测试分析性能
            import time
            start_time = time.time()
            
            result = await self.analyzer.analyze_distribution(
                large_data, 
                self.test_ranges
            )
            
            end_time = time.time()
            processing_time = end_time - start_time
            
            # 验证结果
            self.assertEqual(result['total_stocks'], 1000)
            
            # 验证性能（应该在合理时间内完成）
            self.assertLess(processing_time, 10.0)  # 应该在10秒内完成
            
            self.logger.info(f"Large dataset analysis completed in {processing_time:.3f}s")
        
        asyncio.run(run_test())


class TestPriceDistributionAnalyzerEdgeCases(unittest.TestCase):
    """涨跌分布分析器边界情况测试"""
    
    def setUp(self):
        """设置测试环境"""
        self.logger = logging.getLogger(__name__)
        self.analyzer = PriceDistributionAnalyzer(logger=self.logger)
    
    def test_extreme_values(self):
        """测试极端值处理"""
        async def run_test():
            # 创建包含极端值的数据
            extreme_data = pd.DataFrame({
                'ts_code': ['000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ'],
                'pct_chg': [50.0, -30.0, 0.0, np.inf],  # 包含无穷大值
                'name': ['股票1', '股票2', '股票3', '股票4']
            })
            
            # 应该能够处理极端值
            result = await self.analyzer.analyze_distribution(extreme_data)
            
            # 验证结果
            self.assertIn('total_stocks', result)
            self.assertGreater(result['data_quality_score'], 0)
        
        asyncio.run(run_test())
    
    def test_all_same_values(self):
        """测试所有股票涨跌幅相同的情况"""
        async def run_test():
            # 创建所有股票涨跌幅都相同的数据
            same_value_data = pd.DataFrame({
                'ts_code': [f'00000{i}.SZ' for i in range(1, 11)],
                'pct_chg': [2.5] * 10,  # 所有股票都是2.5%
                'name': [f'股票{i}' for i in range(1, 11)]
            })
            
            result = await self.analyzer.analyze_distribution(same_value_data)
            
            # 验证所有股票都在同一个区间
            distribution_results = result['distribution_results']
            non_zero_ranges = [
                name for name, res in distribution_results.items() 
                if res.stock_count > 0
            ]
            self.assertEqual(len(non_zero_ranges), 1)  # 只有一个区间有股票
        
        asyncio.run(run_test())
    
    def test_empty_market_data(self):
        """测试空市场数据"""
        async def run_test():
            # 创建包含空市场的分类数据
            classified_data = {
                'shanghai': pd.DataFrame({
                    'ts_code': ['600001.SH', '600002.SH'],
                    'pct_chg': [1.5, -2.3],
                    'name': ['股票1', '股票2']
                }),
                'shenzhen': pd.DataFrame(),  # 空数据
                'star': pd.DataFrame({
                    'ts_code': ['688001.SH'],
                    'pct_chg': [5.2],
                    'name': ['股票3']
                })
            }
            
            result = await self.analyzer.calculate_market_breakdown(classified_data)
            
            # 验证结果包含非空市场
            self.assertIn('shanghai', result)
            self.assertIn('star', result)
            # 空市场应该被跳过
            self.assertNotIn('shenzhen', result)
        
        asyncio.run(run_test())


if __name__ == '__main__':
    # 配置日志
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
    # 运行测试
    unittest.main()