"""
统计聚合器集成测试

测试StatisticsAggregator与DistributionCalculator的集成
"""

import pytest
import pandas as pd
import logging
from typing import Dict, List, Any

from quickstock.utils.statistics_aggregator import StatisticsAggregator, MarketStatistics
from quickstock.utils.distribution_calculator import DistributionCalculator
from quickstock.models import DistributionRange, PriceDistributionStats


class TestStatisticsAggregatorIntegration:
    """统计聚合器集成测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.logger = logging.getLogger(__name__)
        self.aggregator = StatisticsAggregator(self.logger)
        self.calculator = DistributionCalculator(self.logger)
        
        # 创建测试股票数据
        self.stock_data = pd.DataFrame({
            'ts_code': [
                '000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ', '000005.SZ',
                '600001.SH', '600002.SH', '600003.SH', '600004.SH', '600005.SH',
                '688001.SH', '688002.SH', '688003.SH', '688004.SH', '688005.SH',
                '300001.SZ', '300002.SZ', '300003.SZ', '300004.SZ', '300005.SZ'
            ],
            'pct_chg': [
                2.5, -1.8, 4.2, -3.5, 6.8,    # 深圳主板
                1.2, -2.1, 3.8, -4.2, 8.5,    # 上海主板
                15.2, -8.5, 12.3, -6.8, 18.9, # 科创板
                0.8, -1.2, 2.8, -2.5, 4.5     # 创业板
            ],
            'name': [f'股票{i:02d}' for i in range(1, 21)]
        })
        
        # 市场分类映射
        self.market_mapping = {
            'shenzhen': ['000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ', '000005.SZ',
                        '300001.SZ', '300002.SZ', '300003.SZ', '300004.SZ', '300005.SZ'],
            'shanghai': ['600001.SH', '600002.SH', '600003.SH', '600004.SH', '600005.SH'],
            'star': ['688001.SH', '688002.SH', '688003.SH', '688004.SH', '688005.SH']
        }
    
    def test_end_to_end_market_analysis(self):
        """测试端到端市场分析流程"""
        # 1. 使用DistributionCalculator计算各市场的分布
        market_distribution_results = {}
        
        for market_name, stock_codes in self.market_mapping.items():
            # 过滤出该市场的股票数据
            market_stocks = self.stock_data[self.stock_data['ts_code'].isin(stock_codes)]
            
            # 计算分布
            classified_stocks = self.calculator.classify_by_ranges(market_stocks)
            distribution_results = self.calculator.calculate_statistics(classified_stocks)
            
            market_distribution_results[market_name] = distribution_results
        
        # 2. 使用StatisticsAggregator聚合市场统计
        aggregated_stats = self.aggregator.aggregate_market_stats(market_distribution_results)
        
        # 3. 验证聚合结果
        assert len(aggregated_stats) == 3  # 三个市场
        assert 'shenzhen' in aggregated_stats
        assert 'shanghai' in aggregated_stats
        assert 'star' in aggregated_stats
        
        # 验证深圳市场统计
        shenzhen_stats = aggregated_stats['shenzhen']
        assert isinstance(shenzhen_stats, MarketStatistics)
        assert shenzhen_stats.market_name == 'shenzhen'
        assert shenzhen_stats.total_stocks == 10  # 深圳主板5只 + 创业板5只
        
        # 验证上海市场统计
        shanghai_stats = aggregated_stats['shanghai']
        assert shanghai_stats.total_stocks == 5
        
        # 验证科创板统计
        star_stats = aggregated_stats['star']
        assert star_stats.total_stocks == 5
        
        # 验证百分比计算
        for market_stats in aggregated_stats.values():
            total_positive = sum(market_stats.positive_ranges.values())
            total_negative = sum(market_stats.negative_ranges.values())
            assert total_positive + total_negative == market_stats.total_stocks
            
            # 验证百分比总和约为100%
            positive_pct_sum = sum(market_stats.positive_percentages.values())
            negative_pct_sum = sum(market_stats.negative_percentages.values())
            total_pct = positive_pct_sum + negative_pct_sum
            assert abs(total_pct - 100.0) < 0.1  # 允许舍入误差
    
    def test_statistics_summary_generation(self):
        """测试统计摘要生成"""
        # 1. 计算整体市场分布
        classified_stocks = self.calculator.classify_by_ranges(self.stock_data)
        distribution_results = self.calculator.calculate_statistics(classified_stocks)
        
        # 2. 创建PriceDistributionStats对象
        positive_ranges = {}
        negative_ranges = {}
        positive_percentages = {}
        negative_percentages = {}
        
        for range_name, result in distribution_results.items():
            if result.range_definition and result.range_definition.is_positive:
                positive_ranges[range_name] = result.stock_count
                positive_percentages[range_name] = result.percentage
            else:
                negative_ranges[range_name] = result.stock_count
                negative_percentages[range_name] = result.percentage
        
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=len(self.stock_data),
            positive_ranges=positive_ranges,
            positive_percentages=positive_percentages,
            negative_ranges=negative_ranges,
            negative_percentages=negative_percentages,
            processing_time=1.5,
            data_quality_score=0.95
        )
        
        # 3. 生成摘要
        summary = self.aggregator.generate_summary(stats)
        
        # 4. 验证摘要内容
        assert summary['trade_date'] == "20240101"
        assert summary['total_stocks'] == 20
        assert 'positive_stocks' in summary
        assert 'negative_stocks' in summary
        assert 'largest_range' in summary
        assert 'smallest_range' in summary
        assert summary['processing_time'] == 1.5
        assert summary['data_quality_score'] == 0.95
        
        # 验证正负股票数量加起来等于总数
        assert summary['positive_stocks'] + summary['negative_stocks'] == 20
    
    def test_data_consistency_validation(self):
        """测试数据一致性验证"""
        # 1. 创建一个完整的统计对象
        classified_stocks = self.calculator.classify_by_ranges(self.stock_data)
        distribution_results = self.calculator.calculate_statistics(classified_stocks)
        
        positive_ranges = {}
        negative_ranges = {}
        positive_percentages = {}
        negative_percentages = {}
        
        for range_name, result in distribution_results.items():
            if result.range_definition and result.range_definition.is_positive:
                positive_ranges[range_name] = result.stock_count
                positive_percentages[range_name] = result.percentage
            else:
                negative_ranges[range_name] = result.stock_count
                negative_percentages[range_name] = result.percentage
        
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=len(self.stock_data),
            positive_ranges=positive_ranges,
            positive_percentages=positive_percentages,
            negative_ranges=negative_ranges,
            negative_percentages=negative_percentages,
            processing_time=1.5,
            data_quality_score=0.95
        )
        
        # 2. 验证数据一致性
        validation_result = self.aggregator.validate_data_consistency(stats)
        
        # 3. 验证结果
        assert validation_result['is_valid'] is True
        assert len(validation_result['errors']) == 0
        assert len(validation_result['checks_performed']) > 0
        
        # 验证执行了所有必要的检查
        expected_checks = [
            'total_stocks_consistency',
            'percentage_consistency',
            'market_breakdown_consistency',
            'data_quality_score',
            'processing_time'
        ]
        
        for check in expected_checks:
            assert check in validation_result['checks_performed']
    
    def test_market_statistics_merging(self):
        """测试市场统计数据合并"""
        # 1. 分别计算深圳和上海市场的统计
        shenzhen_stocks = self.stock_data[self.stock_data['ts_code'].isin(self.market_mapping['shenzhen'])]
        shanghai_stocks = self.stock_data[self.stock_data['ts_code'].isin(self.market_mapping['shanghai'])]
        
        # 计算深圳市场分布
        shenzhen_classified = self.calculator.classify_by_ranges(shenzhen_stocks)
        shenzhen_results = self.calculator.calculate_statistics(shenzhen_classified)
        shenzhen_market_stats = self.aggregator.aggregate_market_stats({'shenzhen': shenzhen_results})['shenzhen']
        
        # 计算上海市场分布
        shanghai_classified = self.calculator.classify_by_ranges(shanghai_stocks)
        shanghai_results = self.calculator.calculate_statistics(shanghai_classified)
        shanghai_market_stats = self.aggregator.aggregate_market_stats({'shanghai': shanghai_results})['shanghai']
        
        # 2. 合并两个市场的统计
        merged_stats = self.aggregator.merge_market_statistics([shenzhen_market_stats, shanghai_market_stats])
        
        # 3. 验证合并结果
        assert merged_stats.market_name == "merged_shenzhen_shanghai"
        assert merged_stats.total_stocks == 15  # 深圳10只 + 上海5只
        
        # 验证股票代码合并
        total_stock_codes = 0
        for codes in merged_stats.stock_codes.values():
            total_stock_codes += len(codes)
        
        # 由于股票分布在不同区间，总的股票代码数应该等于总股票数
        assert total_stock_codes == 15
        
        # 验证百分比重新计算
        positive_pct_sum = sum(merged_stats.positive_percentages.values())
        negative_pct_sum = sum(merged_stats.negative_percentages.values())
        total_pct = positive_pct_sum + negative_pct_sum
        assert abs(total_pct - 100.0) < 0.1  # 允许舍入误差
    
    def test_range_statistics_calculation(self):
        """测试区间统计信息计算"""
        # 1. 计算分布结果
        classified_stocks = self.calculator.classify_by_ranges(self.stock_data)
        distribution_results = self.calculator.calculate_statistics(classified_stocks)
        
        # 2. 计算区间统计
        range_stats = self.aggregator.calculate_range_statistics(distribution_results)
        
        # 3. 验证统计结果
        assert range_stats['total_stocks'] == 20
        assert range_stats['total_ranges'] == len(distribution_results)
        assert range_stats['positive_ranges'] + range_stats['negative_ranges'] == range_stats['total_ranges']
        
        # 验证最大最小区间信息
        assert 'largest_range' in range_stats
        assert 'smallest_range' in range_stats
        assert 'name' in range_stats['largest_range']
        assert 'count' in range_stats['largest_range']
        assert 'percentage' in range_stats['largest_range']
        
        # 验证平均每区间股票数
        expected_avg = 20 / len(distribution_results)
        assert abs(range_stats['average_stocks_per_range'] - expected_avg) < 0.01
    
    def test_percentage_calculation_accuracy(self):
        """测试百分比计算精度"""
        # 创建特定的计数数据来测试精度
        counts = {
            "range1": 33,
            "range2": 33,
            "range3": 34
        }
        total = 100
        
        percentages = self.aggregator.calculate_percentages(counts, total)
        
        # 验证百分比
        assert percentages["range1"] == 33.0
        assert percentages["range2"] == 33.0
        assert percentages["range3"] == 34.0
        
        # 验证总和
        total_pct = sum(percentages.values())
        assert total_pct == 100.0
        
        # 测试不能整除的情况
        counts_uneven = {
            "range1": 1,
            "range2": 1,
            "range3": 1
        }
        total_uneven = 3
        
        percentages_uneven = self.aggregator.calculate_percentages(counts_uneven, total_uneven)
        
        # 每个区间应该是33.33%
        for pct in percentages_uneven.values():
            assert abs(pct - 33.33) < 0.01
        
        # 总和应该接近100%
        total_pct_uneven = sum(percentages_uneven.values())
        assert abs(total_pct_uneven - 100.0) < 0.1


if __name__ == "__main__":
    pytest.main([__file__])