"""
涨停检测器全面测试

为LimitUpDetector提供全面的单元测试覆盖，包括各种价格场景、边界情况和性能测试
"""

import pytest
import math
from unittest.mock import Mock, patch, MagicMock
from quickstock.utils.limit_up_detector import (
    LimitUpDetector,
    LimitUpDetectionResult,
    LimitUpDetectionError,
    InsufficientPriceDataError,
    InvalidPriceDataError,
    detect_limit_up,
    calculate_limit_up_price,
    get_limit_up_threshold
)
from quickstock.models import StockDailyData, LIMIT_UP_THRESHOLDS


class TestLimitUpDetectorComprehensive:
    """涨停检测器全面测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.detector = LimitUpDetector()
    
    def test_init_with_all_parameters(self):
        """测试所有参数的初始化"""
        logger = Mock()
        detector = LimitUpDetector(
            price_tolerance=0.01,
            logger=logger
        )
        assert detector.price_tolerance == 0.01
        assert detector.logger is logger
    
    def test_all_stock_types_thresholds(self):
        """测试所有股票类型的涨停阈值"""
        test_cases = [
            ('normal', 0.10),
            ('st', 0.05),
            ('star', 0.20),
            ('beijing', 0.30),
            ('new_stock', 0.44),
            ('kcb_new', 0.44),  # 科创板新股
            ('cyb_new', 0.20),  # 创业板新股
        ]
        
        for stock_type, expected_threshold in test_cases:
            if stock_type in LIMIT_UP_THRESHOLDS:
                threshold = self.detector.get_limit_up_threshold(stock_type)
                assert threshold == expected_threshold, f"Failed for {stock_type}"
    
    def test_limit_up_price_calculation_precision(self):
        """测试涨停价格计算精度"""
        test_cases = [
            # (前收盘价, 股票类型, 期望涨停价)
            (10.00, 'normal', 11.00),
            (10.01, 'normal', 11.01),
            (9.99, 'normal', 10.99),
            (5.00, 'st', 5.25),
            (5.01, 'st', 5.26),
            (20.00, 'star', 24.00),
            (15.00, 'beijing', 19.50),
            (10.00, 'new_stock', 14.40),
        ]
        
        for prev_close, stock_type, expected in test_cases:
            result = self.detector.calculate_limit_up_price(prev_close, stock_type)
            assert abs(result - expected) < 0.01, f"Failed for {prev_close}, {stock_type}: got {result}, expected {expected}"
    
    def test_limit_up_detection_various_scenarios(self):
        """测试各种涨停检测场景"""
        # 完美涨停
        assert self.detector.is_limit_up(10.0, 11.0, 11.0, 10.0, 'normal') is True
        
        # 在容差范围内的涨停
        assert self.detector.is_limit_up(10.0, 10.995, 10.995, 10.0, 'normal') is True
        
        # 超出容差范围
        assert self.detector.is_limit_up(10.0, 10.98, 10.98, 10.0, 'normal') is False
        
        # 收盘价等于涨停价但不等于最高价
        assert self.detector.is_limit_up(10.0, 11.0, 11.1, 10.0, 'normal') is False
        
        # ST股票涨停
        assert self.detector.is_limit_up(10.0, 10.5, 10.5, 10.0, 'st') is True
        
        # 科创板涨停
        assert self.detector.is_limit_up(20.0, 24.0, 24.0, 20.0, 'star') is True
        
        # 北证涨停
        assert self.detector.is_limit_up(15.0, 19.5, 19.5, 15.0, 'beijing') is True
    
    def test_price_data_validation_comprehensive(self):
        """测试价格数据验证的全面场景"""
        # 有效数据
        self.detector._validate_price_data(10.0, 11.0, 11.5, 9.5, "000001.SZ")
        
        # 无效数据测试
        invalid_cases = [
            (-1.0, 11.0, 11.5, 9.5),  # 负开盘价
            (10.0, -1.0, 11.5, 9.5),  # 负收盘价
            (10.0, 11.0, -1.0, 9.5),  # 负最高价
            (10.0, 11.0, 11.5, -1.0), # 负最低价
            (0.0, 11.0, 11.5, 9.5),   # 零开盘价
            (10.0, 12.0, 11.0, 9.5),  # 收盘价高于最高价
            (12.0, 11.0, 11.0, 9.5),  # 开盘价高于最高价
            (10.0, 11.0, 11.0, 12.0), # 最低价高于最高价
        ]
        
        for open_p, close_p, high_p, low_p in invalid_cases:
            with pytest.raises(InvalidPriceDataError):
                self.detector._validate_price_data(open_p, close_p, high_p, low_p, "000001.SZ")
    
    def test_stock_type_inference_comprehensive(self):
        """测试股票类型推断的全面场景"""
        test_cases = [
            # (股票代码, 股票名称, 期望类型)
            ('000001.SZ', '平安银行', 'normal'),
            ('000002.SZ', 'ST万科', 'st'),
            ('000003.SZ', '*ST海润', 'st'),
            ('688001.SH', '华兴源创', 'star'),
            ('688002.SH', 'ST科创', 'st'),  # 科创板ST股票
            ('430001.BJ', '北交所股票', 'beijing'),
            ('600000.SH', '浦发银行', 'normal'),
            ('300001.SZ', '特锐德', 'normal'),  # 创业板
            ('200001.SZ', 'B股股票', 'normal'),  # B股
        ]
        
        for code, name, expected_type in test_cases:
            inferred_type = self.detector._infer_stock_type_from_code(code, name)
            assert inferred_type == expected_type, f"Failed for {code}, {name}: got {inferred_type}, expected {expected_type}"
    
    def test_batch_detection_comprehensive(self):
        """测试批量检测的全面场景"""
        # 创建多样化的测试数据
        stock_data_list = [
            # 普通涨停股票
            StockDailyData(
                ts_code="000001.SZ", trade_date="20241015",
                open=10.0, high=11.0, low=10.0, close=11.0,
                pre_close=10.0, change=1.0, pct_chg=10.0,
                vol=1000000, amount=10500000.0, name="平安银行"
            ),
            # ST涨停股票
            StockDailyData(
                ts_code="000002.SZ", trade_date="20241015",
                open=10.0, high=10.5, low=10.0, close=10.5,
                pre_close=10.0, change=0.5, pct_chg=5.0,
                vol=800000, amount=8200000.0, name="ST万科"
            ),
            # 科创板涨停股票
            StockDailyData(
                ts_code="688001.SH", trade_date="20241015",
                open=20.0, high=24.0, low=20.0, close=24.0,
                pre_close=20.0, change=4.0, pct_chg=20.0,
                vol=500000, amount=11000000.0, name="华兴源创"
            ),
            # 非涨停股票
            StockDailyData(
                ts_code="600000.SH", trade_date="20241015",
                open=8.0, high=8.5, low=8.0, close=8.3,
                pre_close=8.0, change=0.3, pct_chg=3.75,
                vol=2000000, amount=16600000.0, name="浦发银行"
            ),
        ]
        
        results = self.detector.batch_detect_limit_up(stock_data_list)
        
        # 验证结果
        assert len(results) == 4
        assert all(isinstance(result, LimitUpDetectionResult) for result in results)
        
        # 验证涨停检测结果
        limit_up_results = [r for r in results if r.is_limit_up]
        assert len(limit_up_results) == 3  # 前三只应该是涨停
        
        # 验证股票代码
        limit_up_codes = [r.ts_code for r in limit_up_results]
        assert "000001.SZ" in limit_up_codes
        assert "000002.SZ" in limit_up_codes
        assert "688001.SH" in limit_up_codes
        assert "600000.SH" not in limit_up_codes
    
    def test_detection_confidence_calculation(self):
        """测试检测置信度计算"""
        # 高置信度：完美涨停
        confidence = self.detector._calculate_detection_confidence(
            close_price=11.0,
            limit_up_price=11.0,
            high_price=11.0,
            price_diff=0.0,
            is_within_tolerance=True,
            close_equals_high=True
        )
        assert confidence >= 0.95
        
        # 中等置信度：在容差范围内
        confidence = self.detector._calculate_detection_confidence(
            close_price=10.995,
            limit_up_price=11.0,
            high_price=10.995,
            price_diff=0.005,
            is_within_tolerance=True,
            close_equals_high=True
        )
        assert 0.8 <= confidence < 0.95
        
        # 低置信度：不符合涨停条件
        confidence = self.detector._calculate_detection_confidence(
            close_price=10.5,
            limit_up_price=11.0,
            high_price=11.0,
            price_diff=0.5,
            is_within_tolerance=False,
            close_equals_high=False
        )
        assert confidence == 0.0
    
    def test_price_precision_handling(self):
        """测试价格精度处理"""
        # 测试价格四舍五入
        assert self.detector._round_price(10.123) == 10.12
        assert self.detector._round_price(10.126) == 10.13
        assert self.detector._round_price(10.125) == 10.13  # 银行家舍入
        
        # 测试浮点精度问题
        prev_close = 10.00
        calculated_limit = prev_close * 1.10  # 可能不完全等于11.00
        
        is_limit_up = self.detector.is_limit_up(
            open_price=10.5,
            close_price=calculated_limit,
            high_price=calculated_limit,
            prev_close=prev_close,
            stock_type='normal'
        )
        assert is_limit_up is True
    
    def test_edge_cases_comprehensive(self):
        """测试全面的边界情况"""
        # 极小价格
        assert self.detector.is_limit_up(0.01, 0.011, 0.011, 0.01, 'normal') is True
        
        # 极大价格
        assert self.detector.is_limit_up(1000.0, 1100.0, 1100.0, 1000.0, 'normal') is True
        
        # 价格精度边界
        prev_close = 10.005
        limit_price = round(prev_close * 1.10, 2)
        assert self.detector.is_limit_up(10.0, limit_price, limit_price, prev_close, 'normal') is True
        
        # 容差边界测试
        tolerance = self.detector.price_tolerance
        prev_close = 10.0
        limit_price = 11.0
        
        # 刚好在容差内
        close_price = limit_price - tolerance + 0.0001
        assert self.detector.is_limit_up(10.0, close_price, close_price, prev_close, 'normal') is True
        
        # 刚好超出容差
        close_price = limit_price - tolerance - 0.0001
        assert self.detector.is_limit_up(10.0, close_price, close_price, prev_close, 'normal') is False
    
    def test_data_completeness_validation(self):
        """测试数据完整性验证"""
        # 完整数据
        complete_stock = StockDailyData(
            ts_code="000001.SZ", trade_date="20241015",
            open=10.0, high=11.0, low=9.8, close=11.0,
            pre_close=10.0, change=1.0, pct_chg=10.0,
            vol=1000000, amount=10500000.0, name="平安银行"
        )
        
        result = self.detector.validate_price_data_completeness(complete_stock)
        assert result['is_complete'] is True
        assert len(result['missing_fields']) == 0
        
        # 缺失字段数据
        incomplete_stock = StockDailyData(
            ts_code="000001.SZ", trade_date="20241015",
            open=None, high=11.0, low=9.8, close=11.0,
            pre_close=10.0, change=1.0, pct_chg=10.0,
            vol=1000000, amount=10500000.0, name="平安银行"
        )
        
        result = self.detector.validate_price_data_completeness(incomplete_stock)
        assert result['is_complete'] is False
        assert 'open' in result['missing_fields']
    
    def test_performance_optimization(self):
        """测试性能优化"""
        import time
        
        # 创建大量测试数据
        stock_data_list = []
        for i in range(1000):
            stock_data_list.append(StockDailyData(
                ts_code=f"{i:06d}.SZ", trade_date="20241015",
                open=10.0, high=11.0, low=10.0, close=11.0,
                pre_close=10.0, change=1.0, pct_chg=10.0,
                vol=1000000, amount=10500000.0, name=f"股票{i}"
            ))
        
        # 测试批量检测性能
        start_time = time.time()
        results = self.detector.batch_detect_limit_up(stock_data_list)
        end_time = time.time()
        
        # 验证结果
        assert len(results) == 1000
        
        # 性能要求：1000只股票的检测应该在合理时间内完成
        processing_time = end_time - start_time
        assert processing_time < 5.0  # 5秒内完成
        
        # 平均每只股票的处理时间
        avg_time_per_stock = processing_time / 1000
        assert avg_time_per_stock < 0.005  # 每只股票5毫秒内
    
    def test_error_recovery_mechanisms(self):
        """测试错误恢复机制"""
        # 创建包含错误数据的股票列表
        stock_data_list = [
            # 正常数据
            StockDailyData(
                ts_code="000001.SZ", trade_date="20241015",
                open=10.0, high=11.0, low=10.0, close=11.0,
                pre_close=10.0, change=1.0, pct_chg=10.0,
                vol=1000000, amount=10500000.0, name="平安银行"
            ),
            # 异常数据（负价格）
            StockDailyData(
                ts_code="000002.SZ", trade_date="20241015",
                open=-10.0, high=11.0, low=10.0, close=11.0,
                pre_close=10.0, change=1.0, pct_chg=10.0,
                vol=1000000, amount=10500000.0, name="异常股票"
            ),
        ]
        
        # 批量检测应该能处理错误并继续
        results = self.detector.batch_detect_limit_up(stock_data_list)
        
        # 验证结果
        assert len(results) == 2
        
        # 正常股票应该有正确结果
        normal_result = next(r for r in results if r.ts_code == "000001.SZ")
        assert normal_result.is_limit_up is True
        
        # 异常股票应该有错误标记
        error_result = next(r for r in results if r.ts_code == "000002.SZ")
        assert error_result.is_limit_up is False
        assert error_result.confidence == 0.0
        assert 'error' in error_result.detection_details
    
    def test_caching_mechanism(self):
        """测试缓存机制（跳过，因为当前实现不支持缓存）"""
        # 当前实现不支持缓存，跳过此测试
        pytest.skip("Cache functionality not supported in current implementation")
    
    def test_logging_integration(self):
        """测试日志集成"""
        mock_logger = Mock()
        detector = LimitUpDetector(logger=mock_logger)
        
        # 执行检测操作
        detector.is_limit_up(10.0, 11.0, 11.0, 10.0, 'normal')
        
        # 验证日志调用
        assert mock_logger.debug.called or mock_logger.info.called
    
    def test_thread_safety(self):
        """测试线程安全性"""
        import threading
        
        results = []
        errors = []
        
        def detection_worker():
            try:
                for i in range(100):
                    result = self.detector.is_limit_up(
                        open_price=10.0 + i * 0.01,
                        close_price=11.0 + i * 0.01,
                        high_price=11.0 + i * 0.01,
                        prev_close=10.0 + i * 0.01,
                        stock_type='normal'
                    )
                    results.append(result)
            except Exception as e:
                errors.append(e)
        
        # 创建多个线程
        threads = []
        for _ in range(10):
            thread = threading.Thread(target=detection_worker)
            threads.append(thread)
            thread.start()
        
        # 等待所有线程完成
        for thread in threads:
            thread.join()
        
        # 验证结果
        assert len(errors) == 0, f"Thread safety errors: {errors}"
        assert len(results) == 1000


class TestLimitUpDetectionResultComprehensive:
    """涨停检测结果全面测试"""
    
    def test_result_serialization(self):
        """测试结果序列化"""
        result = LimitUpDetectionResult(
            ts_code="000001.SZ",
            is_limit_up=True,
            confidence=0.95,
            limit_up_price=11.0,
            actual_close_price=11.0,
            price_difference=0.0,
            threshold_used=0.10,
            stock_type='normal',
            detection_details={'test': 'data'}
        )
        
        # 测试转换为字典
        result_dict = result.to_dict()
        assert isinstance(result_dict, dict)
        assert result_dict['ts_code'] == "000001.SZ"
        assert result_dict['is_limit_up'] is True
        assert result_dict['confidence'] == 0.95
        
        # 测试从字典创建
        new_result = LimitUpDetectionResult.from_dict(result_dict)
        assert new_result.ts_code == result.ts_code
        assert new_result.is_limit_up == result.is_limit_up
        assert new_result.confidence == result.confidence


class TestExceptionHandlingComprehensive:
    """异常处理全面测试"""
    
    def test_limit_up_detection_error_comprehensive(self):
        """测试涨停检测异常的全面场景"""
        error = LimitUpDetectionError(
            "检测失败", 
            "000001.SZ", 
            {'reason': '价格数据异常', 'details': 'test'}
        )
        
        assert error.stock_code == "000001.SZ"
        assert error.detection_details['reason'] == '价格数据异常'
        
        # 测试转换为字典
        error_dict = error.to_dict()
        assert error_dict['error_type'] == 'LimitUpDetectionError'
        assert error_dict['stock_code'] == "000001.SZ"
    
    def test_insufficient_price_data_error_comprehensive(self):
        """测试价格数据不足异常的全面场景"""
        missing_fields = ['open', 'close', 'high']
        error = InsufficientPriceDataError("000001.SZ", missing_fields)
        
        assert error.stock_code == "000001.SZ"
        assert error.detection_details['missing_fields'] == missing_fields
        assert 'required_fields' in error.detection_details
        assert 'suggestions' in error.detection_details
    
    def test_invalid_price_data_error_comprehensive(self):
        """测试无效价格数据异常的全面场景"""
        validation_errors = ['开盘价不能为负', '收盘价不能高于最高价']
        error = InvalidPriceDataError("000001.SZ", validation_errors)
        
        assert error.stock_code == "000001.SZ"
        assert error.detection_details['validation_errors'] == validation_errors
        assert 'data_quality_issues' in error.detection_details


if __name__ == "__main__":
    pytest.main([__file__])