"""
涨跌分布统计容错机制集成测试

测试涨跌分布统计功能的容错处理和错误恢复的集成场景
"""

import pytest
import asyncio
import time
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from typing import Dict, Any, List
import pandas as pd

from quickstock.core.price_distribution_errors import (
    PriceDistributionError,
    InsufficientPriceDataError,
    DistributionCalculationError,
    MarketClassificationError,
    PriceDistributionCacheError,
    PriceDistributionServiceError
)
from quickstock.core.price_distribution_fault_tolerance import FaultTolerantPriceDistributionService
from quickstock.models.price_distribution_models import PriceDistributionStats, PriceDistributionRequest
from quickstock.core.errors import DataSourceError, NetworkError, CacheError


class TestFaultToleranceIntegration:
    """容错机制集成测试"""
    
    @pytest.fixture
    def mock_data_manager(self):
        """模拟数据管理器"""
        manager = Mock()
        manager.get_stock_daily_data = AsyncMock()
        return manager
    
    @pytest.fixture
    def mock_stock_classifier(self):
        """模拟股票分类器"""
        classifier = Mock()
        classifier.classify_stocks = Mock()
        return classifier
    
    @pytest.fixture
    def mock_cache_manager(self):
        """模拟缓存管理器"""
        cache = Mock()
        cache.get_distribution_stats = AsyncMock()
        cache.set_distribution_stats = AsyncMock()
        cache.delete_distribution_stats = AsyncMock()
        return cache
    
    @pytest.fixture
    def mock_base_service(self, mock_data_manager, mock_stock_classifier, mock_cache_manager):
        """模拟基础服务"""
        service = Mock()
        service.data_manager = mock_data_manager
        service.stock_classifier = mock_stock_classifier
        service.cache_manager = mock_cache_manager
        
        # 主要方法
        service.get_price_distribution_stats = AsyncMock()
        
        # 恢复策略方法
        service.get_stats_with_partial_data = AsyncMock()
        service.get_stats_with_simple_calculation = AsyncMock()
        service.get_stats_excluding_stocks = AsyncMock()
        service.get_stats_with_default_classification = AsyncMock()
        service.get_stats_without_cache = AsyncMock()
        service.get_stats_with_simple_aggregation = AsyncMock()
        service.get_stats_from_backup_source = AsyncMock()
        service.recalculate_partial_stats = AsyncMock()
        
        return service
    
    @pytest.fixture
    def fault_tolerant_service(self, mock_base_service):
        """创建容错服务"""
        service = FaultTolerantPriceDistributionService(mock_base_service)
        # 设置较短的重试延迟以加快测试
        service.fault_tolerance_config['retry_delay'] = 0.01
        service.fault_tolerance_config['max_delay'] = 0.1
        return service
    
    @pytest.fixture
    def sample_request(self):
        """示例请求"""
        return PriceDistributionRequest(
            trade_date="20240115",
            include_st=True,
            market_filter=None,
            distribution_ranges=None,
            force_refresh=False,
            save_to_db=True
        )
    
    @pytest.fixture
    def sample_stock_data(self):
        """示例股票数据"""
        return pd.DataFrame({
            'ts_code': ['000001.SZ', '000002.SZ', '600000.SH', '600001.SH'],
            'trade_date': ['20240115'] * 4,
            'open': [10.0, 20.0, 30.0, 40.0],
            'close': [10.5, 19.5, 31.0, 39.0],
            'pct_chg': [5.0, -2.5, 3.33, -2.5]
        })
    
    @pytest.fixture
    def sample_stats(self):
        """示例统计结果"""
        return PriceDistributionStats(
            trade_date="20240115",
            total_stocks=1000,
            positive_ranges={
                "0-3%": 300,
                "3-5%": 150,
                "5-7%": 100,
                "7-10%": 50,
                ">=10%": 20
            },
            positive_percentages={
                "0-3%": 30.0,
                "3-5%": 15.0,
                "5-7%": 10.0,
                "7-10%": 5.0,
                ">=10%": 2.0
            },
            negative_ranges={
                "0到-3%": 280,
                "-3到-5%": 120,
                "-5到-7%": 80,
                "-7到-10%": 40,
                "<=-10%": 10
            },
            negative_percentages={
                "0到-3%": 28.0,
                "-3到-5%": 12.0,
                "-5到-7%": 8.0,
                "-7到-10%": 4.0,
                "<=-10%": 1.0
            },
            market_breakdown={
                "total": {"positive": 620, "negative": 530},
                "shanghai": {"positive": 310, "negative": 265},
                "shenzhen": {"positive": 310, "negative": 265}
            },
            data_quality_score=0.95,
            processing_time=1.5
        )
    
    @pytest.mark.asyncio
    async def test_end_to_end_success_scenario(self, fault_tolerant_service, mock_base_service, 
                                             sample_request, sample_stats):
        """测试端到端成功场景"""
        # 配置成功响应
        mock_base_service.get_price_distribution_stats.return_value = sample_stats
        
        # 执行请求
        start_time = time.time()
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        execution_time = time.time() - start_time
        
        # 验证结果
        assert result == sample_stats
        assert result.data_quality_score >= 0.8
        
        # 验证统计
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['total_requests'] == 1
        assert stats['successful_requests'] == 1
        assert stats['failed_requests'] == 0
        assert stats['success_rate'] == 1.0
        
        # 验证调用
        mock_base_service.get_price_distribution_stats.assert_called_once_with(sample_request)
        
        # 验证执行时间合理
        assert execution_time < 5.0  # 应该很快完成
    
    @pytest.mark.asyncio
    async def test_network_error_recovery_chain(self, fault_tolerant_service, mock_base_service, 
                                              sample_request, sample_stats):
        """测试网络错误的完整恢复链"""
        # 模拟网络错误序列：网络错误 -> 网络错误 -> 成功
        mock_base_service.get_price_distribution_stats.side_effect = [
            NetworkError("连接超时", "NETWORK_TIMEOUT"),
            NetworkError("DNS解析失败", "DNS_ERROR"),
            sample_stats
        ]
        
        # 执行请求
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        # 验证最终成功
        assert result == sample_stats
        
        # 验证重试次数
        assert mock_base_service.get_price_distribution_stats.call_count == 3
        
        # 验证统计
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['successful_requests'] == 1
        assert stats['error_types']['NetworkError'] == 2
    
    @pytest.mark.asyncio
    async def test_data_insufficient_recovery_cascade(self, fault_tolerant_service, mock_base_service, 
                                                    sample_request, sample_stats):
        """测试数据不足错误的级联恢复"""
        # 主服务失败
        insufficient_error = InsufficientPriceDataError(
            trade_date="20240115",
            available_count=50,
            required_count=100,
            missing_fields=["pct_chg"]
        )
        mock_base_service.get_price_distribution_stats.side_effect = insufficient_error
        
        # 第一级恢复：部分数据恢复也失败
        mock_base_service.get_stats_with_partial_data.side_effect = Exception("部分数据处理失败")
        
        # 第二级恢复：历史缓存成功
        cached_stats = sample_stats
        cached_stats.data_quality_score = 0.6
        mock_base_service.cache_manager.get_distribution_stats.return_value = cached_stats
        
        # 执行请求
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        # 验证使用了回退策略
        assert result == cached_stats
        assert result.data_quality_score == 0.6
        
        # 验证统计
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['fallback_used'] == 1
        assert stats['error_types']['InsufficientPriceDataError'] == 1
    
    @pytest.mark.asyncio
    async def test_calculation_error_multi_recovery(self, fault_tolerant_service, mock_base_service, 
                                                  sample_request, sample_stats):
        """测试计算错误的多重恢复策略"""
        # 主服务计算错误
        calc_error = DistributionCalculationError(
            calculation_type="区间分类",
            reason="涨跌幅数据包含NaN值",
            affected_stocks=["000001.SZ", "000002.SZ"],
            invalid_data_count=2
        )
        mock_base_service.get_price_distribution_stats.side_effect = calc_error
        
        # 第一级恢复：简化计算失败
        mock_base_service.get_stats_with_simple_calculation.side_effect = Exception("简化计算也失败")
        
        # 第二级恢复：排除问题股票成功
        recovered_stats = sample_stats
        recovered_stats.data_quality_score = 0.85
        recovered_stats.total_stocks = 998  # 排除了2只股票
        mock_base_service.get_stats_excluding_stocks.return_value = recovered_stats
        
        # 执行请求
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        # 验证恢复成功
        assert result == recovered_stats
        assert result.total_stocks == 998
        
        # 验证调用了排除股票的方法
        mock_base_service.get_stats_excluding_stocks.assert_called_once_with(
            sample_request, exclude_stocks=["000001.SZ", "000002.SZ"]
        )
        
        # 验证统计
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['recovered_requests'] == 1
    
    @pytest.mark.asyncio
    async def test_cache_error_bypass_strategy(self, fault_tolerant_service, mock_base_service, 
                                             sample_request, sample_stats):
        """测试缓存错误的绕过策略"""
        # 缓存错误
        cache_error = PriceDistributionCacheError(
            operation="get",
            cache_key="price_dist:20240115",
            reason="Redis连接断开",
            cache_layer="L2"
        )
        mock_base_service.get_price_distribution_stats.side_effect = cache_error
        
        # 跳过缓存成功
        mock_base_service.get_stats_without_cache.return_value = sample_stats
        
        # 执行请求
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        # 验证绕过缓存成功
        assert result == sample_stats
        
        # 验证调用了无缓存方法
        mock_base_service.get_stats_without_cache.assert_called_once()
        
        # 验证统计
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['recovered_requests'] == 1
        assert stats['error_types']['PriceDistributionCacheError'] == 1
    
    @pytest.mark.asyncio
    async def test_classification_error_fallback_chain(self, fault_tolerant_service, mock_base_service, 
                                                     sample_request, sample_stats):
        """测试分类错误的回退链"""
        # 分类错误
        classification_error = MarketClassificationError(
            stock_codes=["000001.SZ", "000002.SZ"],
            classification_type="市场板块",
            reason="新股票代码无法识别",
            unclassified_count=2
        )
        mock_base_service.get_price_distribution_stats.side_effect = classification_error
        
        # 第一级恢复：默认分类失败
        mock_base_service.get_stats_with_default_classification.side_effect = Exception("默认分类失败")
        
        # 第二级恢复：简化分类成功
        simplified_stats = sample_stats
        simplified_stats.data_quality_score = 0.7
        # 模拟简化的市场分类（只有total和non_st）
        simplified_stats.market_breakdown = {
            "total": {"positive": 620, "negative": 530},
            "non_st": {"positive": 580, "negative": 490}
        }
        
        # 通过修改请求来模拟简化分类
        async def mock_simplified_classification(request):
            if request.market_filter == ['total', 'non_st']:
                return simplified_stats
            raise Exception("未预期的请求")
        
        mock_base_service.get_price_distribution_stats.side_effect = [
            classification_error,  # 第一次调用失败
            simplified_stats       # 第二次调用（简化请求）成功
        ]
        
        # 执行请求
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        # 验证使用了简化分类
        assert result == simplified_stats
        assert result.data_quality_score == 0.7
        assert "shanghai" not in result.market_breakdown
        assert "shenzhen" not in result.market_breakdown
    
    @pytest.mark.asyncio
    async def test_complete_failure_scenario(self, fault_tolerant_service, mock_base_service, 
                                           sample_request):
        """测试完全失败场景"""
        # 所有策略都失败
        persistent_error = DataSourceError("数据源完全不可用", "DATA_SOURCE_DOWN")
        
        # 主服务失败
        mock_base_service.get_price_distribution_stats.side_effect = persistent_error
        
        # 所有恢复策略都失败
        mock_base_service.get_stats_with_partial_data.side_effect = persistent_error
        mock_base_service.get_stats_with_simple_calculation.side_effect = persistent_error
        mock_base_service.get_stats_excluding_stocks.side_effect = persistent_error
        mock_base_service.get_stats_with_default_classification.side_effect = persistent_error
        mock_base_service.get_stats_without_cache.side_effect = persistent_error
        mock_base_service.get_stats_with_simple_aggregation.side_effect = persistent_error
        mock_base_service.get_stats_from_backup_source.side_effect = persistent_error
        
        # 回退策略也失败
        mock_base_service.cache_manager.get_distribution_stats.return_value = None
        
        # 执行请求并期望失败
        with pytest.raises(PriceDistributionServiceError) as exc_info:
            await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
                sample_request
            )
        
        # 验证异常信息
        assert "所有恢复策略都失败" in str(exc_info.value)
        assert exc_info.value.service_operation == "get_price_distribution_stats"
        assert "数据源完全不可用" in exc_info.value.reason
        
        # 验证统计
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['failed_requests'] == 1
        assert stats['error_types']['DataSourceError'] > 0
    
    @pytest.mark.asyncio
    async def test_quality_improvement_workflow(self, fault_tolerant_service, mock_base_service, 
                                              sample_request, sample_stats):
        """测试质量改进工作流"""
        # 返回低质量结果
        poor_quality_stats = sample_stats
        poor_quality_stats.data_quality_score = 0.5  # 低于阈值(0.8)
        poor_quality_stats.total_stocks = 50  # 低于最小阈值(100)
        
        mock_base_service.get_price_distribution_stats.return_value = poor_quality_stats
        
        # 模拟重新计算改进质量
        improved_stats = sample_stats
        improved_stats.data_quality_score = 0.85
        improved_stats.total_stocks = 950
        mock_base_service.recalculate_partial_stats.return_value = improved_stats
        
        # 执行请求
        result = await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
            sample_request
        )
        
        # 验证质量得到改进
        assert result.data_quality_score >= 0.8
        assert result.total_stocks >= 100
    
    @pytest.mark.asyncio
    async def test_concurrent_requests_handling(self, fault_tolerant_service, mock_base_service, 
                                              sample_stats):
        """测试并发请求处理"""
        # 创建多个请求
        requests = [
            PriceDistributionRequest(trade_date=f"2024011{i}", include_st=True)
            for i in range(5, 10)  # 20240115 到 20240119
        ]
        
        # 配置不同的响应
        responses = []
        for i, req in enumerate(requests):
            stats = PriceDistributionStats(
                trade_date=req.trade_date,
                total_stocks=1000 + i * 10,
                positive_ranges={"0-3%": 300 + i * 5},
                positive_percentages={"0-3%": 30.0 + i * 0.5},
                negative_ranges={"0到-3%": 280 + i * 5},
                negative_percentages={"0到-3%": 28.0 + i * 0.5},
                market_breakdown={},
                data_quality_score=0.9,
                processing_time=0.5
            )
            responses.append(stats)
        
        mock_base_service.get_price_distribution_stats.side_effect = responses
        
        # 并发执行请求
        tasks = [
            fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(req)
            for req in requests
        ]
        
        results = await asyncio.gather(*tasks)
        
        # 验证所有请求都成功
        assert len(results) == 5
        for i, result in enumerate(results):
            assert result.trade_date == requests[i].trade_date
            assert result.total_stocks == 1000 + i * 10
        
        # 验证统计
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['total_requests'] == 5
        assert stats['successful_requests'] == 5
        assert stats['success_rate'] == 1.0
    
    @pytest.mark.asyncio
    async def test_performance_under_stress(self, fault_tolerant_service, mock_base_service, 
                                          sample_request, sample_stats):
        """测试压力下的性能"""
        # 模拟慢响应
        async def slow_response(*args, **kwargs):
            await asyncio.sleep(0.1)  # 100ms延迟
            return sample_stats
        
        mock_base_service.get_price_distribution_stats.side_effect = slow_response
        
        # 执行多次请求并测量时间
        start_time = time.time()
        
        tasks = []
        for i in range(10):
            req = PriceDistributionRequest(
                trade_date=f"20240{i+10:02d}",
                include_st=True
            )
            task = fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(req)
            tasks.append(task)
        
        results = await asyncio.gather(*tasks)
        
        total_time = time.time() - start_time
        
        # 验证并发执行效率（应该比串行快）
        assert total_time < 2.0  # 并发执行应该在2秒内完成
        assert len(results) == 10
        
        # 验证所有请求都成功
        stats = fault_tolerant_service.get_fault_tolerance_stats()
        assert stats['successful_requests'] == 10
    
    def test_configuration_customization(self, mock_base_service):
        """测试配置自定义"""
        # 创建自定义配置的服务
        service = FaultTolerantPriceDistributionService(mock_base_service)
        
        # 修改配置
        service.fault_tolerance_config.update({
            'max_retries': 5,
            'retry_delay': 2.0,
            'data_quality_threshold': 0.9,
            'min_stock_count_threshold': 200,
            'enable_fallback': False
        })
        
        # 验证配置生效
        assert service.fault_tolerance_config['max_retries'] == 5
        assert service.fault_tolerance_config['retry_delay'] == 2.0
        assert service.fault_tolerance_config['data_quality_threshold'] == 0.9
        assert service.fault_tolerance_config['min_stock_count_threshold'] == 200
        assert service.fault_tolerance_config['enable_fallback'] is False
    
    @pytest.mark.asyncio
    async def test_error_context_preservation(self, fault_tolerant_service, mock_base_service, 
                                            sample_request):
        """测试错误上下文保持"""
        # 创建带详细信息的错误
        detailed_error = InsufficientPriceDataError(
            trade_date="20240115",
            available_count=30,
            required_count=100,
            missing_fields=["pct_chg", "volume"],
            reason="数据源维护中"
        )
        
        # 所有恢复策略都失败
        mock_base_service.get_price_distribution_stats.side_effect = detailed_error
        mock_base_service.get_stats_with_partial_data.side_effect = Exception("恢复失败")
        mock_base_service.cache_manager.get_distribution_stats.return_value = None
        
        # 执行并捕获最终错误
        with pytest.raises(PriceDistributionServiceError) as exc_info:
            await fault_tolerant_service.get_price_distribution_stats_with_fault_tolerance(
                sample_request
            )
        
        # 验证原始错误信息被保持
        final_error = exc_info.value
        assert "数据源维护中" in final_error.reason
        assert final_error.service_operation == "get_price_distribution_stats"
        assert "execution_time" in final_error.service_state
        assert "retry_attempts" in final_error.service_state


if __name__ == "__main__":
    pytest.main([__file__, "-v"])