"""
涨停检测器测试

测试LimitUpDetector类的各种涨停检测场景
"""

import pytest
import math
from unittest.mock import Mock, patch
from quickstock.utils.limit_up_detector import (
    LimitUpDetector,
    LimitUpDetectionResult,
    LimitUpDetectionError,
    InsufficientPriceDataError,
    InvalidPriceDataError,
    detect_limit_up,
    calculate_limit_up_price,
    get_limit_up_threshold
)
from quickstock.models import StockDailyData, LIMIT_UP_THRESHOLDS


class TestLimitUpDetector:
    """涨停检测器测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.detector = LimitUpDetector()
    
    def test_init_default_parameters(self):
        """测试默认参数初始化"""
        detector = LimitUpDetector()
        assert detector.price_tolerance == LimitUpDetector.PRICE_TOLERANCE
        assert detector.logger is not None
    
    def test_init_custom_parameters(self):
        """测试自定义参数初始化"""
        custom_tolerance = 0.01
        detector = LimitUpDetector(price_tolerance=custom_tolerance)
        assert detector.price_tolerance == custom_tolerance
    
    def test_get_limit_up_threshold_normal_stock(self):
        """测试普通股票涨停阈值"""
        threshold = self.detector.get_limit_up_threshold('normal')
        assert threshold == LIMIT_UP_THRESHOLDS['normal']
        assert threshold == 0.10
    
    def test_get_limit_up_threshold_st_stock(self):
        """测试ST股票涨停阈值"""
        threshold = self.detector.get_limit_up_threshold('st')
        assert threshold == LIMIT_UP_THRESHOLDS['st']
        assert threshold == 0.05
    
    def test_get_limit_up_threshold_star_market(self):
        """测试科创板涨停阈值"""
        threshold = self.detector.get_limit_up_threshold('star')
        assert threshold == LIMIT_UP_THRESHOLDS['star']
        assert threshold == 0.20
    
    def test_get_limit_up_threshold_beijing_market(self):
        """测试北证涨停阈值"""
        threshold = self.detector.get_limit_up_threshold('beijing')
        assert threshold == LIMIT_UP_THRESHOLDS['beijing']
        assert threshold == 0.30
    
    def test_get_limit_up_threshold_new_stock(self):
        """测试新股涨停阈值"""
        threshold = self.detector.get_limit_up_threshold('new_stock')
        assert threshold == LIMIT_UP_THRESHOLDS['new_stock']
        assert threshold == 0.44
    
    def test_get_limit_up_threshold_invalid_type(self):
        """测试无效股票类型"""
        with pytest.raises(ValueError, match="Unknown stock type"):
            self.detector.get_limit_up_threshold('invalid_type')
    
    def test_calculate_limit_up_price_normal_stock(self):
        """测试普通股票涨停价计算"""
        prev_close = 10.00
        limit_up_price = self.detector.calculate_limit_up_price(prev_close, 'normal')
        expected = 10.00 * 1.10
        assert abs(limit_up_price - expected) < 0.01
    
    def test_calculate_limit_up_price_st_stock(self):
        """测试ST股票涨停价计算"""
        prev_close = 5.00
        limit_up_price = self.detector.calculate_limit_up_price(prev_close, 'st')
        expected = 5.00 * 1.05
        assert abs(limit_up_price - expected) < 0.01
    
    def test_calculate_limit_up_price_star_market(self):
        """测试科创板涨停价计算"""
        prev_close = 20.00
        limit_up_price = self.detector.calculate_limit_up_price(prev_close, 'star')
        expected = 20.00 * 1.20
        assert abs(limit_up_price - expected) < 0.01
    
    def test_calculate_limit_up_price_beijing_market(self):
        """测试北证涨停价计算"""
        prev_close = 15.00
        limit_up_price = self.detector.calculate_limit_up_price(prev_close, 'beijing')
        expected = 15.00 * 1.30
        assert abs(limit_up_price - expected) < 0.01
    
    def test_calculate_limit_up_price_invalid_prev_close(self):
        """测试无效前收盘价"""
        with pytest.raises(ValueError, match="Previous close price must be positive"):
            self.detector.calculate_limit_up_price(0, 'normal')
        
        with pytest.raises(ValueError, match="Previous close price must be positive"):
            self.detector.calculate_limit_up_price(-5.0, 'normal')
    
    def test_is_limit_up_perfect_match(self):
        """测试完美涨停匹配"""
        prev_close = 10.00
        limit_up_price = 11.00
        
        # 完美涨停：收盘价等于涨停价且等于最高价
        is_limit_up = self.detector.is_limit_up(
            open_price=10.50,
            close_price=limit_up_price,
            high_price=limit_up_price,
            prev_close=prev_close,
            stock_type='normal'
        )
        assert is_limit_up is True
    
    def test_is_limit_up_within_tolerance(self):
        """测试容差范围内的涨停"""
        prev_close = 10.00
        limit_up_price = 11.00
        
        # 在容差范围内的涨停
        close_price = limit_up_price - 0.003  # 在容差范围内
        is_limit_up = self.detector.is_limit_up(
            open_price=10.50,
            close_price=close_price,
            high_price=close_price,
            prev_close=prev_close,
            stock_type='normal'
        )
        assert is_limit_up is True
    
    def test_is_limit_up_outside_tolerance(self):
        """测试超出容差范围的非涨停"""
        prev_close = 10.00
        limit_up_price = 11.00
        
        # 超出容差范围
        close_price = limit_up_price - 0.01  # 超出容差
        is_limit_up = self.detector.is_limit_up(
            open_price=10.50,
            close_price=close_price,
            high_price=close_price,
            prev_close=prev_close,
            stock_type='normal'
        )
        assert is_limit_up is False
    
    def test_is_limit_up_close_not_equal_high(self):
        """测试收盘价不等于最高价的情况"""
        prev_close = 10.00
        limit_up_price = 11.00
        
        # 收盘价等于涨停价但不等于最高价
        is_limit_up = self.detector.is_limit_up(
            open_price=10.50,
            close_price=limit_up_price,
            high_price=limit_up_price + 0.10,  # 最高价更高
            prev_close=prev_close,
            stock_type='normal'
        )
        assert is_limit_up is False
    
    def test_is_limit_up_st_stock(self):
        """测试ST股票涨停检测"""
        prev_close = 10.00
        limit_up_price = 10.50  # ST股票5%涨停
        
        is_limit_up = self.detector.is_limit_up(
            open_price=10.20,
            close_price=limit_up_price,
            high_price=limit_up_price,
            prev_close=prev_close,
            stock_type='st'
        )
        assert is_limit_up is True
    
    def test_is_limit_up_star_market(self):
        """测试科创板涨停检测"""
        prev_close = 20.00
        limit_up_price = 24.00  # 科创板20%涨停
        
        is_limit_up = self.detector.is_limit_up(
            open_price=21.00,
            close_price=limit_up_price,
            high_price=limit_up_price,
            prev_close=prev_close,
            stock_type='star'
        )
        assert is_limit_up is True
    
    def test_is_limit_up_beijing_market(self):
        """测试北证涨停检测"""
        prev_close = 15.00
        limit_up_price = 19.50  # 北证30%涨停
        
        is_limit_up = self.detector.is_limit_up(
            open_price=16.00,
            close_price=limit_up_price,
            high_price=limit_up_price,
            prev_close=prev_close,
            stock_type='beijing'
        )
        assert is_limit_up is True
    
    def test_is_limit_up_without_prev_close(self):
        """测试没有前收盘价的情况（使用开盘价）"""
        open_price = 10.00
        limit_up_price = 11.00
        
        is_limit_up = self.detector.is_limit_up(
            open_price=open_price,
            close_price=limit_up_price,
            high_price=limit_up_price,
            prev_close=None,  # 没有前收盘价
            stock_type='normal'
        )
        assert is_limit_up is True
    
    def test_detect_limit_up_with_details(self):
        """测试详细涨停检测"""
        prev_close = 10.00
        open_price = 10.50
        close_price = 11.00
        high_price = 11.00
        ts_code = "000001.SZ"
        
        result = self.detector.detect_limit_up_with_details(
            open_price=open_price,
            close_price=close_price,
            high_price=high_price,
            prev_close=prev_close,
            stock_type='normal',
            ts_code=ts_code
        )
        
        assert isinstance(result, LimitUpDetectionResult)
        assert result.ts_code == ts_code
        assert result.is_limit_up is True
        assert result.confidence > 0.5
        assert result.limit_up_price == 11.00
        assert result.actual_close_price == close_price
        assert result.threshold_used == 0.10
        assert result.stock_type == 'normal'
        assert 'base_price' in result.detection_details
        assert 'pct_change' in result.detection_details
    
    def test_validate_price_data_valid(self):
        """测试有效价格数据验证"""
        # 不应该抛出异常
        self.detector._validate_price_data(10.0, 11.0, 11.5, 9.5, "000001.SZ")
    
    def test_validate_price_data_negative_price(self):
        """测试负价格验证"""
        with pytest.raises(InvalidPriceDataError):
            self.detector._validate_price_data(-1.0, 11.0, 11.5, 9.5, "000001.SZ")
    
    def test_validate_price_data_zero_price(self):
        """测试零价格验证"""
        with pytest.raises(InvalidPriceDataError):
            self.detector._validate_price_data(0.0, 11.0, 11.5, 9.5, "000001.SZ")
    
    def test_validate_price_data_close_higher_than_high(self):
        """测试收盘价高于最高价"""
        with pytest.raises(InvalidPriceDataError):
            self.detector._validate_price_data(10.0, 12.0, 11.0, 9.5, "000001.SZ")
    
    def test_validate_price_data_open_higher_than_high(self):
        """测试开盘价高于最高价"""
        with pytest.raises(InvalidPriceDataError):
            self.detector._validate_price_data(12.0, 11.0, 11.0, 9.5, "000001.SZ")
    
    def test_batch_detect_limit_up(self):
        """测试批量涨停检测"""
        stock_data_list = [
            StockDailyData(
                ts_code="000001.SZ",
                trade_date="20241015",
                open=10.50,
                high=11.00,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=1000000,
                amount=10500000.0,
                name="平安银行"
            ),
            StockDailyData(
                ts_code="600000.SH",
                trade_date="20241015",
                open=8.50,
                high=9.00,
                low=8.40,
                close=8.80,
                pre_close=8.00,
                change=0.80,
                pct_chg=10.0,
                vol=800000,
                amount=7000000.0,
                name="浦发银行"
            )
        ]
        
        results = self.detector.batch_detect_limit_up(stock_data_list)
        
        assert len(results) == 2
        assert all(isinstance(result, LimitUpDetectionResult) for result in results)
        assert results[0].ts_code == "000001.SZ"
        assert results[1].ts_code == "600000.SH"
    
    def test_batch_detect_limit_up_with_error(self):
        """测试批量检测中的错误处理"""
        # 创建一个有问题的股票数据
        stock_data_list = [
            StockDailyData(
                ts_code="000001.SZ",
                trade_date="20241015",
                open=10.50,
                high=11.00,
                low=10.30,
                close=11.00,
                pre_close=10.00,
                change=1.00,
                pct_chg=10.0,
                vol=1000000,
                amount=10500000.0,
                name="平安银行"
            )
        ]
        
        # Mock检测方法抛出异常
        with patch.object(self.detector, 'detect_limit_up_with_details', side_effect=Exception("Test error")):
            results = self.detector.batch_detect_limit_up(stock_data_list)
            
            assert len(results) == 1
            assert results[0].is_limit_up is False
            assert results[0].confidence == 0.0
            assert 'error' in results[0].detection_details
    
    def test_validate_price_data_completeness_complete(self):
        """测试完整价格数据验证"""
        stock_data = StockDailyData(
            ts_code="000001.SZ",
            trade_date="20241015",
            open=10.50,
            high=11.00,
            low=10.30,
            close=11.00,
            pre_close=10.00,
            change=1.00,
            pct_chg=10.0,
            vol=1000000,
            amount=10500000.0,
            name="平安银行"
        )
        
        result = self.detector.validate_price_data_completeness(stock_data)
        
        assert result['is_complete'] is True
        assert len(result['missing_fields']) == 0
        assert len(result['invalid_fields']) == 0
        assert result['ts_code'] == "000001.SZ"
    
    def test_validate_price_data_completeness_with_warnings(self):
        """测试带警告的价格数据验证"""
        stock_data = StockDailyData(
            ts_code="000001.SZ",
            trade_date="20241015",
            open=10.00,
            high=10.00,
            low=10.00,  # 高低价相等，可能有问题
            close=10.00,
            pre_close=5.00,  # 价格变动100%，异常
            change=5.00,
            pct_chg=100.0,
            vol=1000000,
            amount=10000000.0,
            name="测试股票"
        )
        
        result = self.detector.validate_price_data_completeness(stock_data)
        
        assert result['is_complete'] is True
        assert len(result['warnings']) > 0
        assert any('Unusual price change' in warning for warning in result['warnings'])
        assert any('High equals low' in warning for warning in result['warnings'])
    
    def test_infer_stock_type_from_code_normal(self):
        """测试从代码推断普通股票类型"""
        stock_type = self.detector._infer_stock_type_from_code("000001.SZ", "平安银行")
        assert stock_type == 'normal'
    
    def test_infer_stock_type_from_code_st(self):
        """测试从代码推断ST股票类型"""
        stock_type = self.detector._infer_stock_type_from_code("000002.SZ", "ST万科")
        assert stock_type == 'st'
    
    def test_infer_stock_type_from_code_star(self):
        """测试从代码推断科创板股票类型"""
        stock_type = self.detector._infer_stock_type_from_code("688001.SH", "华兴源创")
        assert stock_type == 'star'
    
    def test_is_st_stock_by_name_st(self):
        """测试ST股票名称识别"""
        assert self.detector._is_st_stock_by_name("ST万科") is True
        assert self.detector._is_st_stock_by_name("*ST海润") is True
        assert self.detector._is_st_stock_by_name("退市大控") is True
        assert self.detector._is_st_stock_by_name("暂停上市") is True
    
    def test_is_st_stock_by_name_normal(self):
        """测试普通股票名称识别"""
        assert self.detector._is_st_stock_by_name("平安银行") is False
        assert self.detector._is_st_stock_by_name("招商银行") is False
        assert self.detector._is_st_stock_by_name("") is False
        assert self.detector._is_st_stock_by_name(None) is False
    
    def test_round_price(self):
        """测试价格精度处理"""
        assert self.detector._round_price(10.123) == 10.12
        assert self.detector._round_price(10.126) == 10.13
        assert self.detector._round_price(1000.999) == 1001.00
    
    def test_calculate_detection_confidence_high(self):
        """测试高置信度计算"""
        confidence = self.detector._calculate_detection_confidence(
            close_price=11.00,
            limit_up_price=11.00,
            high_price=11.00,
            price_diff=0.0,
            is_within_tolerance=True,
            close_equals_high=True
        )
        assert confidence >= 0.9
    
    def test_calculate_detection_confidence_low(self):
        """测试低置信度计算"""
        confidence = self.detector._calculate_detection_confidence(
            close_price=10.50,
            limit_up_price=11.00,
            high_price=11.00,
            price_diff=0.50,
            is_within_tolerance=False,
            close_equals_high=False
        )
        assert confidence == 0.0


class TestConvenienceFunctions:
    """便利函数测试类"""
    
    def test_detect_limit_up_function(self):
        """测试便利函数detect_limit_up"""
        result = detect_limit_up(
            open_price=10.50,
            close_price=11.00,
            high_price=11.00,
            prev_close=10.00,
            stock_type='normal'
        )
        assert result is True
    
    def test_calculate_limit_up_price_function(self):
        """测试便利函数calculate_limit_up_price"""
        price = calculate_limit_up_price(10.00, 'normal')
        assert abs(price - 11.00) < 0.01
    
    def test_get_limit_up_threshold_function(self):
        """测试便利函数get_limit_up_threshold"""
        threshold = get_limit_up_threshold('normal')
        assert threshold == 0.10


class TestLimitUpDetectionResult:
    """涨停检测结果测试类"""
    
    def test_to_dict(self):
        """测试转换为字典"""
        result = LimitUpDetectionResult(
            ts_code="000001.SZ",
            is_limit_up=True,
            confidence=0.95,
            limit_up_price=11.00,
            actual_close_price=11.00,
            price_difference=0.0,
            threshold_used=0.10,
            stock_type='normal',
            detection_details={'test': 'data'}
        )
        
        result_dict = result.to_dict()
        
        assert result_dict['ts_code'] == "000001.SZ"
        assert result_dict['is_limit_up'] is True
        assert result_dict['confidence'] == 0.95
        assert result_dict['detection_details'] == {'test': 'data'}


class TestExceptionHandling:
    """异常处理测试类"""
    
    def test_limit_up_detection_error(self):
        """测试涨停检测异常"""
        error = LimitUpDetectionError("Test error", "000001.SZ", {'test': 'data'})
        
        assert error.stock_code == "000001.SZ"
        assert error.detection_details == {'test': 'data'}
        
        error_dict = error.to_dict()
        assert error_dict['error_type'] == 'LimitUpDetectionError'
        assert error_dict['stock_code'] == "000001.SZ"
    
    def test_insufficient_price_data_error(self):
        """测试价格数据不足异常"""
        missing_fields = ['open', 'close']
        error = InsufficientPriceDataError("000001.SZ", missing_fields)
        
        assert error.stock_code == "000001.SZ"
        assert error.detection_details['missing_fields'] == missing_fields
        assert 'required_fields' in error.detection_details
    
    def test_invalid_price_data_error(self):
        """测试无效价格数据异常"""
        validation_errors = ['Price must be positive', 'Invalid range']
        error = InvalidPriceDataError("000001.SZ", validation_errors)
        
        assert error.stock_code == "000001.SZ"
        assert error.detection_details['validation_errors'] == validation_errors


class TestEdgeCases:
    """边界情况测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.detector = LimitUpDetector()
    
    def test_very_small_price(self):
        """测试极小价格"""
        prev_close = 0.01
        limit_up_price = self.detector.calculate_limit_up_price(prev_close, 'normal')
        expected = round(prev_close * 1.10, 2)  # 考虑价格精度处理
        assert abs(limit_up_price - expected) < 0.01
    
    def test_very_large_price(self):
        """测试极大价格"""
        prev_close = 10000.00
        limit_up_price = self.detector.calculate_limit_up_price(prev_close, 'normal')
        expected = prev_close * 1.10
        assert abs(limit_up_price - expected) < 0.01
    
    def test_price_precision_edge_case(self):
        """测试价格精度边界情况"""
        prev_close = 10.005
        close_price = 11.006  # 刚好在容差边界
        
        is_limit_up = self.detector.is_limit_up(
            open_price=10.50,
            close_price=close_price,
            high_price=close_price,
            prev_close=prev_close,
            stock_type='normal'
        )
        # 应该根据具体的容差设置来判断
        assert isinstance(is_limit_up, bool)
    
    def test_floating_point_precision(self):
        """测试浮点精度问题"""
        prev_close = 10.00
        # 模拟浮点精度问题
        close_price = 10.00 * 1.10  # 可能不完全等于11.00
        
        is_limit_up = self.detector.is_limit_up(
            open_price=10.50,
            close_price=close_price,
            high_price=close_price,
            prev_close=prev_close,
            stock_type='normal'
        )
        assert is_limit_up is True  # 应该能正确处理浮点精度问题


if __name__ == '__main__':
    pytest.main([__file__])