"""
分布计算器集成测试

测试 DistributionCalculator 与真实数据的集成
"""

import pytest
import pandas as pd
import numpy as np
from quickstock.utils.distribution_calculator import DistributionCalculator
from quickstock.models import DistributionRange


class TestDistributionCalculatorIntegration:
    """分布计算器集成测试"""
    
    def setup_method(self):
        """测试前准备"""
        self.calculator = DistributionCalculator()
        
        # 创建模拟真实市场数据
        np.random.seed(42)  # 固定随机种子以确保测试可重复
        
        # 生成500只股票的涨跌幅数据，模拟真实市场分布
        stock_codes = [f'{i:06d}.{"SH" if i % 2 == 0 else "SZ"}' for i in range(1, 501)]
        
        # 使用正态分布生成涨跌幅，但添加一些极端值
        base_changes = np.random.normal(0, 3, 450)  # 大部分股票在正常范围内
        extreme_changes = np.random.choice([-15, -12, -8, 8, 12, 15], 50)  # 一些极端涨跌幅
        pct_changes = np.concatenate([base_changes, extreme_changes])
        np.random.shuffle(pct_changes)
        
        self.market_data = pd.DataFrame({
            'ts_code': stock_codes,
            'pct_chg': pct_changes,
            'name': [f'股票{i:03d}' for i in range(1, 501)],
            'close': np.random.uniform(10, 100, 500),
            'volume': np.random.randint(1000, 100000, 500)
        })
    
    def test_full_market_analysis(self):
        """测试完整市场分析流程"""
        # 执行分类
        classified = self.calculator.classify_by_ranges(self.market_data)
        
        # 计算统计
        statistics = self.calculator.calculate_statistics(classified)
        
        # 验证基本结构
        assert len(statistics) == 10  # 5个正区间 + 5个负区间
        
        # 验证所有股票都被分类
        total_classified = sum(len(stocks) for stocks in classified.values())
        assert total_classified == len(self.market_data)
        
        # 验证百分比总和
        total_percentage = sum(result.percentage for result in statistics.values())
        assert abs(total_percentage - 100.0) < 0.01
        
        # 验证每个区间的数据完整性
        for range_name, result in statistics.items():
            assert result.range_name == range_name
            assert result.stock_count >= 0
            assert len(result.stock_codes) == result.stock_count
            assert 0 <= result.percentage <= 100
            assert result.range_definition is not None
        
        # 打印分析结果（用于调试）
        print("\n=== 市场涨跌分布分析结果 ===")
        for range_name, result in statistics.items():
            if result.stock_count > 0:
                print(f"{result.range_definition.display_name}: {result.stock_count}只 ({result.percentage:.1f}%)")
    
    def test_market_summary_analysis(self):
        """测试市场摘要分析"""
        classified = self.calculator.classify_by_ranges(self.market_data)
        statistics = self.calculator.calculate_statistics(classified)
        summary = self.calculator.get_range_summary(statistics)
        
        # 验证摘要数据
        assert summary['total_stocks'] == 500
        assert summary['total_ranges'] == 10
        assert summary['positive_ranges'] == 5
        assert summary['negative_ranges'] == 5
        
        # 验证最大和最小区间
        assert summary['largest_range'] is not None
        assert summary['smallest_range'] is not None
        assert summary['largest_range']['count'] >= summary['smallest_range']['count']
        
        print(f"\n=== 市场摘要 ===")
        print(f"总股票数: {summary['total_stocks']}")
        print(f"最大区间: {summary['largest_range']['name']} ({summary['largest_range']['count']}只)")
        print(f"最小区间: {summary['smallest_range']['name']} ({summary['smallest_range']['count']}只)")
    
    def test_extreme_market_conditions(self):
        """测试极端市场条件"""
        # 创建极端市场数据：大部分股票涨停或跌停
        extreme_data = pd.DataFrame({
            'ts_code': [f'{i:06d}.SZ' for i in range(1, 101)],
            'pct_chg': [10.0] * 50 + [-10.0] * 50  # 50只涨停，50只跌停
        })
        
        classified = self.calculator.classify_by_ranges(extreme_data)
        statistics = self.calculator.calculate_statistics(classified)
        
        # 验证极端情况下的分类
        assert statistics['>=10%'].stock_count == 50
        assert statistics['-7--10%'].stock_count == 50  # -10.0 goes to -7--10% range
        assert statistics['>=10%'].percentage == 50.0
        assert statistics['-7--10%'].percentage == 50.0
        
        # 其他区间应该为空
        other_ranges = [name for name in statistics.keys() 
                       if name not in ['>=10%', '-7--10%']]
        for range_name in other_ranges:
            assert statistics[range_name].stock_count == 0
    
    def test_custom_range_analysis(self):
        """测试自定义区间分析"""
        # 定义自定义区间：只关注大涨大跌
        custom_ranges = [
            DistributionRange("大涨", 5.0, float('inf'), True, "大涨(>=5%)"),
            DistributionRange("小涨", 0.0, 5.0, True, "小涨(0%~5%)"),
            DistributionRange("小跌", -5.0, 0.0, False, "小跌(-5%~0%)"),
            DistributionRange("大跌", float('-inf'), -5.0, False, "大跌(<=-5%)")
        ]
        
        classified = self.calculator.classify_by_ranges(self.market_data, custom_ranges)
        statistics = self.calculator.calculate_statistics(classified, custom_ranges)
        
        # 验证自定义区间分析
        assert len(statistics) == 4
        assert '大涨' in statistics
        assert '小涨' in statistics
        assert '小跌' in statistics
        assert '大跌' in statistics
        
        # 验证所有股票都被分类
        total_stocks = sum(result.stock_count for result in statistics.values())
        assert total_stocks == len(self.market_data)
        
        print(f"\n=== 自定义区间分析 ===")
        for range_name, result in statistics.items():
            print(f"{result.range_definition.display_name}: {result.stock_count}只 ({result.percentage:.1f}%)")
    
    def test_performance_with_large_dataset(self):
        """测试大数据集性能"""
        # 创建10000只股票的数据
        large_stock_codes = [f'{i:06d}.{"SH" if i % 3 == 0 else "SZ"}' for i in range(1, 10001)]
        large_pct_changes = np.random.normal(0, 4, 10000)
        
        large_data = pd.DataFrame({
            'ts_code': large_stock_codes,
            'pct_chg': large_pct_changes
        })
        
        import time
        start_time = time.time()
        
        # 执行分析
        classified = self.calculator.classify_by_ranges(large_data)
        statistics = self.calculator.calculate_statistics(classified)
        
        end_time = time.time()
        processing_time = end_time - start_time
        
        # 验证结果正确性
        total_stocks = sum(result.stock_count for result in statistics.values())
        assert total_stocks == 10000
        
        # 验证性能（应该在合理时间内完成）
        assert processing_time < 5.0  # 应该在5秒内完成
        
        print(f"\n=== 性能测试 ===")
        print(f"处理10000只股票用时: {processing_time:.3f}秒")
        print(f"平均每只股票处理时间: {processing_time/10000*1000:.3f}毫秒")
    
    def test_data_quality_handling(self):
        """测试数据质量处理"""
        # 创建包含各种数据质量问题的数据
        problematic_data = pd.DataFrame({
            'ts_code': ['good1', 'nan_pct', 'inf_pct', 'str_pct', 'good2'],
            'pct_chg': [2.5, np.nan, np.inf, 'invalid', -3.2]
        })
        
        # 应该能够处理有问题的数据
        classified = self.calculator.classify_by_ranges(problematic_data)
        statistics = self.calculator.calculate_statistics(classified)
        
        # 只有有效数据应该被分类
        total_valid_stocks = sum(result.stock_count for result in statistics.values())
        assert total_valid_stocks == 2  # 只有 good1 和 good2
        
        # 验证有效股票被正确分类
        assert 'good1' in classified['0-3%']
        assert 'good2' in classified['-3--5%']  # -3.2 goes to -3--5% range
    
    def test_boundary_precision(self):
        """测试边界值精度处理"""
        # 创建精确边界值数据
        boundary_data = pd.DataFrame({
            'ts_code': [f'boundary_{i}' for i in range(10)],
            'pct_chg': [0.0, 3.0, 5.0, 7.0, 10.0, -3.0, -5.0, -7.0, -10.0, 2.999999]
        })
        
        classified = self.calculator.classify_by_ranges(boundary_data)
        
        # 验证边界值处理
        assert 'boundary_0' in classified['0-3%']      # 0.0
        assert 'boundary_1' in classified['3-5%']      # 3.0
        assert 'boundary_2' in classified['5-7%']      # 5.0
        assert 'boundary_3' in classified['7-10%']     # 7.0
        assert 'boundary_4' in classified['>=10%']     # 10.0
        assert 'boundary_5' in classified['0--3%']     # -3.0
        assert 'boundary_6' in classified['-3--5%']    # -5.0
        assert 'boundary_7' in classified['-5--7%']    # -7.0
        assert 'boundary_8' in classified['-7--10%']   # -10.0
        assert 'boundary_9' in classified['0-3%']      # 2.999999


if __name__ == "__main__":
    pytest.main([__file__, "-v", "-s"])