"""
Unit tests for limit-up statistics data models.
"""
import pytest
from datetime import datetime, timedelta
from quickstock.models import (
    StockDailyData, 
    LimitUpStats, 
    LimitUpStatsRequest,
    MARKET_CLASSIFICATION_RULES,
    ST_PATTERNS,
    LIMIT_UP_THRESHOLDS
)


class TestStockDailyData:
    """测试 StockDailyData 数据模型"""
    
    def test_valid_stock_data_creation(self):
        """测试有效股票数据的创建"""
        data = StockDailyData(
            ts_code="000001.SZ",
            trade_date="20241015",
            open=10.50,
            high=11.55,
            low=10.30,
            close=11.00,
            pre_close=10.00,
            change=1.00,
            pct_chg=10.0,
            vol=1000000,
            amount=11000000.0,
            name="平安银行"
        )
        
        assert data.ts_code == "000001.SZ"
        assert data.trade_date == "20241015"
        assert data.close == 11.00
        assert data.name == "平安银行"
    
    def test_invalid_stock_code_format(self):
        """测试无效股票代码格式"""
        with pytest.raises(ValueError, match="Invalid stock code format"):
            StockDailyData(
                ts_code="invalid_code",
                trade_date="20241015",
                open=10.50,
                high=11.55,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=1000000,
                amount=11000000.0,
                name="测试股票"
            )
    
    def test_invalid_trade_date_format(self):
        """测试无效交易日期格式"""
        with pytest.raises(ValueError, match="Invalid trade date format"):
            StockDailyData(
                ts_code="000001.SZ",
                trade_date="2024-10-15",  # 错误格式
                open=10.50,
                high=11.55,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=1000000,
                amount=11000000.0,
                name="平安银行"
            )
    
    def test_invalid_price_values(self):
        """测试无效价格值"""
        with pytest.raises(ValueError, match="Invalid open"):
            StockDailyData(
                ts_code="000001.SZ",
                trade_date="20241015",
                open=-1.0,  # 负价格
                high=11.55,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=1000000,
                amount=11000000.0,
                name="平安银行"
            )
    
    def test_price_relationship_validation(self):
        """测试价格关系验证"""
        with pytest.raises(ValueError, match="Price relationship validation failed"):
            StockDailyData(
                ts_code="000001.SZ",
                trade_date="20241015",
                open=12.00,  # 开盘价高于最高价
                high=11.55,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=1000000,
                amount=11000000.0,
                name="平安银行"
            )
    
    def test_negative_volume_amount(self):
        """测试负成交量和成交额"""
        with pytest.raises(ValueError, match="Volume and amount must be non-negative"):
            StockDailyData(
                ts_code="000001.SZ",
                trade_date="20241015",
                open=10.50,
                high=11.55,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=-1000,  # 负成交量
                amount=11000000.0,
                name="平安银行"
            )
    
    def test_empty_stock_name(self):
        """测试空股票名称"""
        with pytest.raises(ValueError, match="Stock name is required"):
            StockDailyData(
                ts_code="000001.SZ",
                trade_date="20241015",
                open=10.50,
                high=11.55,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=1000000,
                amount=11000000.0,
                name=""  # 空名称
            )
    
    def test_to_dict_conversion(self):
        """测试转换为字典"""
        data = StockDailyData(
            ts_code="000001.SZ",
            trade_date="20241015",
            open=10.50,
            high=11.55,
            low=10.30,
            close=11.00,
            pre_close=10.00,
            change=1.00,
            pct_chg=10.0,
            vol=1000000,
            amount=11000000.0,
            name="平安银行"
        )
        
        result = data.to_dict()
        assert isinstance(result, dict)
        assert result['ts_code'] == "000001.SZ"
        assert result['close'] == 11.00
        assert len(result) == 12
    
    def test_from_dict_creation(self):
        """测试从字典创建实例"""
        data_dict = {
            'ts_code': "000001.SZ",
            'trade_date': "20241015",
            'open': 10.50,
            'high': 11.55,
            'low': 10.30,
            'close': 11.00,
            'pre_close': 10.00,
            'change': 1.00,
            'pct_chg': 10.0,
            'vol': 1000000,
            'amount': 11000000.0,
            'name': "平安银行"
        }
        
        data = StockDailyData.from_dict(data_dict)
        assert data.ts_code == "000001.SZ"
        assert data.close == 11.00


class TestLimitUpStats:
    """测试 LimitUpStats 数据模型"""
    
    def test_valid_limit_up_stats_creation(self):
        """测试有效涨停统计数据的创建"""
        stats = LimitUpStats(
            trade_date="20241015",
            total=5,
            non_st=4,
            shanghai=1,
            shenzhen=2,
            star=1,
            beijing=1,
            st=1,
            limit_up_stocks=["000001.SZ", "600000.SH", "688001.SH", "430001.BJ", "000002.SZ"],
            market_breakdown={
                "shanghai": ["600000.SH"],
                "shenzhen": ["000001.SZ", "000002.SZ"],
                "star": ["688001.SH"],
                "beijing": ["430001.BJ"]
            }
        )
        
        assert stats.total == 5
        assert stats.non_st == 4
        assert stats.st == 1
        assert len(stats.limit_up_stocks) == 5
    
    def test_invalid_trade_date_format(self):
        """测试无效交易日期格式"""
        with pytest.raises(ValueError, match="Invalid trade date format"):
            LimitUpStats(
                trade_date="2024-10-15",  # 错误格式
                total=5,
                non_st=4,
                shanghai=1,
                shenzhen=2,
                star=1,
                beijing=1,
                st=1,
                limit_up_stocks=["000001.SZ", "600000.SH", "688001.SH", "430001.BJ", "000002.SZ"]
            )
    
    def test_negative_count_validation(self):
        """测试负数计数验证"""
        with pytest.raises(ValueError, match="Invalid total"):
            LimitUpStats(
                trade_date="20241015",
                total=-1,  # 负数
                non_st=4,
                shanghai=1,
                shenzhen=2,
                star=1,
                beijing=1,
                st=1,
                limit_up_stocks=[]
            )
    
    def test_total_market_sum_consistency(self):
        """测试总数与市场分类之和的一致性"""
        with pytest.raises(ValueError, match="Total count .* doesn't match market sum"):
            LimitUpStats(
                trade_date="20241015",
                total=5,
                non_st=4,
                shanghai=1,
                shenzhen=1,  # 市场之和 = 4，但总数 = 5
                star=1,
                beijing=1,
                st=1,
                limit_up_stocks=["000001.SZ", "600000.SH", "688001.SH", "430001.BJ", "000002.SZ"]
            )
    
    def test_total_st_nonst_consistency(self):
        """测试总数与ST+非ST之和的一致性"""
        with pytest.raises(ValueError, match="Total count .* doesn't match ST \\+ non-ST"):
            LimitUpStats(
                trade_date="20241015",
                total=5,
                non_st=3,  # ST + 非ST = 4，但总数 = 5
                shanghai=1,
                shenzhen=2,
                star=1,
                beijing=1,
                st=1,
                limit_up_stocks=["000001.SZ", "600000.SH", "688001.SH", "430001.BJ", "000002.SZ"]
            )
    
    def test_limit_up_stocks_count_consistency(self):
        """测试涨停股票列表长度与总数的一致性"""
        with pytest.raises(ValueError, match="Limit up stocks count .* doesn't match total"):
            LimitUpStats(
                trade_date="20241015",
                total=5,
                non_st=4,
                shanghai=1,
                shenzhen=2,
                star=1,
                beijing=1,
                st=1,
                limit_up_stocks=["000001.SZ", "600000.SH"]  # 只有2个股票，但总数是5
            )
    
    def test_auto_timestamp_creation(self):
        """测试自动时间戳创建"""
        stats = LimitUpStats(
            trade_date="20241015",
            total=1,
            non_st=1,
            shanghai=1,
            shenzhen=0,
            star=0,
            beijing=0,
            st=0,
            limit_up_stocks=["600000.SH"]
        )
        
        assert stats.created_at is not None
        assert stats.updated_at is not None
        assert stats.created_at == stats.updated_at
    
    def test_to_dict_conversion(self):
        """测试转换为字典"""
        stats = LimitUpStats(
            trade_date="20241015",
            total=1,
            non_st=1,
            shanghai=1,
            shenzhen=0,
            star=0,
            beijing=0,
            st=0,
            limit_up_stocks=["600000.SH"]
        )
        
        result = stats.to_dict()
        assert isinstance(result, dict)
        assert result['trade_date'] == "20241015"
        assert result['total'] == 1
        assert 'created_at' in result
        assert 'updated_at' in result
    
    def test_get_summary(self):
        """测试获取统计摘要"""
        stats = LimitUpStats(
            trade_date="20241015",
            total=5,
            non_st=4,
            shanghai=1,
            shenzhen=2,
            star=1,
            beijing=1,
            st=1,
            limit_up_stocks=["000001.SZ", "600000.SH", "688001.SH", "430001.BJ", "000002.SZ"]
        )
        
        summary = stats.get_summary()
        expected = {
            'total': 5,
            'non_st': 4,
            'shanghai': 1,
            'shenzhen': 2,
            'star': 1,
            'beijing': 1,
            'st': 1
        }
        assert summary == expected


class TestLimitUpStatsRequest:
    """测试 LimitUpStatsRequest 数据模型"""
    
    def test_valid_request_creation_yyyymmdd(self):
        """测试使用YYYYMMDD格式创建有效请求"""
        request = LimitUpStatsRequest(trade_date="20241015")
        
        assert request.trade_date == "20241015"
        assert request.include_st is True
        assert request.market_filter is None
        assert request.force_refresh is False
        assert request.save_to_db is True
        assert request.timeout == 30
        assert request.data_type == "limit_up_stats"
    
    def test_valid_request_creation_yyyy_mm_dd(self):
        """测试使用YYYY-MM-DD格式创建有效请求"""
        request = LimitUpStatsRequest(trade_date="2024-10-15")
        
        # 应该自动转换为YYYYMMDD格式
        assert request.trade_date == "20241015"
    
    def test_invalid_date_format(self):
        """测试无效日期格式"""
        with pytest.raises(ValueError, match="Invalid trade date format"):
            LimitUpStatsRequest(trade_date="2024/10/15")
    
    def test_invalid_date_value(self):
        """测试无效日期值"""
        with pytest.raises(ValueError, match="Invalid date"):
            LimitUpStatsRequest(trade_date="20241301")  # 13月
    
    def test_future_date_validation(self):
        """测试未来日期验证"""
        future_date = (datetime.now() + timedelta(days=1)).strftime("%Y%m%d")
        with pytest.raises(ValueError, match="Trade date cannot be in the future"):
            LimitUpStatsRequest(trade_date=future_date)
    
    def test_empty_trade_date(self):
        """测试空交易日期"""
        with pytest.raises(ValueError, match="Trade date is required"):
            LimitUpStatsRequest(trade_date="")
    
    def test_valid_market_filter(self):
        """测试有效市场过滤器"""
        request = LimitUpStatsRequest(
            trade_date="20241015",
            market_filter=["shanghai", "shenzhen"]
        )
        
        assert set(request.market_filter) == {"shanghai", "shenzhen"}
        assert len(request.market_filter) == 2
    
    def test_invalid_market_filter(self):
        """测试无效市场过滤器"""
        with pytest.raises(ValueError, match="Invalid market filters"):
            LimitUpStatsRequest(
                trade_date="20241015",
                market_filter=["invalid_market"]
            )
    
    def test_market_filter_normalization(self):
        """测试市场过滤器标准化"""
        # 测试单个字符串转换为列表
        request = LimitUpStatsRequest(
            trade_date="20241015",
            market_filter="shanghai"
        )
        assert request.market_filter == ["shanghai"]
        
        # 测试去重
        request = LimitUpStatsRequest(
            trade_date="20241015",
            market_filter=["shanghai", "shanghai", "shenzhen"]
        )
        assert len(request.market_filter) == 2
        assert "shanghai" in request.market_filter
        assert "shenzhen" in request.market_filter
    
    def test_invalid_timeout(self):
        """测试无效超时时间"""
        with pytest.raises(ValueError, match="Timeout must be a positive integer"):
            LimitUpStatsRequest(
                trade_date="20241015",
                timeout=-1
            )
    
    def test_get_formatted_date(self):
        """测试获取格式化日期"""
        request = LimitUpStatsRequest(trade_date="20241015")
        assert request.get_formatted_date() == "2024-10-15"
    
    def test_is_market_included(self):
        """测试市场包含检查"""
        # 无过滤器时，所有市场都包含
        request = LimitUpStatsRequest(trade_date="20241015")
        assert request.is_market_included("shanghai") is True
        assert request.is_market_included("shenzhen") is True
        
        # 有过滤器时，只包含指定市场
        request = LimitUpStatsRequest(
            trade_date="20241015",
            market_filter=["shanghai"]
        )
        assert request.is_market_included("shanghai") is True
        assert request.is_market_included("shenzhen") is False
    
    def test_to_dict_conversion(self):
        """测试转换为字典"""
        request = LimitUpStatsRequest(
            trade_date="20241015",
            include_st=False,
            market_filter=["shanghai"],
            force_refresh=True,
            save_to_db=False,
            timeout=60
        )
        
        result = request.to_dict()
        assert result['trade_date'] == "20241015"
        assert result['include_st'] is False
        assert result['market_filter'] == ["shanghai"]
        assert result['force_refresh'] is True
        assert result['save_to_db'] is False
        assert result['timeout'] == 60
        assert result['data_type'] == "limit_up_stats"
    
    def test_from_dict_creation(self):
        """测试从字典创建实例"""
        data_dict = {
            'trade_date': "20241015",
            'include_st': False,
            'market_filter': ["shanghai"],
            'force_refresh': True,
            'save_to_db': False,
            'timeout': 60
        }
        
        request = LimitUpStatsRequest.from_dict(data_dict)
        assert request.trade_date == "20241015"
        assert request.include_st is False
        assert request.market_filter == ["shanghai"]


class TestConstants:
    """测试常量定义"""
    
    def test_market_classification_rules(self):
        """测试市场分类规则常量"""
        assert 'shanghai' in MARKET_CLASSIFICATION_RULES
        assert 'shenzhen' in MARKET_CLASSIFICATION_RULES
        assert 'star' in MARKET_CLASSIFICATION_RULES
        assert 'beijing' in MARKET_CLASSIFICATION_RULES
        
        # 检查每个市场都有patterns和description
        for market, rules in MARKET_CLASSIFICATION_RULES.items():
            assert 'patterns' in rules
            assert 'description' in rules
            assert isinstance(rules['patterns'], list)
            assert isinstance(rules['description'], str)
    
    def test_st_patterns(self):
        """测试ST股票模式常量"""
        assert isinstance(ST_PATTERNS, list)
        assert len(ST_PATTERNS) > 0
        assert r'\*ST' in ST_PATTERNS
        assert r'ST' in ST_PATTERNS
    
    def test_limit_up_thresholds(self):
        """测试涨停阈值常量"""
        assert 'normal' in LIMIT_UP_THRESHOLDS
        assert 'st' in LIMIT_UP_THRESHOLDS
        assert 'star' in LIMIT_UP_THRESHOLDS
        assert 'beijing' in LIMIT_UP_THRESHOLDS
        
        # 检查阈值范围合理
        assert 0 < LIMIT_UP_THRESHOLDS['normal'] <= 1
        assert 0 < LIMIT_UP_THRESHOLDS['st'] <= 1
        assert LIMIT_UP_THRESHOLDS['st'] < LIMIT_UP_THRESHOLDS['normal']  # ST股票涨停幅度更小


if __name__ == "__main__":
    pytest.main([__file__, "-v"])