"""
涨跌分布统计异常处理单元测试

测试涨跌分布统计功能的异常类和错误处理机制
"""

import pytest
import asyncio
import time
from unittest.mock import Mock, AsyncMock, patch
from typing import Dict, Any, List

from quickstock.core.price_distribution_errors import (
    PriceDistributionError,
    InvalidDistributionRangeError,
    InsufficientPriceDataError,
    DistributionCalculationError,
    MarketClassificationError,
    PriceDistributionCacheError,
    InvalidTradeDateError,
    StatisticsAggregationError,
    PriceDistributionValidationError,
    PriceDistributionServiceError
)
from quickstock.core.price_distribution_fault_tolerance import FaultTolerantPriceDistributionService
from quickstock.models.price_distribution_models import PriceDistributionStats, PriceDistributionRequest


class TestPriceDistributionExceptions:
    """测试涨跌分布统计异常类"""
    
    def test_price_distribution_error_base(self):
        """测试基础异常类"""
        error = PriceDistributionError(
            message="测试错误",
            error_code="TEST_ERROR",
            details={"key": "value"},
            suggestions="测试建议"
        )
        
        assert str(error) == "测试错误"
        assert error.error_code == "TEST_ERROR"
        assert error.details == {"key": "value"}
        assert error.suggestions == "测试建议"
    
    def test_invalid_distribution_range_error(self):
        """测试无效分布区间异常"""
        error = InvalidDistributionRangeError(
            range_name="0-3%",
            min_value=0.0,
            max_value=3.0,
            reason="最小值大于最大值"
        )
        
        assert "无效的分布区间 '0-3%'" in str(error)
        assert "最小值大于最大值" in str(error)
        assert error.error_code == "INVALID_DISTRIBUTION_RANGE"
        assert error.range_name == "0-3%"
        assert error.min_value == 0.0
        assert error.max_value == 3.0
        assert "区间定义是否正确" in error.suggestions
    
    def test_insufficient_price_data_error(self):
        """测试价格数据不足异常"""
        error = InsufficientPriceDataError(
            trade_date="20240115",
            available_count=50,
            required_count=100,
            missing_fields=["pct_chg", "close"],
            reason="数据源暂时不可用"
        )
        
        assert "20240115" in str(error)
        assert "可用: 50, 需要: 100" in str(error)
        assert "pct_chg, close" in str(error)
        assert error.error_code == "INSUFFICIENT_PRICE_DATA"
        assert error.available_count == 50
        assert error.required_count == 100
        assert error.missing_fields == ["pct_chg", "close"]
        assert "确认该日期是否为交易日" in error.suggestions
    
    def test_distribution_calculation_error(self):
        """测试分布计算异常"""
        affected_stocks = ["000001.SZ", "000002.SZ", "600000.SH"]
        error = DistributionCalculationError(
            calculation_type="区间分类",
            reason="涨跌幅数据包含无效值",
            affected_stocks=affected_stocks,
            invalid_data_count=3
        )
        
        assert "分布计算失败 (区间分类)" in str(error)
        assert "涨跌幅数据包含无效值" in str(error)
        assert "无效数据: 3 条" in str(error)
        assert "000001.SZ, 000002.SZ, 600000.SH" in str(error)
        assert error.error_code == "DISTRIBUTION_CALCULATION_ERROR"
        assert error.affected_stocks == affected_stocks
        assert "股票价格数据的完整性" in error.suggestions
    
    def test_market_classification_error(self):
        """测试市场分类异常"""
        stock_codes = ["000001.SZ", "000002.SZ"]
        error = MarketClassificationError(
            stock_codes=stock_codes,
            classification_type="市场板块",
            reason="股票代码格式不正确",
            unclassified_count=2
        )
        
        assert "市场分类失败 (市场板块)" in str(error)
        assert "股票代码格式不正确" in str(error)
        assert "未分类股票: 2 只" in str(error)
        assert "000001.SZ, 000002.SZ" in str(error)
        assert error.error_code == "MARKET_CLASSIFICATION_ERROR"
        assert error.stock_codes == stock_codes
        assert "股票代码格式是否正确" in error.suggestions
    
    def test_price_distribution_cache_error(self):
        """测试缓存异常"""
        error = PriceDistributionCacheError(
            operation="get",
            cache_key="price_dist:20240115",
            reason="Redis连接超时",
            cache_layer="L2"
        )
        
        assert "缓存操作 'get' 失败" in str(error)
        assert "price_dist:20240115" in str(error)
        assert "Redis连接超时" in str(error)
        assert "缓存层: L2" in str(error)
        assert error.error_code == "PRICE_DISTRIBUTION_CACHE_ERROR"
        assert error.operation == "get"
        assert error.cache_key == "price_dist:20240115"
        assert "缓存服务是否正常运行" in error.suggestions
    
    def test_invalid_trade_date_error(self):
        """测试无效交易日期异常"""
        error = InvalidTradeDateError(
            date="2024-13-01",
            reason="月份超出有效范围",
            valid_format="2024-01-15"
        )
        
        assert "无效的交易日期 '2024-13-01'" in str(error)
        assert "月份超出有效范围" in str(error)
        assert "有效格式示例: 2024-01-15" in str(error)
        assert error.error_code == "INVALID_TRADE_DATE"
        assert error.invalid_date == "2024-13-01"
        assert "YYYY-MM-DD 或 YYYYMMDD" in error.suggestions
    
    def test_statistics_aggregation_error(self):
        """测试统计聚合异常"""
        market_segments = ["shanghai", "shenzhen"]
        data_inconsistency = {"total_mismatch": True, "percentage_error": True}
        
        error = StatisticsAggregationError(
            aggregation_type="市场板块聚合",
            reason="数据不一致",
            market_segments=market_segments,
            data_inconsistency=data_inconsistency
        )
        
        assert "统计聚合失败 (市场板块聚合)" in str(error)
        assert "数据不一致" in str(error)
        assert "shanghai, shenzhen" in str(error)
        assert "数据不一致项: 2 个" in str(error)
        assert error.error_code == "STATISTICS_AGGREGATION_ERROR"
        assert error.market_segments == market_segments
        assert "输入数据的完整性和一致性" in error.suggestions
    
    def test_price_distribution_validation_error(self):
        """测试验证异常"""
        error = PriceDistributionValidationError(
            validation_type="参数",
            field_name="trade_date",
            field_value="invalid_date",
            reason="日期格式不正确",
            expected_format="YYYYMMDD"
        )
        
        assert "参数验证失败，字段 'trade_date'" in str(error)
        assert "日期格式不正确" in str(error)
        assert "当前值: invalid_date" in str(error)
        assert "期望格式: YYYYMMDD" in str(error)
        assert error.error_code == "PRICE_DISTRIBUTION_VALIDATION_ERROR"
        assert error.field_name == "trade_date"
        assert "输入参数的格式和有效性" in error.suggestions
    
    def test_price_distribution_service_error(self):
        """测试服务异常"""
        service_state = {"cache_status": "disconnected", "data_source": "unavailable"}
        
        error = PriceDistributionServiceError(
            service_operation="get_price_distribution_stats",
            reason="多个依赖服务不可用",
            service_state=service_state,
            recovery_suggestion="请检查网络连接和服务状态"
        )
        
        assert "涨跌分布统计服务操作失败 (get_price_distribution_stats)" in str(error)
        assert "多个依赖服务不可用" in str(error)
        assert "cache_status: disconnected" in str(error)
        assert "data_source: unavailable" in str(error)
        assert error.error_code == "PRICE_DISTRIBUTION_SERVICE_ERROR"
        assert error.service_operation == "get_price_distribution_stats"
        assert "请检查网络连接和服务状态" in error.suggestions


class TestFaultTolerantPriceDistributionService:
    """测试容错的涨跌分布统计服务"""
    
    @pytest.fixture
    def mock_base_service(self):
        """创建模拟的基础服务"""
        service = Mock()
        service.get_price_distribution_stats = AsyncMock()
        service.get_stats_with_partial_data = AsyncMock()
        service.get_stats_with_simple_calculation = AsyncMock()
        service.get_stats_excluding_stocks = AsyncMock()
        service.get_stats_with_default_classification = AsyncMock()
        service.get_stats_without_cache = AsyncMock()
        service.get_stats_with_simple_aggregation = AsyncMock()
        service.get_stats_from_backup_source = AsyncMock()
        service.cache_manager = Mock()
        service.cache_manager.get_distribution_stats = AsyncMock()
        return service
    
    @pytest.fixture
    def fault_tolerant_service(self, mock_base_service):
        """创建容错服务实例"""
        return FaultTolerantPriceDistributionService(mock_base_service)
    
    @pytest.fixture
    def sample_request(self):
        """创建示例请求"""
        return PriceDistributionRequest(
            trade_date="20240115",
            include_st=True,
            market_filter=None,
            distribution_ranges=None,
            force_refresh=False,
            save_to_db=True
        )
    
    @pytest.fixture
    def sample_stats(self):
        """创建示例统计结果"""
        return PriceDistributionStats(
            trade_date="20240115",
            total_stocks=1000,
            positive_ranges={"0-3%": 300, "3-5%": 150},
            positive_percentages={"0-3%": 30.0, "3-5%": 15.0},
            negative_ranges={"0到-3%": 280, "-3到-5%": 120},
            negative_percentages={"0到-3%": 28.0, "-3到-5%": 12.0},
            market_breakdown={},
            data_quality_score=0.9,
            processing_time=1.0
        )
    
    @pytest.mark.asyncio
    async def test_successful_request(self, fault_tolerant_service, mock_base_service, 
                                    sample_request, sample_stats):
        """测试成功的请求"""
        mock_base_service.get_price_distribution_stats.return_value = sample_stats
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        assert result == sample_stats
        assert fault_tolerant_service.error_stats['successful_requests'] == 1
        assert fault_tolerant_service.error_stats['failed_requests'] == 0
        mock_base_service.get_price_distribution_stats.assert_called_once_with(sample_request)
    
    @pytest.mark.asyncio
    async def test_retry_mechanism(self, fault_tolerant_service, mock_base_service, 
                                 sample_request, sample_stats):
        """测试重试机制"""
        # 前两次调用失败，第三次成功
        mock_base_service.get_price_distribution_stats.side_effect = [
            InsufficientPriceDataError("20240115", 50, 100),
            InsufficientPriceDataError("20240115", 50, 100),
            sample_stats
        ]
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        assert result == sample_stats
        assert mock_base_service.get_price_distribution_stats.call_count == 3
    
    @pytest.mark.asyncio
    async def test_insufficient_data_recovery(self, fault_tolerant_service, mock_base_service, 
                                            sample_request, sample_stats):
        """测试数据不足错误的恢复"""
        # 主服务失败
        mock_base_service.get_price_distribution_stats.side_effect = InsufficientPriceDataError(
            "20240115", 50, 100
        )
        
        # 部分数据恢复成功
        partial_stats = sample_stats.copy() if hasattr(sample_stats, 'copy') else sample_stats
        partial_stats.data_quality_score = 0.7
        mock_base_service.get_stats_with_partial_data.return_value = partial_stats
        
        # 确保缓存返回None以避免使用缓存
        mock_base_service.cache_manager.get_distribution_stats.return_value = None
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        assert result == partial_stats
        assert fault_tolerant_service.error_stats['recovered_requests'] == 1
        mock_base_service.get_stats_with_partial_data.assert_called_once()
    
    @pytest.mark.asyncio
    async def test_calculation_error_recovery(self, fault_tolerant_service, mock_base_service, 
                                            sample_request, sample_stats):
        """测试计算错误的恢复"""
        # 主服务失败
        mock_base_service.get_price_distribution_stats.side_effect = DistributionCalculationError(
            "区间分类", "数据包含无效值", ["000001.SZ"]
        )
        
        # 简化计算恢复成功
        simplified_stats = sample_stats
        simplified_stats.data_quality_score = 0.8
        mock_base_service.get_stats_with_simple_calculation.return_value = simplified_stats
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        assert result == simplified_stats
        assert fault_tolerant_service.error_stats['recovered_requests'] == 1
        mock_base_service.get_stats_with_simple_calculation.assert_called_once()
    
    @pytest.mark.asyncio
    async def test_classification_error_recovery(self, fault_tolerant_service, mock_base_service, 
                                               sample_request, sample_stats):
        """测试分类错误的恢复"""
        # 主服务失败
        mock_base_service.get_price_distribution_stats.side_effect = MarketClassificationError(
            ["000001.SZ"], "市场板块", "代码格式错误"
        )
        
        # 默认分类恢复成功
        default_stats = sample_stats
        default_stats.data_quality_score = 0.8
        mock_base_service.get_stats_with_default_classification.return_value = default_stats
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        assert result == default_stats
        assert fault_tolerant_service.error_stats['recovered_requests'] == 1
        mock_base_service.get_stats_with_default_classification.assert_called_once()
    
    @pytest.mark.asyncio
    async def test_cache_error_recovery(self, fault_tolerant_service, mock_base_service, 
                                      sample_request, sample_stats):
        """测试缓存错误的恢复"""
        # 主服务失败
        mock_base_service.get_price_distribution_stats.side_effect = PriceDistributionCacheError(
            "get", "price_dist:20240115", "Redis连接超时"
        )
        
        # 跳过缓存恢复成功
        mock_base_service.get_stats_without_cache.return_value = sample_stats
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        assert result == sample_stats
        assert fault_tolerant_service.error_stats['recovered_requests'] == 1
        mock_base_service.get_stats_without_cache.assert_called_once()
    
    @pytest.mark.asyncio
    async def test_fallback_strategy(self, fault_tolerant_service, mock_base_service, 
                                   sample_request, sample_stats):
        """测试回退策略"""
        # 主服务和所有恢复策略都失败
        mock_base_service.get_price_distribution_stats.side_effect = Exception("服务不可用")
        
        # 模拟历史缓存数据
        cached_stats = sample_stats
        cached_stats.data_quality_score = 0.6
        mock_base_service.cache_manager.get_distribution_stats.return_value = cached_stats
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        assert result == cached_stats
        assert fault_tolerant_service.error_stats['fallback_used'] == 1
    
    @pytest.mark.asyncio
    async def test_complete_failure(self, fault_tolerant_service, mock_base_service, sample_request):
        """测试完全失败的情况"""
        # 所有策略都失败
        persistent_error = Exception("服务不可用")
        mock_base_service.get_price_distribution_stats.side_effect = persistent_error
        
        # 所有恢复方法都失败
        mock_base_service.get_stats_with_partial_data.side_effect = persistent_error
        mock_base_service.get_stats_with_simple_calculation.side_effect = persistent_error
        mock_base_service.get_stats_excluding_stocks.side_effect = persistent_error
        mock_base_service.get_stats_with_default_classification.side_effect = persistent_error
        mock_base_service.get_stats_without_cache.side_effect = persistent_error
        mock_base_service.get_stats_with_simple_aggregation.side_effect = persistent_error
        mock_base_service.get_stats_from_backup_source.side_effect = persistent_error
        
        # 缓存也返回None
        mock_base_service.cache_manager.get_distribution_stats.return_value = None
        
        with pytest.raises(PriceDistributionServiceError) as exc_info:
            await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
                sample_request
            )
        
        assert "所有恢复策略都失败" in str(exc_info.value)
        assert fault_tolerant_service.error_stats['failed_requests'] == 1
    
    @pytest.mark.asyncio
    async def test_poor_quality_recovery(self, fault_tolerant_service, mock_base_service, 
                                       sample_request, sample_stats):
        """测试低质量结果的恢复"""
        # 返回低质量结果
        poor_stats = sample_stats
        poor_stats.data_quality_score = 0.5  # 低于阈值
        mock_base_service.get_price_distribution_stats.return_value = poor_stats
        
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        # 应该返回改进后的结果
        assert result.data_quality_score >= 0.5
    
    def test_should_retry_error(self, fault_tolerant_service):
        """测试错误重试判断"""
        # 应该重试的错误
        assert fault_tolerant_service._should_retry_error(ConnectionError("网络错误"))
        assert fault_tolerant_service._should_retry_error(PriceDistributionCacheError("get", "key"))
        assert fault_tolerant_service._should_retry_error(DistributionCalculationError("计算", "临时错误"))
        
        # 不应该重试的错误
        assert not fault_tolerant_service._should_retry_error(InvalidTradeDateError("invalid"))
        assert not fault_tolerant_service._should_retry_error(PriceDistributionValidationError("参数", "field"))
    
    def test_validate_result_quality(self, fault_tolerant_service, sample_stats):
        """测试结果质量验证"""
        # 高质量结果
        assert fault_tolerant_service._validate_result_quality(sample_stats)
        
        # 低质量结果
        poor_stats = sample_stats
        poor_stats.data_quality_score = 0.5  # 低于阈值
        assert not fault_tolerant_service._validate_result_quality(poor_stats)
        
        # 股票数量不足
        insufficient_stats = sample_stats
        insufficient_stats.total_stocks = 50  # 低于阈值
        assert not fault_tolerant_service._validate_result_quality(insufficient_stats)
        
        # 空结果
        assert not fault_tolerant_service._validate_result_quality(None)
    
    def test_get_adjacent_trade_dates(self, fault_tolerant_service):
        """测试获取相邻交易日期"""
        dates = fault_tolerant_service._get_adjacent_trade_dates("20240115", 2)
        
        assert isinstance(dates, list)
        assert len(dates) > 0
        # 应该包含前后的日期（2024年1月15日是周一，所以前面应该有周五的日期）
        # 检查是否包含相邻的工作日
        date_strings = [str(d) for d in dates]
        assert any("20240112" in date_str for date_str in date_strings)  # 周五
        assert any("20240116" in date_str for date_str in date_strings)  # 周二
    
    def test_error_stats_tracking(self, fault_tolerant_service):
        """测试错误统计跟踪"""
        # 初始状态
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['total_requests'] == 0
        assert stats['success_rate'] == 0.0
        
        # 更新统计
        fault_tolerant_service.error_stats['total_requests'] = 10
        fault_tolerant_service.error_stats['successful_requests'] = 8
        fault_tolerant_service.error_stats['recovered_requests'] = 1
        fault_tolerant_service.error_stats['failed_requests'] = 1
        
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['success_rate'] == 0.8
        assert stats['recovery_rate'] == 0.1
        assert stats['failure_rate'] == 0.1
    
    def test_reset_stats(self, fault_tolerant_service):
        """测试重置统计"""
        # 设置一些统计数据
        fault_tolerant_service.error_stats['total_requests'] = 10
        fault_tolerant_service.error_stats['error_types']['TestError'] = 5
        
        # 重置
        fault_tolerant_service.reset_fault_tolerance_stats()
        
        # 验证重置
        assert fault_tolerant_service.error_stats['total_requests'] == 0
        assert fault_tolerant_service.error_stats['error_types'] == {}
        assert len(fault_tolerant_service._fallback_cache) == 0


class TestErrorRecoveryStrategies:
    """测试错误恢复策略"""
    
    @pytest.fixture
    def fault_tolerant_service(self):
        """创建容错服务实例"""
        mock_service = Mock()
        return FaultTolerantPriceDistributionService(mock_service)
    
    @pytest.fixture
    def sample_request(self):
        """创建示例请求"""
        return PriceDistributionRequest(
            trade_date="20240115",
            include_st=True
        )
    
    @pytest.mark.asyncio
    async def test_create_empty_distribution_stats(self, fault_tolerant_service, sample_request):
        """测试创建空统计结构"""
        empty_stats = fault_tolerant_service._create_empty_distribution_stats(sample_request)
        
        assert empty_stats.trade_date == "20240115"
        assert empty_stats.total_stocks == 0
        assert empty_stats.data_quality_score == 0.0
        assert empty_stats.positive_ranges == {}
        assert empty_stats.negative_ranges == {}
    
    @pytest.mark.asyncio
    async def test_clean_and_correct_result(self, fault_tolerant_service):
        """测试结果清理和修正"""
        # 创建有问题的结果（绕过验证）
        problematic_stats = PriceDistributionStats.__new__(PriceDistributionStats)
        problematic_stats.trade_date = "20240115"
        problematic_stats.total_stocks = 100
        problematic_stats.positive_ranges = {"0-3%": 30, "3-5%": 20}
        problematic_stats.positive_percentages = {"0-3%": 150.0, "3-5%": 80.0}  # 百分比总和超过100%
        problematic_stats.negative_ranges = {"0到-3%": 25, "-3到-5%": 15}
        problematic_stats.negative_percentages = {"0到-3%": 25.0, "-3到-5%": 15.0}
        problematic_stats.market_breakdown = {}
        problematic_stats.data_quality_score = 0.6
        problematic_stats.processing_time = 1.0
        problematic_stats.created_at = None
        problematic_stats.updated_at = None
        
        cleaned_stats = await fault_tolerant_service._clean_and_correct_result(problematic_stats)
        
        assert cleaned_stats is not None
        assert cleaned_stats.data_quality_score > problematic_stats.data_quality_score
        
        # 验证百分比修正
        total_positive = sum(cleaned_stats.positive_percentages.values())
        assert total_positive <= 100.0
    
    def test_update_error_stats(self, fault_tolerant_service):
        """测试错误统计更新"""
        error = InvalidTradeDateError("invalid_date")
        
        fault_tolerant_service._update_error_stats(error)
        
        assert fault_tolerant_service.error_stats['error_types']['InvalidTradeDateError'] == 1
        
        # 再次更新同类型错误
        fault_tolerant_service._update_error_stats(error)
        assert fault_tolerant_service.error_stats['error_types']['InvalidTradeDateError'] == 2


if __name__ == "__main__":
    pytest.main([__file__, "-v"])