"""
分布计算器单元测试

测试 DistributionCalculator 类的核心算法功能
"""

import pytest
import pandas as pd
import numpy as np
from quickstock.utils.distribution_calculator import (
    DistributionCalculator,
    DistributionResult,
    DistributionCalculationError,
    create_default_ranges,
    classify_stocks_by_change,
    calculate_distribution_stats
)
from quickstock.models import DistributionRange
from quickstock.core.errors import ValidationError


class TestDistributionCalculator:
    """测试 DistributionCalculator 类"""
    
    def setup_method(self):
        """测试前准备"""
        self.calculator = DistributionCalculator()
        
        # 创建测试数据
        self.test_data = pd.DataFrame({
            'ts_code': [
                '000001.SZ', '000002.SZ', '000003.SZ', '000004.SZ', '000005.SZ',
                '600001.SH', '600002.SH', '600003.SH', '600004.SH', '600005.SH',
                '688001.SH', '688002.SH', '300001.SZ', '300002.SZ', '300003.SZ'
            ],
            'pct_chg': [
                2.5, 4.2, 6.8, 9.1, 12.3,    # 正涨幅：0-3%, 3-5%, 5-7%, 7-10%, >=10%
                -1.8, -3.5, -6.2, -8.7, -11.5,  # 负涨幅：0到-3%, -3到-5%, -5到-7%, -7到-10%, <=-10%
                0.0, 3.0, 5.0, 7.0, 10.0     # 边界值测试
            ]
        })
    
    def test_init(self):
        """测试初始化"""
        calculator = DistributionCalculator()
        assert calculator is not None
        assert calculator.default_ranges is not None
        assert len(calculator.default_ranges) == 10  # 5个正区间 + 5个负区间
    
    def test_create_default_ranges(self):
        """测试创建默认区间"""
        ranges = self.calculator._create_default_ranges()
        
        assert len(ranges) == 10
        
        # 检查正区间
        positive_ranges = [r for r in ranges if r.is_positive]
        assert len(positive_ranges) == 5
        
        # 检查负区间
        negative_ranges = [r for r in ranges if not r.is_positive]
        assert len(negative_ranges) == 5
        
        # 检查区间名称
        positive_names = {r.name for r in positive_ranges}
        expected_positive = {"0-3%", "3-5%", "5-7%", "7-10%", ">=10%"}
        assert positive_names == expected_positive
        
        negative_names = {r.name for r in negative_ranges}
        expected_negative = {"0--3%", "-3--5%", "-5--7%", "-7--10%", "<=-10%"}
        assert negative_names == expected_negative
    
    def test_validate_ranges_valid(self):
        """测试有效区间验证"""
        ranges = self.calculator.default_ranges
        assert self.calculator.validate_ranges(ranges) is True
    
    def test_validate_ranges_empty(self):
        """测试空区间列表"""
        with pytest.raises(DistributionCalculationError, match="Ranges list cannot be empty"):
            self.calculator.validate_ranges([])
    
    def test_validate_ranges_not_list(self):
        """测试非列表类型"""
        with pytest.raises(DistributionCalculationError, match="Ranges must be a list"):
            self.calculator.validate_ranges("not a list")
    
    def test_validate_ranges_invalid_type(self):
        """测试无效的区间类型"""
        invalid_ranges = [{"name": "test", "min": 0, "max": 1}]
        with pytest.raises(DistributionCalculationError, match="not a DistributionRange instance"):
            self.calculator.validate_ranges(invalid_ranges)
    
    def test_validate_ranges_duplicate_names(self):
        """测试重复的区间名称"""
        range1 = DistributionRange("test", 0.0, 1.0, True, "Test 1")
        range2 = DistributionRange("test", 1.0, 2.0, True, "Test 2")
        
        with pytest.raises(DistributionCalculationError, match="Duplicate range name"):
            self.calculator.validate_ranges([range1, range2])
    
    def test_validate_ranges_invalid_values(self):
        """测试无效的区间值"""
        # The ValidationError will be raised during DistributionRange creation, not during validate_ranges
        with pytest.raises(ValidationError, match="Min and max values must be numeric"):
            invalid_range = DistributionRange("test", "invalid", 1.0, True, "Test")
    
    def test_classify_by_ranges_basic(self):
        """测试基本分类功能"""
        classified = self.calculator.classify_by_ranges(self.test_data)
        
        # 检查所有区间都存在
        expected_ranges = {r.name for r in self.calculator.default_ranges}
        assert set(classified.keys()) == expected_ranges
        
        # 检查分类结果
        assert len(classified["0-3%"]) >= 1  # 应该包含2.5%的股票
        assert len(classified["3-5%"]) >= 1  # 应该包含4.2%的股票
        assert len(classified["5-7%"]) >= 1  # 应该包含6.8%的股票
        assert len(classified["7-10%"]) >= 1  # 应该包含9.1%的股票
        assert len(classified[">=10%"]) >= 1  # 应该包含12.3%的股票
        
        assert len(classified["0--3%"]) >= 1  # 应该包含-1.8%的股票
        assert len(classified["-3--5%"]) >= 1  # 应该包含-3.5%的股票
        assert len(classified["-5--7%"]) >= 1  # 应该包含-6.2%的股票
        assert len(classified["-7--10%"]) >= 1  # 应该包含-8.7%的股票
        assert len(classified["<=-10%"]) >= 1  # 应该包含-11.5%的股票
    
    def test_classify_by_ranges_boundary_values(self):
        """测试边界值分类"""
        # 创建边界值测试数据
        boundary_data = pd.DataFrame({
            'ts_code': ['test1', 'test2', 'test3', 'test4', 'test5'],
            'pct_chg': [0.0, 3.0, 5.0, 7.0, 10.0]
        })
        
        classified = self.calculator.classify_by_ranges(boundary_data)
        
        # 检查边界值分类
        # 0.0% 应该在 0-3% 区间
        assert 'test1' in classified["0-3%"]
        
        # 3.0% 应该在 3-5% 区间（左闭右开）
        assert 'test2' in classified["3-5%"]
        
        # 5.0% 应该在 5-7% 区间
        assert 'test3' in classified["5-7%"]
        
        # 7.0% 应该在 7-10% 区间
        assert 'test4' in classified["7-10%"]
        
        # 10.0% 应该在 >=10% 区间
        assert 'test5' in classified[">=10%"]
    
    def test_classify_by_ranges_negative_boundary_values(self):
        """测试负边界值分类"""
        # 创建负边界值测试数据
        boundary_data = pd.DataFrame({
            'ts_code': ['test1', 'test2', 'test3', 'test4', 'test5'],
            'pct_chg': [-0.1, -3.0, -5.0, -7.0, -10.0]
        })
        
        classified = self.calculator.classify_by_ranges(boundary_data)
        
        # 检查负边界值分类 (使用左闭右开区间 [min_value, max_value))
        # -0.1% 应该在 0--3% 区间 ([-3.0, 0.0))
        assert 'test1' in classified["0--3%"]
        
        # -3.0% 应该在 0--3% 区间 ([-3.0, 0.0)) - 包含左边界
        assert 'test2' in classified["0--3%"]
        
        # -5.0% 应该在 -3--5% 区间 ([-5.0, -3.0)) - 包含左边界
        assert 'test3' in classified["-3--5%"]
        
        # -7.0% 应该在 -5--7% 区间 ([-7.0, -5.0)) - 包含左边界
        assert 'test4' in classified["-5--7%"]
        
        # -10.0% 应该在 -7--10% 区间 ([-10.0, -7.0)) - 包含左边界
        assert 'test5' in classified["-7--10%"]
    
    def test_classify_by_ranges_invalid_data(self):
        """测试无效数据处理"""
        # 测试空DataFrame
        empty_data = pd.DataFrame()
        with pytest.raises(DistributionCalculationError, match="Stock data cannot be empty"):
            self.calculator.classify_by_ranges(empty_data)
        
        # 测试缺少必需列
        invalid_data = pd.DataFrame({'ts_code': ['test1']})
        with pytest.raises(DistributionCalculationError, match="Missing required columns"):
            self.calculator.classify_by_ranges(invalid_data)
        
        # 测试非DataFrame类型
        with pytest.raises(DistributionCalculationError, match="Stock data must be a pandas DataFrame"):
            self.calculator.classify_by_ranges("not a dataframe")
    
    def test_classify_by_ranges_with_nan_values(self):
        """测试包含NaN值的数据"""
        data_with_nan = pd.DataFrame({
            'ts_code': ['test1', 'test2', 'test3'],
            'pct_chg': [2.5, np.nan, 4.2]
        })
        
        classified = self.calculator.classify_by_ranges(data_with_nan)
        
        # NaN值应该被跳过
        assert 'test1' in classified["0-3%"]
        assert 'test3' in classified["3-5%"]
        # test2 不应该出现在任何区间中
        all_stocks = []
        for stocks in classified.values():
            all_stocks.extend(stocks)
        assert 'test2' not in all_stocks
    
    def test_calculate_statistics_basic(self):
        """测试基本统计计算"""
        classified = self.calculator.classify_by_ranges(self.test_data)
        statistics = self.calculator.calculate_statistics(classified)
        
        # 检查统计结果结构
        assert isinstance(statistics, dict)
        assert len(statistics) == len(self.calculator.default_ranges)
        
        # 检查每个统计结果
        for range_name, result in statistics.items():
            assert isinstance(result, DistributionResult)
            assert result.range_name == range_name
            assert result.stock_count >= 0
            assert 0 <= result.percentage <= 100
            assert isinstance(result.stock_codes, list)
            assert result.range_definition is not None
        
        # 检查百分比总和约等于100%
        total_percentage = sum(result.percentage for result in statistics.values())
        assert abs(total_percentage - 100.0) < 0.01  # 允许浮点误差
    
    def test_calculate_statistics_empty_data(self):
        """测试空数据统计"""
        empty_classified = {r.name: [] for r in self.calculator.default_ranges}
        statistics = self.calculator.calculate_statistics(empty_classified)
        
        assert statistics == {}
    
    def test_calculate_percentages(self):
        """测试百分比计算"""
        counts = {"range1": 10, "range2": 20, "range3": 70}
        total = 100
        
        percentages = self.calculator.calculate_percentages(counts, total)
        
        assert percentages["range1"] == 10.0
        assert percentages["range2"] == 20.0
        assert percentages["range3"] == 70.0
        
        # 测试总数为0的情况
        percentages_zero = self.calculator.calculate_percentages(counts, 0)
        assert all(p == 0.0 for p in percentages_zero.values())
    
    def test_get_range_summary(self):
        """测试区间统计摘要"""
        classified = self.calculator.classify_by_ranges(self.test_data)
        statistics = self.calculator.calculate_statistics(classified)
        summary = self.calculator.get_range_summary(statistics)
        
        assert isinstance(summary, dict)
        assert "total_ranges" in summary
        assert "total_stocks" in summary
        assert "positive_ranges" in summary
        assert "negative_ranges" in summary
        assert "largest_range" in summary
        assert "smallest_range" in summary
        
        assert summary["total_ranges"] == len(statistics)
        assert summary["positive_ranges"] + summary["negative_ranges"] == summary["total_ranges"]
        assert summary["total_stocks"] == len(self.test_data)
    
    def test_get_range_summary_empty(self):
        """测试空统计摘要"""
        summary = self.calculator.get_range_summary({})
        
        assert summary["total_ranges"] == 0
        assert summary["total_stocks"] == 0
        assert summary["positive_ranges"] == 0
        assert summary["negative_ranges"] == 0
        assert summary["largest_range"] is None
        assert summary["smallest_range"] is None
    
    def test_custom_ranges(self):
        """测试自定义区间"""
        custom_ranges = [
            DistributionRange("small_up", 0.0, 2.0, True, "小涨"),
            DistributionRange("big_up", 2.0, float('inf'), True, "大涨"),
            DistributionRange("small_down", -2.0, 0.0, False, "小跌"),
            DistributionRange("big_down", float('-inf'), -2.0, False, "大跌")
        ]
        
        test_data = pd.DataFrame({
            'ts_code': ['test1', 'test2', 'test3', 'test4'],
            'pct_chg': [1.0, 5.0, -1.0, -5.0]
        })
        
        classified = self.calculator.classify_by_ranges(test_data, custom_ranges)
        statistics = self.calculator.calculate_statistics(classified, custom_ranges)
        
        assert len(statistics) == 4
        assert statistics["small_up"].stock_count == 1
        assert statistics["big_up"].stock_count == 1
        assert statistics["small_down"].stock_count == 1
        assert statistics["big_down"].stock_count == 1
    
    def test_extreme_values(self):
        """测试极端值处理"""
        extreme_data = pd.DataFrame({
            'ts_code': ['test1', 'test2', 'test3', 'test4'],
            'pct_chg': [999.99, -999.99, 0.001, -0.001]
        })
        
        classified = self.calculator.classify_by_ranges(extreme_data)
        
        # 极大正值应该在 >=10% 区间
        assert 'test1' in classified[">=10%"]
        
        # 极大负值应该在 <=-10% 区间
        assert 'test2' in classified["<=-10%"]
        
        # 极小正值应该在 0-3% 区间
        assert 'test3' in classified["0-3%"]
        
        # 极小负值应该在 0--3% 区间
        assert 'test4' in classified["0--3%"]


class TestConvenienceFunctions:
    """测试便利函数"""
    
    def test_create_default_ranges(self):
        """测试创建默认区间便利函数"""
        ranges = create_default_ranges()
        
        assert len(ranges) == 10
        assert all(isinstance(r, DistributionRange) for r in ranges)
    
    def test_classify_stocks_by_change(self):
        """测试股票分类便利函数"""
        test_data = pd.DataFrame({
            'ts_code': ['test1', 'test2'],
            'pct_chg': [2.5, -3.5]
        })
        
        classified = classify_stocks_by_change(test_data)
        
        assert isinstance(classified, dict)
        assert 'test1' in classified["0-3%"]
        assert 'test2' in classified["-3--5%"]
    
    def test_calculate_distribution_stats(self):
        """测试完整分布统计便利函数"""
        test_data = pd.DataFrame({
            'ts_code': ['test1', 'test2', 'test3'],
            'pct_chg': [2.5, 4.2, -3.5]
        })
        
        statistics = calculate_distribution_stats(test_data)
        
        assert isinstance(statistics, dict)
        assert all(isinstance(result, DistributionResult) for result in statistics.values())
        
        # 检查总百分比 (允许更大的浮点误差)
        total_percentage = sum(result.percentage for result in statistics.values())
        assert abs(total_percentage - 100.0) < 0.02


class TestEdgeCases:
    """测试边界情况"""
    
    def setup_method(self):
        """测试前准备"""
        self.calculator = DistributionCalculator()
    
    def test_single_stock(self):
        """测试单只股票"""
        single_stock = pd.DataFrame({
            'ts_code': ['test1'],
            'pct_chg': [2.5]
        })
        
        classified = self.calculator.classify_by_ranges(single_stock)
        statistics = self.calculator.calculate_statistics(classified)
        
        # 只有一个区间有股票，其百分比应该是100%
        non_empty_ranges = [r for r in statistics.values() if r.stock_count > 0]
        assert len(non_empty_ranges) == 1
        assert non_empty_ranges[0].percentage == 100.0
    
    def test_all_same_change(self):
        """测试所有股票涨跌幅相同"""
        same_change_data = pd.DataFrame({
            'ts_code': ['test1', 'test2', 'test3'],
            'pct_chg': [2.5, 2.5, 2.5]
        })
        
        classified = self.calculator.classify_by_ranges(same_change_data)
        statistics = self.calculator.calculate_statistics(classified)
        
        # 所有股票应该在同一个区间
        non_empty_ranges = [r for r in statistics.values() if r.stock_count > 0]
        assert len(non_empty_ranges) == 1
        assert non_empty_ranges[0].stock_count == 3
        assert non_empty_ranges[0].percentage == 100.0
    
    def test_duplicate_stock_codes(self):
        """测试重复股票代码"""
        duplicate_data = pd.DataFrame({
            'ts_code': ['test1', 'test1', 'test2'],
            'pct_chg': [2.5, 4.2, -3.5]
        })
        
        # 应该能正常处理，但会有警告
        classified = self.calculator.classify_by_ranges(duplicate_data)
        
        # 检查分类结果
        assert 'test1' in classified["0-3%"] or 'test1' in classified["3-5%"]
        assert 'test2' in classified["-3--5%"]
    
    def test_very_large_dataset(self):
        """测试大数据集"""
        # 创建1000只股票的数据
        large_data = pd.DataFrame({
            'ts_code': [f'test{i:04d}' for i in range(1000)],
            'pct_chg': np.random.normal(0, 5, 1000)  # 正态分布，均值0，标准差5
        })
        
        classified = self.calculator.classify_by_ranges(large_data)
        statistics = self.calculator.calculate_statistics(classified)
        
        # 检查总数
        total_stocks = sum(result.stock_count for result in statistics.values())
        assert total_stocks == 1000
        
        # 检查百分比总和
        total_percentage = sum(result.percentage for result in statistics.values())
        assert abs(total_percentage - 100.0) < 0.01


if __name__ == "__main__":
    pytest.main([__file__])