"""
价格工具测试

测试PriceUtils类的各种价格计算和验证功能
"""

import pytest
import math
from decimal import Decimal
from unittest.mock import Mock, patch
from quickstock.utils.price_utils import (
    PriceUtils,
    PriceValidationResult,
    PriceComparisonResult,
    PriceValidationError,
    PriceComparisonError,
    PriceCalculationError,
    calculate_limit_up_price,
    calculate_limit_down_price,
    compare_prices,
    round_price,
    validate_ohlc_prices
)
from quickstock.models import LIMIT_UP_THRESHOLDS


class TestPriceUtils:
    """价格工具测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.price_utils = PriceUtils()
    
    def test_init_default_parameters(self):
        """测试默认参数初始化"""
        utils = PriceUtils()
        assert utils.price_tolerance == PriceUtils.DEFAULT_PRICE_TOLERANCE
        assert utils.logger is not None
    
    def test_init_custom_parameters(self):
        """测试自定义参数初始化"""
        custom_tolerance = 0.01
        utils = PriceUtils(price_tolerance=custom_tolerance)
        assert utils.price_tolerance == custom_tolerance
    
    def test_calculate_limit_price_limit_up(self):
        """测试涨停价格计算"""
        base_price = 10.00
        change_rate = 0.10
        
        result = self.price_utils.calculate_limit_price(base_price, change_rate, 'limit_up')
        expected = 11.00
        assert abs(result - expected) < 0.01
    
    def test_calculate_limit_price_limit_down(self):
        """测试跌停价格计算"""
        base_price = 10.00
        change_rate = 0.10
        
        result = self.price_utils.calculate_limit_price(base_price, change_rate, 'limit_down')
        expected = 9.00
        assert abs(result - expected) < 0.01
    
    def test_calculate_limit_price_invalid_base_price(self):
        """测试无效基准价格"""
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_limit_price(0, 0.10, 'limit_up')
        
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_limit_price(-5.0, 0.10, 'limit_up')
    
    def test_calculate_limit_price_invalid_change_rate(self):
        """测试无效变动比率"""
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_limit_price(10.0, -0.10, 'limit_up')
        
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_limit_price(10.0, 'invalid', 'limit_up')
    
    def test_calculate_limit_price_invalid_price_type(self):
        """测试无效价格类型"""
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_limit_price(10.0, 0.10, 'invalid_type')
    
    def test_calculate_limit_up_price_normal_stock(self):
        """测试普通股票涨停价计算"""
        base_price = 10.00
        result = self.price_utils.calculate_limit_up_price(base_price, 'normal')
        expected = 11.00
        assert abs(result - expected) < 0.01
    
    def test_calculate_limit_up_price_st_stock(self):
        """测试ST股票涨停价计算"""
        base_price = 10.00
        result = self.price_utils.calculate_limit_up_price(base_price, 'st')
        expected = 10.50
        assert abs(result - expected) < 0.01
    
    def test_calculate_limit_up_price_star_market(self):
        """测试科创板涨停价计算"""
        base_price = 20.00
        result = self.price_utils.calculate_limit_up_price(base_price, 'star')
        expected = 24.00
        assert abs(result - expected) < 0.01
    
    def test_calculate_limit_up_price_beijing_market(self):
        """测试北证涨停价计算"""
        base_price = 15.00
        result = self.price_utils.calculate_limit_up_price(base_price, 'beijing')
        expected = 19.50
        assert abs(result - expected) < 0.01
    
    def test_calculate_limit_up_price_invalid_stock_type(self):
        """测试无效股票类型"""
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_limit_up_price(10.0, 'invalid_type')
    
    def test_calculate_limit_down_price_normal_stock(self):
        """测试普通股票跌停价计算"""
        base_price = 10.00
        result = self.price_utils.calculate_limit_down_price(base_price, 'normal')
        expected = 9.00
        assert abs(result - expected) < 0.01
    
    def test_calculate_limit_down_price_st_stock(self):
        """测试ST股票跌停价计算"""
        base_price = 10.00
        result = self.price_utils.calculate_limit_down_price(base_price, 'st')
        expected = 9.50
        assert abs(result - expected) < 0.01
    
    def test_calculate_price_change(self):
        """测试价格变动计算"""
        current_price = 11.00
        previous_price = 10.00
        
        result = self.price_utils.calculate_price_change(current_price, previous_price)
        
        assert result['absolute_change'] == 1.00
        assert result['relative_change'] == 10.0
        assert result['current_price'] == current_price
        assert result['previous_price'] == previous_price
    
    def test_calculate_price_change_negative(self):
        """测试负价格变动计算"""
        current_price = 9.00
        previous_price = 10.00
        
        result = self.price_utils.calculate_price_change(current_price, previous_price)
        
        assert result['absolute_change'] == -1.00
        assert result['relative_change'] == -10.0
    
    def test_calculate_price_change_invalid_previous_price(self):
        """测试无效前价格"""
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_price_change(11.0, 0)
        
        with pytest.raises(PriceCalculationError):
            self.price_utils.calculate_price_change(11.0, -5.0)
    
    def test_compare_prices_equal(self):
        """测试相等价格比较"""
        price1 = 10.00
        price2 = 10.003  # 在默认容差范围内
        
        result = self.price_utils.compare_prices(price1, price2)
        
        assert isinstance(result, PriceComparisonResult)
        assert result.are_equal is True
        assert result.difference < self.price_utils.price_tolerance
        assert result.tolerance_used == self.price_utils.price_tolerance
    
    def test_compare_prices_not_equal(self):
        """测试不相等价格比较"""
        price1 = 10.00
        price2 = 10.10  # 超出容差范围
        
        result = self.price_utils.compare_prices(price1, price2)
        
        assert result.are_equal is False
        assert result.difference > self.price_utils.price_tolerance
    
    def test_compare_prices_custom_tolerance(self):
        """测试自定义容差价格比较"""
        price1 = 10.00
        price2 = 10.02
        custom_tolerance = 0.03
        
        result = self.price_utils.compare_prices(price1, price2, custom_tolerance)
        
        assert result.are_equal is True
        assert result.tolerance_used == custom_tolerance
    
    def test_compare_prices_invalid_input(self):
        """测试无效输入价格比较"""
        with pytest.raises(PriceComparisonError):
            self.price_utils.compare_prices("invalid", 10.0)
        
        with pytest.raises(PriceComparisonError):
            self.price_utils.compare_prices(10.0, None)
    
    def test_is_price_equal_true(self):
        """测试价格相等判断（真）"""
        assert self.price_utils.is_price_equal(10.00, 10.003) is True
    
    def test_is_price_equal_false(self):
        """测试价格相等判断（假）"""
        assert self.price_utils.is_price_equal(10.00, 10.10) is False
    
    def test_round_price_default_precision(self):
        """测试默认精度价格处理"""
        assert self.price_utils.round_price(10.123) == 10.12
        assert self.price_utils.round_price(10.126) == 10.13
        assert self.price_utils.round_price(10.125) == 10.13  # 四舍五入
    
    def test_round_price_custom_precision(self):
        """测试自定义精度价格处理"""
        assert self.price_utils.round_price(10.123, 1) == 10.1
        assert self.price_utils.round_price(10.123, 3) == 10.123
    
    def test_validate_price_data_valid(self):
        """测试有效价格数据验证"""
        price_data = {
            'open': 10.50,
            'high': 11.00,
            'low': 10.30,
            'close': 10.80
        }
        
        result = self.price_utils.validate_price_data(price_data)
        
        assert isinstance(result, PriceValidationResult)
        assert result.is_valid is True
        assert len(result.errors) == 0
    
    def test_validate_price_data_missing_fields(self):
        """测试缺失字段价格数据验证"""
        price_data = {
            'open': 10.50,
            'high': 11.00
            # 缺少 'low' 和 'close'
        }
        
        result = self.price_utils.validate_price_data(price_data)
        
        assert result.is_valid is False
        assert len(result.errors) > 0
        assert any('Missing required fields' in error for error in result.errors)
    
    def test_validate_price_data_invalid_prices(self):
        """测试无效价格数据验证"""
        price_data = {
            'open': -10.50,  # 负价格
            'high': 0,       # 零价格
            'low': 'invalid', # 非数字
            'close': float('nan')  # NaN
        }
        
        result = self.price_utils.validate_price_data(price_data)
        
        assert result.is_valid is False
        assert len(result.errors) > 0
    
    def test_validate_price_data_logic_errors(self):
        """测试价格逻辑错误验证"""
        price_data = {
            'open': 10.50,
            'high': 10.00,  # 最高价低于开盘价
            'low': 11.00,   # 最低价高于开盘价
            'close': 10.80
        }
        
        result = self.price_utils.validate_price_data(price_data)
        
        assert result.is_valid is False
        assert len(result.errors) > 0
    
    def test_validate_price_data_with_warnings(self):
        """测试带警告的价格数据验证"""
        price_data = {
            'open': 10.00,
            'high': 10.00,  # 高低价相等
            'low': 10.00,
            'close': 10.00,
            'pre_close': 5.00  # 100%变动
        }
        
        result = self.price_utils.validate_price_data(price_data)
        
        assert result.is_valid is True  # 逻辑上有效
        assert len(result.warnings) > 0  # 但有警告
    
    def test_validate_ohlc_relationship_valid(self):
        """测试有效OHLC关系验证"""
        result = self.price_utils.validate_ohlc_relationship(10.50, 11.00, 10.30, 10.80)
        
        assert result['is_valid'] is True
        assert len(result['violations']) == 0
        assert all(result['relationships'].values())
    
    def test_validate_ohlc_relationship_invalid(self):
        """测试无效OHLC关系验证"""
        result = self.price_utils.validate_ohlc_relationship(10.50, 10.00, 11.00, 10.80)
        
        assert result['is_valid'] is False
        assert len(result['violations']) > 0
    
    def test_validate_ohlc_relationship_invalid_prices(self):
        """测试无效价格OHLC关系验证"""
        result = self.price_utils.validate_ohlc_relationship(-10.50, 11.00, 10.30, 10.80)
        
        assert result['is_valid'] is False
        assert any('invalid' in violation for violation in result['violations'])
    
    def test_calculate_price_statistics_valid_data(self):
        """测试有效数据价格统计计算"""
        prices = [10.0, 11.0, 9.0, 12.0, 8.0]
        
        result = self.price_utils.calculate_price_statistics(prices)
        
        assert result['count'] == 5
        assert result['min'] == 8.0
        assert result['max'] == 12.0
        assert result['mean'] == 10.0
        assert result['median'] == 10.0
        assert result['range'] == 4.0
        assert 'std_dev' in result
        assert 'coefficient_of_variation' in result
    
    def test_calculate_price_statistics_empty_data(self):
        """测试空数据价格统计计算"""
        result = self.price_utils.calculate_price_statistics([])
        assert result == {}
    
    def test_calculate_price_statistics_invalid_data(self):
        """测试无效数据价格统计计算"""
        prices = [-1.0, 0, 'invalid', None]
        result = self.price_utils.calculate_price_statistics(prices)
        assert result == {}
    
    def test_calculate_price_statistics_mixed_data(self):
        """测试混合数据价格统计计算"""
        prices = [10.0, -1.0, 11.0, 0, 'invalid', 12.0]
        
        result = self.price_utils.calculate_price_statistics(prices)
        
        assert result['count'] == 3  # 只有3个有效价格
        assert result['min'] == 10.0
        assert result['max'] == 12.0


class TestPriceValidationResult:
    """价格验证结果测试类"""
    
    def test_to_dict(self):
        """测试转换为字典"""
        result = PriceValidationResult(
            is_valid=True,
            errors=[],
            warnings=['test warning'],
            price_analysis={'test': 'data'},
            suggestions=['test suggestion']
        )
        
        result_dict = result.to_dict()
        
        assert result_dict['is_valid'] is True
        assert result_dict['warnings'] == ['test warning']
        assert result_dict['price_analysis'] == {'test': 'data'}
        assert result_dict['suggestions'] == ['test suggestion']


class TestPriceComparisonResult:
    """价格比较结果测试类"""
    
    def test_to_dict(self):
        """测试转换为字典"""
        result = PriceComparisonResult(
            are_equal=True,
            difference=0.001,
            relative_difference=0.01,
            tolerance_used=0.005,
            comparison_details={'test': 'data'}
        )
        
        result_dict = result.to_dict()
        
        assert result_dict['are_equal'] is True
        assert result_dict['difference'] == 0.001
        assert result_dict['relative_difference'] == 0.01
        assert result_dict['tolerance_used'] == 0.005
        assert result_dict['comparison_details'] == {'test': 'data'}


class TestConvenienceFunctions:
    """便利函数测试类"""
    
    def test_calculate_limit_up_price_function(self):
        """测试便利函数calculate_limit_up_price"""
        result = calculate_limit_up_price(10.0, 'normal')
        assert abs(result - 11.0) < 0.01
    
    def test_calculate_limit_down_price_function(self):
        """测试便利函数calculate_limit_down_price"""
        result = calculate_limit_down_price(10.0, 'normal')
        assert abs(result - 9.0) < 0.01
    
    def test_compare_prices_function(self):
        """测试便利函数compare_prices"""
        assert compare_prices(10.0, 10.003) is True
        assert compare_prices(10.0, 10.1) is False
    
    def test_round_price_function(self):
        """测试便利函数round_price"""
        assert round_price(10.123) == 10.12
        assert round_price(10.126, 1) == 10.1
    
    def test_validate_ohlc_prices_function(self):
        """测试便利函数validate_ohlc_prices"""
        assert validate_ohlc_prices(10.5, 11.0, 10.3, 10.8) is True
        assert validate_ohlc_prices(10.5, 10.0, 11.0, 10.8) is False


class TestExceptionHandling:
    """异常处理测试类"""
    
    def test_price_validation_error(self):
        """测试价格验证异常"""
        price_data = {'test': 'data'}
        error = PriceValidationError("Test error", price_data)
        
        assert error.price_data == price_data
        
        error_dict = error.to_dict()
        assert error_dict['error_type'] == 'PriceValidationError'
        assert error_dict['price_data'] == price_data
    
    def test_price_comparison_error(self):
        """测试价格比较异常"""
        error = PriceComparisonError("Test comparison error")
        assert isinstance(error, PriceValidationError)
    
    def test_price_calculation_error(self):
        """测试价格计算异常"""
        error = PriceCalculationError("Test calculation error")
        assert isinstance(error, PriceValidationError)


class TestEdgeCases:
    """边界情况测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.price_utils = PriceUtils()
    
    def test_very_small_prices(self):
        """测试极小价格"""
        base_price = 0.01
        result = self.price_utils.calculate_limit_up_price(base_price, 'normal')
        expected = round(base_price * 1.10, 2)
        assert abs(result - expected) < 0.001
    
    def test_very_large_prices(self):
        """测试极大价格"""
        base_price = 100000.0
        result = self.price_utils.calculate_limit_up_price(base_price, 'normal')
        expected = base_price * 1.10
        assert abs(result - expected) < 1.0
    
    def test_floating_point_precision(self):
        """测试浮点精度问题"""
        # 测试可能导致浮点精度问题的计算
        base_price = 10.01
        result = self.price_utils.calculate_limit_up_price(base_price, 'normal')
        
        # 应该能正确处理浮点精度
        assert isinstance(result, float)
        assert result > base_price
    
    def test_zero_tolerance_comparison(self):
        """测试零容差比较"""
        result = self.price_utils.compare_prices(10.0, 10.001, tolerance=0.0)
        assert result.are_equal is False
    
    def test_very_large_tolerance_comparison(self):
        """测试极大容差比较"""
        result = self.price_utils.compare_prices(10.0, 15.0, tolerance=10.0)
        assert result.are_equal is True
    
    def test_identical_ohlc_prices(self):
        """测试完全相同的OHLC价格"""
        price = 10.0
        result = self.price_utils.validate_ohlc_relationship(price, price, price, price)
        assert result['is_valid'] is True
    
    def test_extreme_price_changes(self):
        """测试极端价格变动"""
        current_price = 100.0
        previous_price = 1.0  # 9900%变动
        
        result = self.price_utils.calculate_price_change(current_price, previous_price)
        assert result['relative_change'] == 9900.0
    
    def test_price_statistics_single_value(self):
        """测试单一价格统计"""
        prices = [10.0]
        result = self.price_utils.calculate_price_statistics(prices)
        
        assert result['count'] == 1
        assert result['min'] == result['max'] == result['mean'] == result['median'] == 10.0
        assert result['range'] == 0.0
        assert result['std_dev'] == 0.0
    
    def test_price_statistics_two_values(self):
        """测试两个价格统计"""
        prices = [10.0, 12.0]
        result = self.price_utils.calculate_price_statistics(prices)
        
        assert result['count'] == 2
        assert result['min'] == 10.0
        assert result['max'] == 12.0
        assert result['mean'] == 11.0
        assert result['median'] == 11.0
        assert result['range'] == 2.0


class TestPrecisionHandling:
    """精度处理测试类"""
    
    def setup_method(self):
        """测试前置设置"""
        self.price_utils = PriceUtils()
    
    def test_decimal_precision_rounding(self):
        """测试Decimal精度处理"""
        # 测试可能导致浮点精度问题的数字
        test_cases = [
            (10.125, 10.13),  # 四舍五入
            (10.124, 10.12),  # 四舍五入
            (0.015, 0.02),    # 小数精度
            (999.999, 1000.00)  # 进位
        ]
        
        for input_price, expected in test_cases:
            result = self.price_utils.round_price(input_price)
            assert abs(result - expected) < 0.001, f"Failed for {input_price}: got {result}, expected {expected}"
    
    def test_price_comparison_with_rounding_errors(self):
        """测试带舍入误差的价格比较"""
        # 模拟浮点运算可能产生的舍入误差
        price1 = 10.0
        price2 = 10.0 + 1e-15  # 极小的浮点误差
        
        result = self.price_utils.compare_prices(price1, price2)
        assert result.are_equal is True  # 应该被认为是相等的
    
    def test_limit_price_calculation_precision(self):
        """测试涨停价计算精度"""
        # 测试可能产生精度问题的基准价格
        test_cases = [
            (10.01, 'normal'),
            (33.33, 'normal'),
            (7.77, 'st'),
            (123.45, 'star')
        ]
        
        for base_price, stock_type in test_cases:
            result = self.price_utils.calculate_limit_up_price(base_price, stock_type)
            
            # 结果应该是合理的价格格式（2位小数）
            assert isinstance(result, float)
            assert result > base_price
            assert len(str(result).split('.')[-1]) <= 2  # 最多2位小数


if __name__ == '__main__':
    pytest.main([__file__])