"""
涨停统计验证工具测试

测试涨停统计功能的验证和回退机制
"""

import pytest
import pandas as pd
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
import logging

import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from quickstock.utils.limit_up_validators import (
    TradeDateValidator,
    DataValidator,
    FallbackManager,
    ValidationUtils
)
from quickstock.core.errors import (
    InvalidTradeDateError,
    InsufficientDataError,
    StockClassificationError,
    LimitUpDetectionError
)


class TestTradeDateValidator:
    """测试交易日期验证器"""
    
    def test_validate_date_format_success(self):
        """测试成功的日期格式验证"""
        test_cases = [
            ("2024-01-15", datetime(2024, 1, 15)),
            ("20240115", datetime(2024, 1, 15)),
            ("2024/01/15", datetime(2024, 1, 15)),
            ("  2024-01-15  ", datetime(2024, 1, 15)),  # 带空格
        ]
        
        for date_str, expected in test_cases:
            result = TradeDateValidator.validate_date_format(date_str)
            assert result == expected
    
    def test_validate_date_format_failure(self):
        """测试失败的日期格式验证"""
        invalid_dates = [
            "",
            None,
            "invalid-date",
            "2024-13-01",  # 无效月份
            "2024-01-32",  # 无效日期
            "24-01-15",    # 年份格式错误
            "2024/13/01",  # 无效月份
        ]
        
        for invalid_date in invalid_dates:
            with pytest.raises(InvalidTradeDateError):
                TradeDateValidator.validate_date_format(invalid_date)
    
    def test_validate_trade_date_success(self):
        """测试成功的交易日期验证"""
        # 使用一个工作日
        yesterday = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
        result = TradeDateValidator.validate_trade_date(yesterday)
        assert result == yesterday
    
    def test_validate_trade_date_future_date(self):
        """测试未来日期验证"""
        future_date = (datetime.now() + timedelta(days=30)).strftime("%Y-%m-%d")
        
        with pytest.raises(InvalidTradeDateError) as exc_info:
            TradeDateValidator.validate_trade_date(future_date)
        
        assert "不能查询未来日期" in str(exc_info.value)
    
    def test_validate_trade_date_too_old(self):
        """测试过于久远的日期"""
        old_date = (datetime.now() - timedelta(days=365 * 11)).strftime("%Y-%m-%d")
        
        with pytest.raises(InvalidTradeDateError) as exc_info:
            TradeDateValidator.validate_trade_date(old_date)
        
        assert "日期过于久远" in str(exc_info.value)
    
    def test_is_likely_trading_day(self):
        """测试交易日判断"""
        # 测试工作日（假设是交易日）
        workday = datetime(2024, 1, 15)  # 周一
        assert TradeDateValidator.is_likely_trading_day("2024-01-15") == True
        
        # 测试周末（不是交易日）
        weekend = datetime(2024, 1, 13)  # 周六
        assert TradeDateValidator.is_likely_trading_day("2024-01-13") == False
        
        # 测试已知节假日
        holiday = "2024-01-01"  # 元旦
        assert TradeDateValidator.is_likely_trading_day(holiday) == False
    
    def test_weekend_warning(self, caplog):
        """测试周末日期警告"""
        with caplog.at_level(logging.WARNING):
            # 2024-01-13 是周六
            TradeDateValidator.validate_trade_date("2024-01-13")
        
        assert "是周末" in caplog.text
    
    def test_holiday_warning(self, caplog):
        """测试节假日警告"""
        with caplog.at_level(logging.WARNING):
            TradeDateValidator.validate_trade_date("2024-01-01")
        
        assert "是已知节假日" in caplog.text


class TestDataValidator:
    """测试数据验证器"""
    
    def create_sample_stock_data(self, num_stocks: int = 100) -> pd.DataFrame:
        """创建示例股票数据"""
        data = []
        for i in range(num_stocks):
            data.append({
                'ts_code': f'{i:06d}.SZ',
                'open': 10.0 + i * 0.1,
                'close': 11.0 + i * 0.1,
                'high': 11.5 + i * 0.1,
                'low': 9.5 + i * 0.1,
                'volume': 1000000
            })
        return pd.DataFrame(data)
    
    def test_validate_stock_data_success(self):
        """测试成功的股票数据验证"""
        stock_data = self.create_sample_stock_data(200)
        is_valid, message = DataValidator.validate_stock_data(stock_data, "2024-01-15")
        
        assert is_valid == True
        assert "验证通过" in message
    
    def test_validate_stock_data_empty(self):
        """测试空数据验证"""
        empty_data = pd.DataFrame()
        is_valid, message = DataValidator.validate_stock_data(empty_data, "2024-01-15")
        
        assert is_valid == False
        assert "没有股票数据" in message
    
    def test_validate_stock_data_insufficient_count(self):
        """测试数据量不足"""
        small_data = self.create_sample_stock_data(50)
        is_valid, message = DataValidator.validate_stock_data(small_data, "2024-01-15", min_stocks=100)
        
        assert is_valid == False
        assert "数据量不足" in message
        assert "50" in message
        assert "100" in message
    
    def test_validate_stock_data_missing_columns(self):
        """测试缺少必要字段"""
        data = pd.DataFrame({
            'ts_code': ['000001.SZ'],
            'open': [10.0],
            # 缺少 close, high, low
        })
        
        is_valid, message = DataValidator.validate_stock_data(data, "2024-01-15", min_stocks=1)
        
        assert is_valid == False
        assert "缺少必要字段" in message
    
    def test_validate_stock_data_null_values(self):
        """测试空值数据"""
        data = self.create_sample_stock_data(100)
        data.loc[0, 'open'] = None
        data.loc[1, 'close'] = None
        
        is_valid, message = DataValidator.validate_stock_data(data, "2024-01-15")
        
        assert is_valid == False
        assert "存在空值字段" in message
    
    def test_validate_stock_data_invalid_prices(self):
        """测试无效价格数据"""
        data = self.create_sample_stock_data(100)
        data.loc[0, 'open'] = -1.0  # 负价格
        data.loc[1, 'close'] = 0.0  # 零价格
        
        is_valid, message = DataValidator.validate_stock_data(data, "2024-01-15")
        
        assert is_valid == False
        assert "无效价格数据" in message
    
    def test_validate_stock_data_price_logic_error(self):
        """测试价格逻辑错误"""
        data = self.create_sample_stock_data(100)
        data.loc[0, 'high'] = 5.0  # 最高价低于最低价
        data.loc[0, 'low'] = 10.0
        
        is_valid, message = DataValidator.validate_stock_data(data, "2024-01-15")
        
        assert is_valid == False
        assert "价格逻辑错误" in message
    
    def test_validate_price_data_success(self):
        """测试成功的价格数据验证"""
        price_data = {
            'open': 10.0,
            'close': 11.0,
            'high': 11.5,
            'low': 9.5
        }
        
        result = DataValidator.validate_price_data("000001.SZ", price_data)
        assert result == True
    
    def test_validate_price_data_missing_field(self):
        """测试缺少必要字段"""
        price_data = {
            'open': 10.0,
            'close': 11.0,
            # 缺少 high
        }
        
        with pytest.raises(LimitUpDetectionError) as exc_info:
            DataValidator.validate_price_data("000001.SZ", price_data)
        
        assert "缺少必要的价格字段" in str(exc_info.value)
        assert "high" in str(exc_info.value)
    
    def test_validate_price_data_null_value(self):
        """测试空值"""
        price_data = {
            'open': 10.0,
            'close': None,
            'high': 11.0
        }
        
        with pytest.raises(LimitUpDetectionError) as exc_info:
            DataValidator.validate_price_data("000001.SZ", price_data)
        
        assert "价格字段 close 为空" in str(exc_info.value)
    
    def test_validate_price_data_invalid_value(self):
        """测试无效值"""
        price_data = {
            'open': 10.0,
            'close': -1.0,  # 负价格
            'high': 11.0
        }
        
        with pytest.raises(LimitUpDetectionError) as exc_info:
            DataValidator.validate_price_data("000001.SZ", price_data)
        
        assert "价格字段 close 无效" in str(exc_info.value)
    
    def test_validate_price_data_logic_error(self):
        """测试价格逻辑错误"""
        price_data = {
            'open': 10.0,
            'close': 11.0,
            'high': 9.0  # 最高价低于开盘价和收盘价
        }
        
        with pytest.raises(LimitUpDetectionError) as exc_info:
            DataValidator.validate_price_data("000001.SZ", price_data)
        
        assert "最高价不能低于" in str(exc_info.value)


class TestFallbackManager:
    """测试回退机制管理器"""
    
    @pytest.fixture
    def fallback_manager(self):
        """创建回退管理器"""
        return FallbackManager()
    
    def test_handle_insufficient_data_success(self, fallback_manager):
        """测试处理数据不足 - 成功情况"""
        data = pd.DataFrame({'ts_code': [f'{i:06d}.SZ' for i in range(150)]})
        
        result = fallback_manager.handle_insufficient_data("2024-01-15", data, min_required=100)
        assert len(result) == 150
    
    def test_handle_insufficient_data_warning(self, fallback_manager, caplog):
        """测试处理数据不足 - 警告情况"""
        data = pd.DataFrame({'ts_code': [f'{i:06d}.SZ' for i in range(80)]})
        
        with caplog.at_level(logging.WARNING):
            result = fallback_manager.handle_insufficient_data("2024-01-15", data, min_required=100)
        
        assert len(result) == 80
        assert "数据量不足" in caplog.text
    
    def test_handle_insufficient_data_failure(self, fallback_manager):
        """测试处理数据不足 - 失败情况"""
        data = pd.DataFrame({'ts_code': [f'{i:06d}.SZ' for i in range(30)]})
        
        with pytest.raises(InsufficientDataError):
            fallback_manager.handle_insufficient_data("2024-01-15", data, min_required=100)
    
    def test_handle_insufficient_data_empty(self, fallback_manager):
        """测试处理空数据"""
        with pytest.raises(InsufficientDataError):
            fallback_manager.handle_insufficient_data("2024-01-15", pd.DataFrame(), min_required=100)
    
    def test_handle_classification_error(self, fallback_manager, caplog):
        """测试处理分类错误"""
        error = StockClassificationError("000001.SZ", "未知格式")
        
        with caplog.at_level(logging.WARNING):
            result = fallback_manager.handle_classification_error("000001.SZ", error)
        
        assert result == "shenzhen"
        assert "分类失败" in caplog.text
    
    def test_handle_classification_error_various_codes(self, fallback_manager):
        """测试各种股票代码的分类回退"""
        test_cases = [
            ("600000.SH", "shanghai"),
            ("688001.SH", "star"),
            ("000001.SZ", "shenzhen"),
            ("430001.BJ", "beijing"),
            ("UNKNOWN.XX", "unknown"),
        ]
        
        for stock_code, expected in test_cases:
            error = StockClassificationError(stock_code)
            result = fallback_manager.handle_classification_error(stock_code, error)
            assert result == expected
    
    def test_handle_detection_error(self, fallback_manager, caplog):
        """测试处理检测错误"""
        error = LimitUpDetectionError("000001.SZ", "价格数据异常")
        
        with caplog.at_level(logging.WARNING):
            result = fallback_manager.handle_detection_error("000001.SZ", error)
        
        assert result == False
        assert "涨停检测失败" in caplog.text
    
    def test_create_partial_results(self, fallback_manager, caplog):
        """测试创建部分结果"""
        successful_data = {
            'total': 100,
            'shanghai': 30,
            'shenzhen': 70
        }
        failed_stocks = [f'{i:06d}.SZ' for i in range(20)]
        
        with caplog.at_level(logging.WARNING):
            result = fallback_manager.create_partial_results("2024-01-15", successful_data, failed_stocks)
        
        assert result['total'] == 100
        assert 'warnings' in result
        assert result['warnings']['partial_data'] == True
        assert result['warnings']['failed_stocks_count'] == 20
        assert len(result['warnings']['failed_stocks']) == 10  # 只显示前10个
        assert "股票处理失败" in caplog.text
    
    def test_get_graceful_degradation_strategy(self, fallback_manager):
        """测试获取优雅降级策略"""
        test_cases = [
            (InvalidTradeDateError("invalid"), "reject"),
            (InsufficientDataError("2024-01-15"), "partial_results"),
            (StockClassificationError("UNKNOWN"), "default_classification"),
            (LimitUpDetectionError("000001.SZ"), "conservative_detection"),
            (Exception("unknown"), "fail_fast"),
        ]
        
        for error, expected_strategy in test_cases:
            strategy = fallback_manager.get_graceful_degradation_strategy(error)
            assert strategy['strategy'] == expected_strategy
            assert 'message' in strategy


class TestValidationUtils:
    """测试验证工具类"""
    
    def test_sanitize_stock_code(self):
        """测试股票代码清理"""
        test_cases = [
            ("  000001.SZ  ", "000001.SZ"),
            ("000001.sz", "000001.SZ"),
            ("000001@.SZ", "000001.SZ"),
            ("000-001.SZ", "000001.SZ"),
            ("", ""),
            (None, ""),
        ]
        
        for input_code, expected in test_cases:
            result = ValidationUtils.sanitize_stock_code(input_code)
            assert result == expected
    
    def test_validate_market_filter(self):
        """测试市场过滤器验证"""
        test_cases = [
            (None, []),
            ([], []),
            (['shanghai', 'shenzhen'], ['shanghai', 'shenzhen']),
            (['SHANGHAI', 'Shenzhen'], ['shanghai', 'shenzhen']),
            (['shanghai', 'invalid', 'star'], ['shanghai', 'star']),
            (['unknown_market'], []),
        ]
        
        for input_filter, expected in test_cases:
            result = ValidationUtils.validate_market_filter(input_filter)
            assert result == expected
    
    def test_validate_market_filter_warning(self, caplog):
        """测试市场过滤器验证警告"""
        with caplog.at_level(logging.WARNING):
            ValidationUtils.validate_market_filter(['invalid_market'])
        
        assert "无效的市场过滤器" in caplog.text
    
    def test_validate_statistics_consistency_success(self):
        """测试统计数据一致性验证 - 成功"""
        stats = {
            'total': 100,
            'non_st': 90,
            'st': 10,
            'shanghai': 30,
            'shenzhen': 40,
            'star': 20,
            'beijing': 10
        }
        
        result = ValidationUtils.validate_statistics_consistency(stats)
        assert result == True
    
    def test_validate_statistics_consistency_missing_field(self, caplog):
        """测试统计数据一致性验证 - 缺少字段"""
        stats = {
            'total': 100,
            'non_st': 90,
            # 缺少其他字段
        }
        
        with caplog.at_level(logging.ERROR):
            result = ValidationUtils.validate_statistics_consistency(stats)
        
        assert result == False
        assert "缺少字段" in caplog.text
    
    def test_validate_statistics_consistency_invalid_value(self, caplog):
        """测试统计数据一致性验证 - 无效值"""
        stats = {
            'total': -1,  # 负数
            'non_st': 90,
            'st': 10,
            'shanghai': 30,
            'shenzhen': 40,
            'star': 20,
            'beijing': 10
        }
        
        with caplog.at_level(logging.ERROR):
            result = ValidationUtils.validate_statistics_consistency(stats)
        
        assert result == False
        assert "无效" in caplog.text
    
    def test_validate_statistics_consistency_market_sum_error(self, caplog):
        """测试统计数据一致性验证 - 市场总和错误"""
        stats = {
            'total': 100,
            'non_st': 90,
            'st': 10,
            'shanghai': 30,
            'shenzhen': 40,
            'star': 20,
            'beijing': 5  # 总和不等于100
        }
        
        with caplog.at_level(logging.ERROR):
            result = ValidationUtils.validate_statistics_consistency(stats)
        
        assert result == False
        assert "各市场总和" in caplog.text
    
    def test_validate_statistics_consistency_st_sum_error(self, caplog):
        """测试统计数据一致性验证 - ST总和错误"""
        stats = {
            'total': 100,
            'non_st': 95,  # 95 + 10 != 100
            'st': 10,
            'shanghai': 30,
            'shenzhen': 40,
            'star': 20,
            'beijing': 10
        }
        
        with caplog.at_level(logging.ERROR):
            result = ValidationUtils.validate_statistics_consistency(stats)
        
        assert result == False
        assert "非ST + ST" in caplog.text


if __name__ == "__main__":
    pytest.main([__file__])