"""
统计聚合器单元测试

测试StatisticsAggregator类的各种功能
"""

import pytest
import logging
from unittest.mock import Mock, patch
from typing import Dict, List, Any

from quickstock.utils.statistics_aggregator import (
    StatisticsAggregator, 
    StatisticsAggregationError,
    MarketStatistics,
    aggregate_market_data,
    calculate_distribution_percentages,
    validate_statistics_consistency
)
from quickstock.utils.distribution_calculator import DistributionResult
from quickstock.models import DistributionRange, PriceDistributionStats


class TestStatisticsAggregator:
    """统计聚合器测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.logger = Mock(spec=logging.Logger)
        self.aggregator = StatisticsAggregator(self.logger)
        
        # 创建测试用的区间定义
        self.positive_ranges = [
            DistributionRange("0-3%", 0.0, 3.0, True, "0-3%"),
            DistributionRange("3-5%", 3.0, 5.0, True, "3-5%"),
            DistributionRange(">=5%", 5.0, float('inf'), True, ">=5%")
        ]
        
        self.negative_ranges = [
            DistributionRange("0--3%", -3.0, 0.0, False, "0到-3%"),
            DistributionRange("-3--5%", -5.0, -3.0, False, "-3到-5%"),
            DistributionRange("<=-5%", float('-inf'), -5.0, False, "<=-5%")
        ]
        
        # 创建测试用的分布结果
        self.sample_distribution_results = {
            "0-3%": DistributionResult("0-3%", 100, ["000001.SZ", "000002.SZ"], 25.0, self.positive_ranges[0]),
            "3-5%": DistributionResult("3-5%", 80, ["000003.SZ", "000004.SZ"], 20.0, self.positive_ranges[1]),
            ">=5%": DistributionResult(">=5%", 60, ["000005.SZ"], 15.0, self.positive_ranges[2]),
            "0--3%": DistributionResult("0--3%", 90, ["000006.SZ"], 22.5, self.negative_ranges[0]),
            "-3--5%": DistributionResult("-3--5%", 50, ["000007.SZ"], 12.5, self.negative_ranges[1]),
            "<=-5%": DistributionResult("<=-5%", 20, ["000008.SZ"], 5.0, self.negative_ranges[2])
        }
    
    def test_init(self):
        """测试初始化"""
        # 测试默认logger
        aggregator = StatisticsAggregator()
        assert aggregator.logger is not None
        
        # 测试自定义logger
        custom_logger = Mock(spec=logging.Logger)
        aggregator = StatisticsAggregator(custom_logger)
        assert aggregator.logger == custom_logger
    
    def test_aggregate_market_stats_success(self):
        """测试成功聚合市场统计数据"""
        market_data = {
            "shanghai": self.sample_distribution_results,
            "shenzhen": {
                "0-3%": DistributionResult("0-3%", 50, ["300001.SZ"], 50.0, self.positive_ranges[0]),
                "0--3%": DistributionResult("0--3%", 50, ["300002.SZ"], 50.0, self.negative_ranges[0])
            }
        }
        
        result = self.aggregator.aggregate_market_stats(market_data)
        
        # 验证结果结构
        assert isinstance(result, dict)
        assert "shanghai" in result
        assert "shenzhen" in result
        
        # 验证上海市场统计
        shanghai_stats = result["shanghai"]
        assert isinstance(shanghai_stats, MarketStatistics)
        assert shanghai_stats.market_name == "shanghai"
        assert shanghai_stats.total_stocks == 400  # 100+80+60+90+50+20
        assert len(shanghai_stats.positive_ranges) == 3
        assert len(shanghai_stats.negative_ranges) == 3
        
        # 验证深圳市场统计
        shenzhen_stats = result["shenzhen"]
        assert shenzhen_stats.market_name == "shenzhen"
        assert shenzhen_stats.total_stocks == 100
        assert len(shenzhen_stats.positive_ranges) == 1
        assert len(shenzhen_stats.negative_ranges) == 1
        
        # 验证日志调用
        self.logger.info.assert_called()
    
    def test_aggregate_market_stats_empty_data(self):
        """测试空数据聚合"""
        result = self.aggregator.aggregate_market_stats({})
        assert result == {}
        self.logger.warning.assert_called_with("No market data provided for aggregation")
    
    def test_aggregate_market_stats_empty_market(self):
        """测试包含空市场数据的聚合"""
        market_data = {
            "shanghai": self.sample_distribution_results,
            "empty_market": {}
        }
        
        result = self.aggregator.aggregate_market_stats(market_data)
        
        assert "shanghai" in result
        assert "empty_market" not in result
        self.logger.warning.assert_called_with("No distribution results for market: empty_market")
    
    def test_aggregate_market_stats_error(self):
        """测试聚合过程中的错误处理"""
        # 创建会导致错误的数据
        invalid_result = Mock()
        invalid_result.stock_count = "invalid"  # 非数字类型
        
        market_data = {
            "invalid_market": {"range1": invalid_result}
        }
        
        with pytest.raises(StatisticsAggregationError) as exc_info:
            self.aggregator.aggregate_market_stats(market_data)
        
        assert "Error during market statistics aggregation" in str(exc_info.value)
        assert exc_info.value.aggregation_data['market_count'] == 1
    
    def test_calculate_percentages_normal(self):
        """测试正常百分比计算"""
        counts = {"range1": 25, "range2": 50, "range3": 25}
        total = 100
        
        result = self.aggregator.calculate_percentages(counts, total)
        
        assert result == {"range1": 25.0, "range2": 50.0, "range3": 25.0}
    
    def test_calculate_percentages_zero_total(self):
        """测试总数为零的百分比计算"""
        counts = {"range1": 0, "range2": 0}
        total = 0
        
        result = self.aggregator.calculate_percentages(counts, total)
        
        assert result == {"range1": 0.0, "range2": 0.0}
        self.logger.warning.assert_called_with("Total count is zero, returning zero percentages")
    
    def test_calculate_percentages_invalid_count(self):
        """测试无效计数的百分比计算"""
        counts = {"range1": -5, "range2": "invalid", "range3": 50}
        total = 100
        
        result = self.aggregator.calculate_percentages(counts, total)
        
        # 无效计数应该被设置为0
        assert result["range1"] == 0.0
        assert result["range2"] == 0.0
        assert result["range3"] == 50.0
        
        # 验证警告日志
        assert self.logger.warning.call_count >= 2
    
    def test_calculate_percentages_rounding(self):
        """测试百分比舍入"""
        counts = {"range1": 1, "range2": 2}
        total = 3
        
        result = self.aggregator.calculate_percentages(counts, total)
        
        # 1/3 = 33.33%, 2/3 = 66.67%
        assert result["range1"] == 33.33
        assert result["range2"] == 66.67
    
    def test_generate_summary_success(self):
        """测试成功生成统计摘要"""
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=1000,
            positive_ranges={"0-3%": 300, "3-5%": 200},
            positive_percentages={"0-3%": 30.0, "3-5%": 20.0},
            negative_ranges={"0--3%": 400, "-3--5%": 100},
            negative_percentages={"0--3%": 40.0, "-3--5%": 10.0},
            market_breakdown={
                "shanghai": {
                    "total_stocks": 600,
                    "positive_ranges": {"0-3%": 200, "3-5%": 100},
                    "negative_ranges": {"0--3%": 250, "-3--5%": 50}
                },
                "shenzhen": {
                    "total_stocks": 400,
                    "positive_ranges": {"0-3%": 100, "3-5%": 100},
                    "negative_ranges": {"0--3%": 150, "-3--5%": 50}
                }
            },
            processing_time=2.5,
            data_quality_score=0.95
        )
        
        result = self.aggregator.generate_summary(stats)
        
        # 验证基本信息
        assert result["trade_date"] == "20240101"
        assert result["total_stocks"] == 1000
        assert result["positive_stocks"] == 500
        assert result["negative_stocks"] == 500
        assert result["positive_percentage"] == 50.0
        assert result["negative_percentage"] == 50.0
        
        # 验证最大最小区间
        assert result["largest_range"]["name"] == "0--3%"
        assert result["largest_range"]["count"] == 400
        assert result["smallest_range"]["name"] == "-3--5%"
        assert result["smallest_range"]["count"] == 100
        
        # 验证市场板块摘要
        assert "shanghai" in result["market_breakdown"]
        assert "shenzhen" in result["market_breakdown"]
        assert result["market_breakdown"]["shanghai"]["total_stocks"] == 600
        
        # 验证元数据
        assert result["processing_time"] == 2.5
        assert result["data_quality_score"] == 0.95
        
        self.logger.info.assert_called()
    
    def test_generate_summary_empty_stats(self):
        """测试空统计数据的摘要生成"""
        result = self.aggregator.generate_summary(None)
        assert result == {}
    
    def test_generate_summary_error(self):
        """测试摘要生成过程中的错误处理"""
        # 创建会导致错误的统计数据
        invalid_stats = Mock()
        invalid_stats.total_stocks = "invalid"  # 非数字类型
        
        with pytest.raises(StatisticsAggregationError) as exc_info:
            self.aggregator.generate_summary(invalid_stats)
        
        assert "Error generating statistics summary" in str(exc_info.value)
    
    def test_validate_data_consistency_valid(self):
        """测试有效数据的一致性验证"""
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=100,
            positive_ranges={"0-3%": 30, "3-5%": 20},
            positive_percentages={"0-3%": 30.0, "3-5%": 20.0},
            negative_ranges={"0--3%": 40, "-3--5%": 10},
            negative_percentages={"0--3%": 40.0, "-3--5%": 10.0},
            market_breakdown={
                "shanghai": {
                    "total_stocks": 60,
                    "positive_ranges": {"0-3%": 20, "3-5%": 10},
                    "negative_ranges": {"0--3%": 25, "-3--5%": 5}
                }
            },
            processing_time=1.5,
            data_quality_score=0.9
        )
        
        result = self.aggregator.validate_data_consistency(stats)
        
        assert result["is_valid"] is True
        assert len(result["errors"]) == 0
        assert len(result["checks_performed"]) > 0
        assert "total_stocks_consistency" in result["checks_performed"]
        assert "percentage_consistency" in result["checks_performed"]
    
    def test_validate_data_consistency_total_mismatch(self):
        """测试总股票数不匹配的验证"""
        # 创建一个有效的stats对象，然后手动修改数据来测试验证逻辑
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=90,
            positive_ranges={"0-3%": 30, "3-5%": 20},
            positive_percentages={"0-3%": 33.33, "3-5%": 22.22},
            negative_ranges={"0--3%": 30, "-3--5%": 10},
            negative_percentages={"0--3%": 33.33, "-3--5%": 11.11}
        )
        
        # 手动修改total_stocks来创建不一致的数据
        stats.total_stocks = 100  # 实际应该是90
        
        result = self.aggregator.validate_data_consistency(stats)
        
        assert result["is_valid"] is False
        assert len(result["errors"]) > 0
        assert any("Total stocks mismatch" in error for error in result["errors"])
    
    def test_validate_data_consistency_percentage_mismatch(self):
        """测试百分比不匹配的验证"""
        # 创建一个有效的stats对象，然后手动修改百分比来测试验证逻辑
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=100,
            positive_ranges={"0-3%": 30, "3-5%": 20},
            positive_percentages={"0-3%": 30.0, "3-5%": 20.0},
            negative_ranges={"0--3%": 40, "-3--5%": 10},
            negative_percentages={"0--3%": 40.0, "-3--5%": 10.0}
        )
        
        # 手动修改百分比来创建不一致的数据
        stats.positive_percentages["0-3%"] = 25.0  # 应该是30.0%
        
        result = self.aggregator.validate_data_consistency(stats)
        
        assert result["is_valid"] is False
        assert len(result["errors"]) > 0
        assert any("Percentage mismatch" in error for error in result["errors"])
    
    def test_validate_data_consistency_invalid_quality_score(self):
        """测试无效数据质量分数的验证"""
        # 创建一个有效的stats对象，然后手动修改质量分数来测试验证逻辑
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=100,
            positive_ranges={"0-3%": 50},
            positive_percentages={"0-3%": 50.0},
            negative_ranges={"0--3%": 50},
            negative_percentages={"0--3%": 50.0},
            data_quality_score=0.9
        )
        
        # 手动修改质量分数来创建无效数据
        stats.data_quality_score = 1.5  # 超出范围
        
        result = self.aggregator.validate_data_consistency(stats)
        
        assert result["is_valid"] is False
        assert any("Data quality score out of range" in error for error in result["errors"])
    
    def test_validate_data_consistency_negative_processing_time(self):
        """测试负处理时间的验证"""
        # 创建一个有效的stats对象，然后手动修改处理时间来测试验证逻辑
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=100,
            positive_ranges={"0-3%": 50},
            positive_percentages={"0-3%": 50.0},
            negative_ranges={"0--3%": 50},
            negative_percentages={"0--3%": 50.0},
            processing_time=1.0
        )
        
        # 手动修改处理时间来创建无效数据
        stats.processing_time = -1.0  # 负数
        
        result = self.aggregator.validate_data_consistency(stats)
        
        assert result["is_valid"] is False
        assert any("Processing time cannot be negative" in error for error in result["errors"])
    
    def test_validate_data_consistency_long_processing_time(self):
        """测试过长处理时间的验证"""
        stats = PriceDistributionStats(
            trade_date="20240101",
            total_stocks=100,
            positive_ranges={"0-3%": 50},
            positive_percentages={"0-3%": 50.0},
            negative_ranges={"0--3%": 50},
            negative_percentages={"0--3%": 50.0},
            processing_time=400.0  # 超过5分钟
        )
        
        result = self.aggregator.validate_data_consistency(stats)
        
        assert result["is_valid"] is True  # 只是警告，不是错误
        assert len(result["warnings"]) > 0
        assert any("Processing time seems unusually long" in warning for warning in result["warnings"])
    
    def test_merge_market_statistics_success(self):
        """测试成功合并市场统计数据"""
        stats1 = MarketStatistics(
            market_name="shanghai",
            total_stocks=100,
            positive_ranges={"0-3%": 30, "3-5%": 20},
            positive_percentages={"0-3%": 30.0, "3-5%": 20.0},
            negative_ranges={"0--3%": 40, "-3--5%": 10},
            negative_percentages={"0--3%": 40.0, "-3--5%": 10.0},
            stock_codes={"0-3%": ["000001.SZ"], "0--3%": ["000002.SZ"]}
        )
        
        stats2 = MarketStatistics(
            market_name="shenzhen",
            total_stocks=50,
            positive_ranges={"0-3%": 20, "3-5%": 10},
            positive_percentages={"0-3%": 40.0, "3-5%": 20.0},
            negative_ranges={"0--3%": 15, "-3--5%": 5},
            negative_percentages={"0--3%": 30.0, "-3--5%": 10.0},
            stock_codes={"0-3%": ["300001.SZ"], "0--3%": ["300002.SZ"]}
        )
        
        result = self.aggregator.merge_market_statistics([stats1, stats2])
        
        assert result.market_name == "merged_shanghai_shenzhen"
        assert result.total_stocks == 150
        assert result.positive_ranges["0-3%"] == 50  # 30 + 20
        assert result.positive_ranges["3-5%"] == 30  # 20 + 10
        assert result.negative_ranges["0--3%"] == 55  # 40 + 15
        
        # 验证股票代码合并和去重
        assert len(result.stock_codes["0-3%"]) == 2
        assert "000001.SZ" in result.stock_codes["0-3%"]
        assert "300001.SZ" in result.stock_codes["0-3%"]
        
        self.logger.info.assert_called()
    
    def test_merge_market_statistics_single(self):
        """测试合并单个市场统计数据"""
        stats = MarketStatistics(
            market_name="shanghai",
            total_stocks=100,
            positive_ranges={"0-3%": 50},
            positive_percentages={"0-3%": 50.0},
            negative_ranges={"0--3%": 50},
            negative_percentages={"0--3%": 50.0},
            stock_codes={"0-3%": ["000001.SZ"]}
        )
        
        result = self.aggregator.merge_market_statistics([stats])
        assert result == stats
    
    def test_merge_market_statistics_empty(self):
        """测试合并空列表"""
        with pytest.raises(StatisticsAggregationError) as exc_info:
            self.aggregator.merge_market_statistics([])
        
        assert "No market statistics provided for merging" in str(exc_info.value)
    
    def test_calculate_range_statistics(self):
        """测试计算区间统计信息"""
        result = self.aggregator.calculate_range_statistics(self.sample_distribution_results)
        
        assert result["total_ranges"] == 6
        assert result["total_stocks"] == 400  # 100+80+60+90+50+20
        assert result["positive_ranges"] == 3
        assert result["negative_ranges"] == 3
        assert result["largest_range"]["name"] == "0-3%"
        assert result["largest_range"]["count"] == 100
        assert result["smallest_range"]["name"] == "<=-5%"
        assert result["smallest_range"]["count"] == 20
        assert result["average_stocks_per_range"] == 66.67  # 400/6
    
    def test_calculate_range_statistics_empty(self):
        """测试空分布结果的区间统计"""
        result = self.aggregator.calculate_range_statistics({})
        assert result == {}


class TestMarketStatistics:
    """市场统计数据测试类"""
    
    def test_to_dict(self):
        """测试转换为字典"""
        stats = MarketStatistics(
            market_name="shanghai",
            total_stocks=100,
            positive_ranges={"0-3%": 50},
            positive_percentages={"0-3%": 50.0},
            negative_ranges={"0--3%": 50},
            negative_percentages={"0--3%": 50.0},
            stock_codes={"0-3%": ["000001.SZ"]}
        )
        
        result = stats.to_dict()
        
        assert result["market_name"] == "shanghai"
        assert result["total_stocks"] == 100
        assert result["positive_ranges"] == {"0-3%": 50}
        assert result["stock_codes"] == {"0-3%": ["000001.SZ"]}


class TestStatisticsAggregationError:
    """统计聚合异常测试类"""
    
    def test_init_with_data(self):
        """测试带数据的异常初始化"""
        data = {"key": "value"}
        error = StatisticsAggregationError("Test error", data)
        
        assert str(error) == "Test error"
        assert error.aggregation_data == data
    
    def test_init_without_data(self):
        """测试不带数据的异常初始化"""
        error = StatisticsAggregationError("Test error")
        
        assert str(error) == "Test error"
        assert error.aggregation_data == {}
    
    def test_to_dict(self):
        """测试转换为字典"""
        data = {"key": "value"}
        error = StatisticsAggregationError("Test error", data)
        
        result = error.to_dict()
        
        assert result["error_type"] == "StatisticsAggregationError"
        assert result["message"] == "Test error"
        assert result["aggregation_data"] == data


class TestConvenienceFunctions:
    """便利函数测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.sample_distribution_results = {
            "0-3%": DistributionResult("0-3%", 100, ["000001.SZ"], 50.0, 
                                     DistributionRange("0-3%", 0.0, 3.0, True, "0-3%")),
            "0--3%": DistributionResult("0--3%", 100, ["000002.SZ"], 50.0,
                                      DistributionRange("0--3%", -3.0, 0.0, False, "0到-3%"))
        }
    
    @patch('quickstock.utils.statistics_aggregator.StatisticsAggregator')
    def test_aggregate_market_data(self, mock_aggregator_class):
        """测试聚合市场数据便利函数"""
        mock_aggregator = Mock()
        mock_aggregator_class.return_value = mock_aggregator
        mock_aggregator.aggregate_market_stats.return_value = {"result": "data"}
        
        market_data = {"shanghai": self.sample_distribution_results}
        logger = Mock()
        
        result = aggregate_market_data(market_data, logger)
        
        mock_aggregator_class.assert_called_once_with(logger)
        mock_aggregator.aggregate_market_stats.assert_called_once_with(market_data)
        assert result == {"result": "data"}
    
    @patch('quickstock.utils.statistics_aggregator.StatisticsAggregator')
    def test_calculate_distribution_percentages(self, mock_aggregator_class):
        """测试计算分布百分比便利函数"""
        mock_aggregator = Mock()
        mock_aggregator_class.return_value = mock_aggregator
        mock_aggregator.calculate_percentages.return_value = {"range1": 50.0}
        
        counts = {"range1": 50}
        total = 100
        
        result = calculate_distribution_percentages(counts, total)
        
        mock_aggregator_class.assert_called_once_with()
        mock_aggregator.calculate_percentages.assert_called_once_with(counts, total)
        assert result == {"range1": 50.0}
    
    @patch('quickstock.utils.statistics_aggregator.StatisticsAggregator')
    def test_validate_statistics_consistency(self, mock_aggregator_class):
        """测试验证统计一致性便利函数"""
        mock_aggregator = Mock()
        mock_aggregator_class.return_value = mock_aggregator
        mock_aggregator.validate_data_consistency.return_value = {"is_valid": True}
        
        stats = Mock()
        
        result = validate_statistics_consistency(stats)
        
        mock_aggregator_class.assert_called_once_with()
        mock_aggregator.validate_data_consistency.assert_called_once_with(stats)
        assert result == {"is_valid": True}


if __name__ == "__main__":
    pytest.main([__file__])