"""
涨跌分布统计数据模型单元测试

测试 PriceDistributionRequest、PriceDistributionStats 和 DistributionRange 数据模型
"""

import pytest
from datetime import datetime, timedelta
from quickstock.models import (
    DistributionRange, 
    PriceDistributionStats, 
    PriceDistributionRequest
)
from quickstock.core.errors import ValidationError


class TestDistributionRange:
    """测试 DistributionRange 数据模型"""
    
    def test_valid_positive_range(self):
        """测试有效的正区间"""
        range_obj = DistributionRange(
            name="0-3%",
            min_value=0.0,
            max_value=3.0,
            is_positive=True,
            display_name="0-3%"
        )
        
        assert range_obj.name == "0-3%"
        assert range_obj.min_value == 0.0
        assert range_obj.max_value == 3.0
        assert range_obj.is_positive is True
        assert range_obj.display_name == "0-3%"
    
    def test_valid_negative_range(self):
        """测试有效的负区间"""
        range_obj = DistributionRange(
            name="0--3%",
            min_value=-3.0,
            max_value=0.0,
            is_positive=False,
            display_name="0到-3%"
        )
        
        assert range_obj.name == "0--3%"
        assert range_obj.min_value == -3.0
        assert range_obj.max_value == 0.0
        assert range_obj.is_positive is False
        assert range_obj.display_name == "0到-3%"
    
    def test_infinite_range(self):
        """测试无限区间"""
        range_obj = DistributionRange(
            name=">=10%",
            min_value=10.0,
            max_value=float('inf'),
            is_positive=True,
            display_name=">=10%"
        )
        
        assert range_obj.max_value == float('inf')
        assert range_obj.contains(15.0) is True
        assert range_obj.contains(100.0) is True
    
    def test_contains_method(self):
        """测试区间包含方法"""
        range_obj = DistributionRange(
            name="3-5%",
            min_value=3.0,
            max_value=5.0,
            is_positive=True,
            display_name="3-5%"
        )
        
        # 测试边界值
        assert range_obj.contains(3.0) is True   # 包含最小值
        assert range_obj.contains(4.99) is True  # 区间内
        assert range_obj.contains(5.0) is False  # 不包含最大值
        assert range_obj.contains(2.99) is False # 小于最小值
        assert range_obj.contains(5.01) is False # 大于最大值
    
    def test_invalid_range_name(self):
        """测试无效的区间名称"""
        with pytest.raises(ValidationError, match="Range name is required"):
            DistributionRange(
                name="",
                min_value=0.0,
                max_value=3.0,
                is_positive=True,
                display_name="0-3%"
            )
    
    def test_invalid_display_name(self):
        """测试无效的显示名称"""
        with pytest.raises(ValidationError, match="Display name is required"):
            DistributionRange(
                name="0-3%",
                min_value=0.0,
                max_value=3.0,
                is_positive=True,
                display_name=""
            )
    
    def test_invalid_min_max_values(self):
        """测试无效的最小最大值"""
        with pytest.raises(ValidationError, match="Min and max values must be numeric"):
            DistributionRange(
                name="0-3%",
                min_value="0.0",  # 字符串而非数字
                max_value=3.0,
                is_positive=True,
                display_name="0-3%"
            )
    
    def test_invalid_range_logic(self):
        """测试无效的区间逻辑"""
        with pytest.raises(ValidationError, match="Min value .* must be less than max value"):
            DistributionRange(
                name="invalid",
                min_value=5.0,
                max_value=3.0,  # 最大值小于最小值
                is_positive=True,
                display_name="invalid"
            )
    
    def test_positive_range_with_negative_values(self):
        """测试正区间包含负值的情况"""
        with pytest.raises(ValidationError, match="Positive range cannot have negative values"):
            DistributionRange(
                name="invalid",
                min_value=-1.0,  # 负值
                max_value=3.0,
                is_positive=True,
                display_name="invalid"
            )
    
    def test_negative_range_with_positive_values(self):
        """测试负区间包含正值的情况"""
        with pytest.raises(ValidationError, match="Negative range cannot have positive values"):
            DistributionRange(
                name="invalid",
                min_value=-3.0,
                max_value=1.0,  # 正值
                is_positive=False,
                display_name="invalid"
            )
    
    def test_to_dict(self):
        """测试转换为字典"""
        range_obj = DistributionRange(
            name="0-3%",
            min_value=0.0,
            max_value=3.0,
            is_positive=True,
            display_name="0-3%"
        )
        
        expected = {
            'name': "0-3%",
            'min_value': 0.0,
            'max_value': 3.0,
            'is_positive': True,
            'display_name': "0-3%"
        }
        
        assert range_obj.to_dict() == expected
    
    def test_from_dict(self):
        """测试从字典创建实例"""
        data = {
            'name': "0-3%",
            'min_value': 0.0,
            'max_value': 3.0,
            'is_positive': True,
            'display_name': "0-3%"
        }
        
        range_obj = DistributionRange.from_dict(data)
        
        assert range_obj.name == "0-3%"
        assert range_obj.min_value == 0.0
        assert range_obj.max_value == 3.0
        assert range_obj.is_positive is True
        assert range_obj.display_name == "0-3%"
    
    def test_create_default_ranges(self):
        """测试创建默认区间"""
        ranges = DistributionRange.create_default_ranges()
        
        # 验证区间数量
        assert len(ranges) == 10  # 5个正区间 + 5个负区间
        
        # 验证正区间
        positive_ranges = [r for r in ranges if r.is_positive]
        assert len(positive_ranges) == 5
        
        # 验证负区间
        negative_ranges = [r for r in ranges if not r.is_positive]
        assert len(negative_ranges) == 5
        
        # 验证特定区间
        range_names = [r.name for r in ranges]
        assert "0-3%" in range_names
        assert ">=10%" in range_names
        assert "0--3%" in range_names
        assert "<=-10%" in range_names


class TestPriceDistributionStats:
    """测试 PriceDistributionStats 数据模型"""
    
    def test_valid_stats(self):
        """测试有效的统计数据"""
        stats = PriceDistributionStats(
            trade_date="20241015",
            total_stocks=100,
            positive_ranges={"0-3%": 30, "3-5%": 20},
            positive_percentages={"0-3%": 30.0, "3-5%": 20.0},
            negative_ranges={"0--3%": 25, "-3--5%": 25},
            negative_percentages={"0--3%": 25.0, "-3--5%": 25.0}
        )
        
        assert stats.trade_date == "20241015"
        assert stats.total_stocks == 100
        assert stats.positive_ranges == {"0-3%": 30, "3-5%": 20}
        assert stats.negative_ranges == {"0--3%": 25, "-3--5%": 25}
        assert stats.created_at is not None
        assert stats.processing_time == 0.0
        assert stats.data_quality_score == 1.0
    
    def test_invalid_trade_date(self):
        """测试无效的交易日期"""
        with pytest.raises(ValidationError, match="Invalid trade date format"):
            PriceDistributionStats(
                trade_date="2024-10-15",  # 错误格式
                total_stocks=100,
                positive_ranges={"0-3%": 50},
                positive_percentages={"0-3%": 50.0},
                negative_ranges={"0--3%": 50},
                negative_percentages={"0--3%": 50.0}
            )
    
    def test_invalid_total_stocks(self):
        """测试无效的总股票数"""
        with pytest.raises(ValidationError, match="Invalid total stocks count"):
            PriceDistributionStats(
                trade_date="20241015",
                total_stocks=-1,  # 负数
                positive_ranges={"0-3%": 50},
                positive_percentages={"0-3%": 50.0},
                negative_ranges={"0--3%": 50},
                negative_percentages={"0--3%": 50.0}
            )
    
    def test_inconsistent_totals(self):
        """测试不一致的总数"""
        with pytest.raises(ValidationError, match="Range totals .* don't match total stocks"):
            PriceDistributionStats(
                trade_date="20241015",
                total_stocks=100,
                positive_ranges={"0-3%": 30},  # 总计30
                positive_percentages={"0-3%": 30.0},
                negative_ranges={"0--3%": 50},  # 总计50，30+50=80 != 100
                negative_percentages={"0--3%": 50.0}
            )
    
    def test_missing_percentage(self):
        """测试缺失的百分比数据"""
        with pytest.raises(ValidationError, match="Missing percentage for positive range"):
            PriceDistributionStats(
                trade_date="20241015",
                total_stocks=100,
                positive_ranges={"0-3%": 50, "3-5%": 25},  # 有两个区间
                positive_percentages={"0-3%": 50.0},  # 只有一个百分比
                negative_ranges={"0--3%": 25},
                negative_percentages={"0--3%": 25.0}
            )
    
    def test_percentage_mismatch(self):
        """测试百分比不匹配"""
        with pytest.raises(ValidationError, match="Percentage mismatch"):
            PriceDistributionStats(
                trade_date="20241015",
                total_stocks=100,
                positive_ranges={"0-3%": 50},
                positive_percentages={"0-3%": 60.0},  # 错误的百分比，应该是50%
                negative_ranges={"0--3%": 50},
                negative_percentages={"0--3%": 50.0}
            )
    
    def test_invalid_processing_time(self):
        """测试无效的处理时间"""
        with pytest.raises(ValidationError, match="Invalid processing time"):
            PriceDistributionStats(
                trade_date="20241015",
                total_stocks=100,
                positive_ranges={"0-3%": 50},
                positive_percentages={"0-3%": 50.0},
                negative_ranges={"0--3%": 50},
                negative_percentages={"0--3%": 50.0},
                processing_time=-1.0  # 负数
            )
    
    def test_invalid_data_quality_score(self):
        """测试无效的数据质量分数"""
        with pytest.raises(ValidationError, match="Data quality score must be between 0 and 1"):
            PriceDistributionStats(
                trade_date="20241015",
                total_stocks=100,
                positive_ranges={"0-3%": 50},
                positive_percentages={"0-3%": 50.0},
                negative_ranges={"0--3%": 50},
                negative_percentages={"0--3%": 50.0},
                data_quality_score=1.5  # 超出范围
            )
    
    def test_get_summary(self):
        """测试获取统计摘要"""
        stats = PriceDistributionStats(
            trade_date="20241015",
            total_stocks=100,
            positive_ranges={"0-3%": 30, "3-5%": 20},
            positive_percentages={"0-3%": 30.0, "3-5%": 20.0},
            negative_ranges={"0--3%": 25, "-3--5%": 25},
            negative_percentages={"0--3%": 25.0, "-3--5%": 25.0},
            processing_time=1.5,
            data_quality_score=0.95
        )
        
        summary = stats.get_summary()
        
        assert summary['trade_date'] == "20241015"
        assert summary['total_stocks'] == 100
        assert summary['positive_stocks'] == 50  # 30 + 20
        assert summary['negative_stocks'] == 50  # 25 + 25
        assert summary['positive_percentage'] == 50.0
        assert summary['negative_percentage'] == 50.0
        assert summary['processing_time'] == 1.5
        assert summary['data_quality_score'] == 0.95
    
    def test_get_market_summary(self):
        """测试获取市场统计摘要"""
        market_breakdown = {
            'shanghai': {
                'total_stocks': 50,
                'positive_ranges': {'0-3%': 20},
                'negative_ranges': {'0--3%': 30},
                'positive_percentages': {'0-3%': 40.0},
                'negative_percentages': {'0--3%': 60.0}
            }
        }
        
        stats = PriceDistributionStats(
            trade_date="20241015",
            total_stocks=100,
            positive_ranges={"0-3%": 50},
            positive_percentages={"0-3%": 50.0},
            negative_ranges={"0--3%": 50},
            negative_percentages={"0--3%": 50.0},
            market_breakdown=market_breakdown
        )
        
        shanghai_summary = stats.get_market_summary('shanghai')
        
        assert shanghai_summary is not None
        assert shanghai_summary['market'] == 'shanghai'
        assert shanghai_summary['total_stocks'] == 50
        assert shanghai_summary['positive_ranges'] == {'0-3%': 20}
        assert shanghai_summary['negative_ranges'] == {'0--3%': 30}
        
        # 测试不存在的市场
        nonexistent_summary = stats.get_market_summary('nonexistent')
        assert nonexistent_summary is None
    
    def test_to_dict_and_from_dict(self):
        """测试字典转换"""
        stats = PriceDistributionStats(
            trade_date="20241015",
            total_stocks=100,
            positive_ranges={"0-3%": 50},
            positive_percentages={"0-3%": 50.0},
            negative_ranges={"0--3%": 50},
            negative_percentages={"0--3%": 50.0}
        )
        
        # 转换为字典
        stats_dict = stats.to_dict()
        
        # 从字典创建新实例
        new_stats = PriceDistributionStats.from_dict(stats_dict)
        
        assert new_stats.trade_date == stats.trade_date
        assert new_stats.total_stocks == stats.total_stocks
        assert new_stats.positive_ranges == stats.positive_ranges
        assert new_stats.negative_ranges == stats.negative_ranges


class TestPriceDistributionRequest:
    """测试 PriceDistributionRequest 数据模型"""
    
    def test_valid_request(self):
        """测试有效的请求"""
        request = PriceDistributionRequest(
            trade_date="20241015",
            include_st=True,
            market_filter=['shanghai', 'shenzhen'],
            force_refresh=False,
            save_to_db=True,
            timeout=30
        )
        
        assert request.trade_date == "20241015"
        assert request.include_st is True
        assert request.market_filter == ['shanghai', 'shenzhen']
        assert request.force_refresh is False
        assert request.save_to_db is True
        assert request.timeout == 30
        assert request.data_type == "price_distribution_stats"
        assert request.distribution_ranges is not None  # 应该有默认区间
    
    def test_date_format_conversion(self):
        """测试日期格式转换"""
        # 测试 YYYY-MM-DD 格式转换为 YYYYMMDD
        request = PriceDistributionRequest(trade_date="2024-10-15")
        assert request.trade_date == "20241015"
        
        # 测试 YYYYMMDD 格式保持不变
        request = PriceDistributionRequest(trade_date="20241015")
        assert request.trade_date == "20241015"
    
    def test_invalid_date_format(self):
        """测试无效的日期格式"""
        with pytest.raises(ValidationError, match="Invalid trade date format"):
            PriceDistributionRequest(trade_date="2024/10/15")  # 错误格式
    
    def test_invalid_date_value(self):
        """测试无效的日期值"""
        with pytest.raises(ValidationError, match="Invalid date"):
            PriceDistributionRequest(trade_date="20241301")  # 13月不存在
    
    def test_future_date(self):
        """测试未来日期"""
        future_date = (datetime.now() + timedelta(days=1)).strftime('%Y%m%d')
        with pytest.raises(ValidationError, match="Trade date cannot be in the future"):
            PriceDistributionRequest(trade_date=future_date)
    
    def test_empty_trade_date(self):
        """测试空的交易日期"""
        with pytest.raises(ValidationError, match="Trade date is required"):
            PriceDistributionRequest(trade_date="")
    
    def test_invalid_market_filter(self):
        """测试无效的市场过滤器"""
        with pytest.raises(ValidationError, match="Invalid market filters"):
            PriceDistributionRequest(
                trade_date="20241015",
                market_filter=['invalid_market']
            )
    
    def test_market_filter_normalization(self):
        """测试市场过滤器标准化"""
        # 测试字符串转列表
        request = PriceDistributionRequest(
            trade_date="20241015",
            market_filter='shanghai'
        )
        assert request.market_filter == ['shanghai']
        
        # 测试去重
        request = PriceDistributionRequest(
            trade_date="20241015",
            market_filter=['shanghai', 'shenzhen', 'shanghai']
        )
        assert request.market_filter == ['shanghai', 'shenzhen']
    
    def test_invalid_distribution_ranges(self):
        """测试无效的分布区间"""
        with pytest.raises(ValidationError, match="Distribution ranges must be a list"):
            PriceDistributionRequest(
                trade_date="20241015",
                distribution_ranges="invalid"
            )
    
    def test_invalid_timeout(self):
        """测试无效的超时时间"""
        with pytest.raises(ValidationError, match="Timeout must be a positive integer"):
            PriceDistributionRequest(
                trade_date="20241015",
                timeout=-1
            )
    
    def test_get_formatted_date(self):
        """测试获取格式化日期"""
        request = PriceDistributionRequest(trade_date="20241015")
        assert request.get_formatted_date() == "2024-10-15"
    
    def test_is_market_included(self):
        """测试市场包含检查"""
        # 无过滤器时，所有市场都包含
        request = PriceDistributionRequest(
            trade_date="20241015",
            market_filter=None
        )
        assert request.is_market_included('shanghai') is True
        assert request.is_market_included('shenzhen') is True
        
        # 有过滤器时，只包含指定市场
        request = PriceDistributionRequest(
            trade_date="20241015",
            market_filter=['shanghai']
        )
        assert request.is_market_included('shanghai') is True
        assert request.is_market_included('shenzhen') is False
    
    def test_get_positive_and_negative_ranges(self):
        """测试获取正负区间"""
        request = PriceDistributionRequest(trade_date="20241015")
        
        positive_ranges = request.get_positive_ranges()
        negative_ranges = request.get_negative_ranges()
        
        assert len(positive_ranges) == 5  # 默认5个正区间
        assert len(negative_ranges) == 5  # 默认5个负区间
        
        # 验证区间类型
        for range_obj in positive_ranges:
            assert range_obj.is_positive is True
        
        for range_obj in negative_ranges:
            assert range_obj.is_positive is False
    
    def test_find_range_for_value(self):
        """测试为值找到对应区间"""
        request = PriceDistributionRequest(trade_date="20241015")
        
        # 测试正值
        range_obj = request.find_range_for_value(2.5)
        assert range_obj is not None
        assert range_obj.name == "0-3%"
        
        # 测试负值
        range_obj = request.find_range_for_value(-2.5)
        assert range_obj is not None
        assert range_obj.name == "0--3%"
        
        # 测试边界值
        range_obj = request.find_range_for_value(3.0)
        assert range_obj is not None
        assert range_obj.name == "3-5%"  # 3.0应该在3-5%区间
    
    def test_to_dict_and_from_dict(self):
        """测试字典转换"""
        request = PriceDistributionRequest(
            trade_date="20241015",
            include_st=False,
            market_filter=['shanghai'],
            force_refresh=True,
            save_to_db=False,
            timeout=60
        )
        
        # 转换为字典
        request_dict = request.to_dict()
        
        # 从字典创建新实例
        new_request = PriceDistributionRequest.from_dict(request_dict)
        
        assert new_request.trade_date == request.trade_date
        assert new_request.include_st == request.include_st
        assert new_request.market_filter == request.market_filter
        assert new_request.force_refresh == request.force_refresh
        assert new_request.save_to_db == request.save_to_db
        assert new_request.timeout == request.timeout


if __name__ == "__main__":
    pytest.main([__file__])