"""
性能优化功能测试

测试向量化操作、并行处理和内存优化功能
"""

import pytest
import pytest_asyncio
import asyncio
import time
import pandas as pd
import numpy as np
from unittest.mock import Mock, patch
import tempfile
import os

from quickstock.utils.performance_optimizations import (
    VectorizedOperations, ParallelProcessor, MemoryOptimizer,
    PerformanceOptimizer, PerformanceConfig
)
from quickstock.utils.performance_benchmarks import (
    PerformanceBenchmark, BenchmarkConfig, ContinuousPerformanceMonitor
)


class TestVectorizedOperations:
    """向量化操作测试"""
    
    @pytest.fixture
    def vectorized_ops(self):
        """创建向量化操作实例"""
        return VectorizedOperations()
    
    @pytest.fixture
    def sample_stock_data(self):
        """创建示例股票数据"""
        data = []
        for i in range(1000):
            # 生成不同市场的股票代码
            if i % 4 == 0:
                ts_code = f'60{i:04d}.SH'
            elif i % 4 == 1:
                ts_code = f'00{i:04d}.SZ'
            elif i % 4 == 2:
                ts_code = f'688{i:03d}.SH'
            else:
                ts_code = f'43{i:04d}.BJ'
            
            base_price = 10.0 + i * 0.01
            
            # 30%的股票涨停
            if i % 3 == 0:
                close_price = base_price * 1.10
                high_price = close_price
            else:
                close_price = base_price * (1 + (i % 10 - 5) * 0.01)
                high_price = close_price * 1.02
            
            data.append({
                'ts_code': ts_code,
                'trade_date': '20241015',
                'open': base_price,
                'high': high_price,
                'low': base_price * 0.98,
                'close': close_price,
                'pre_close': base_price,
                'change': close_price - base_price,
                'pct_chg': (close_price - base_price) / base_price * 100,
                'vol': (i % 1000 + 1) * 1000,
                'amount': close_price * (i % 1000 + 1) * 1000,
                'name': f'股票{i:04d}' if i % 20 != 0 else f'ST股票{i:04d}'
            })
        
        return pd.DataFrame(data)
    
    def test_vectorized_limit_up_detection(self, vectorized_ops, sample_stock_data):
        """测试向量化涨停检测"""
        # 执行向量化检测
        result = vectorized_ops.vectorized_limit_up_detection(sample_stock_data)
        
        # 验证结果
        assert isinstance(result, pd.DataFrame)
        assert len(result) == len(sample_stock_data)
        assert 'is_limit_up' in result.columns
        assert 'market' in result.columns
        assert 'is_st' in result.columns
        assert 'confidence' in result.columns
        
        # 验证涨停检测结果
        limit_up_count = result['is_limit_up'].sum()
        assert limit_up_count > 0
        assert limit_up_count <= len(result)
        
        # 验证市场分类
        market_counts = result['market'].value_counts()
        assert 'shanghai' in market_counts.index
        assert 'shenzhen' in market_counts.index
        assert 'star' in market_counts.index
        assert 'beijing' in market_counts.index
        
        # 验证ST检测
        st_count = result['is_st'].sum()
        assert st_count > 0  # 应该有ST股票
    
    def test_vectorized_market_classification(self, vectorized_ops):
        """测试向量化市场分类"""
        # 创建测试股票代码
        ts_codes = pd.Series([
            '600000.SH',  # 上海主板
            '000001.SZ',  # 深圳主板
            '688001.SH',  # 科创板
            '430001.BJ',  # 北证
            '688999.SH',  # 科创板
            '300001.SZ'   # 深圳创业板
        ])
        
        # 执行分类
        markets = vectorized_ops._vectorized_market_classification(ts_codes)
        
        # 验证结果
        assert len(markets) == len(ts_codes)
        assert markets.iloc[0] == 'shanghai'
        assert markets.iloc[1] == 'shenzhen'
        assert markets.iloc[2] == 'star'
        assert markets.iloc[3] == 'beijing'
        assert markets.iloc[4] == 'star'
        assert markets.iloc[5] == 'shenzhen'
    
    def test_vectorized_st_detection(self, vectorized_ops):
        """测试向量化ST检测"""
        # 创建测试股票名称
        stock_names = pd.Series([
            '平安银行',
            'ST万科',
            '*ST海马',
            '退市大控',
            '暂停交易',
            '正常股票'
        ])
        
        # 执行ST检测
        is_st = vectorized_ops._vectorized_st_detection(stock_names)
        
        # 验证结果
        assert len(is_st) == len(stock_names)
        assert is_st.iloc[0] == False  # 平安银行
        assert is_st.iloc[1] == True   # ST万科
        assert is_st.iloc[2] == True   # *ST海马
        assert is_st.iloc[3] == True   # 退市大控
        assert is_st.iloc[4] == True   # 暂停交易
        assert is_st.iloc[5] == False  # 正常股票
    
    def test_vectorized_statistics_aggregation(self, vectorized_ops, sample_stock_data):
        """测试向量化统计聚合"""
        # 先进行涨停检测
        limit_up_data = vectorized_ops.vectorized_limit_up_detection(sample_stock_data)
        
        # 执行统计聚合
        stats = vectorized_ops.vectorized_statistics_aggregation(limit_up_data)
        
        # 验证结果结构
        assert isinstance(stats, dict)
        required_keys = ['total', 'non_st', 'shanghai', 'shenzhen', 'star', 'beijing', 'st', 
                        'limit_up_stocks', 'market_breakdown']
        for key in required_keys:
            assert key in stats
        
        # 验证数据一致性
        assert stats['total'] >= 0
        assert stats['non_st'] + stats['st'] <= stats['total']  # 可能有未分类的
        assert len(stats['limit_up_stocks']) == stats['total']
        
        # 验证市场分解
        assert isinstance(stats['market_breakdown'], dict)
        market_total = sum(len(stocks) for stocks in stats['market_breakdown'].values())
        assert market_total <= stats['total']
    
    def test_empty_data_handling(self, vectorized_ops):
        """测试空数据处理"""
        empty_df = pd.DataFrame()
        
        # 测试向量化检测
        result = vectorized_ops.vectorized_limit_up_detection(empty_df)
        assert result.empty
        
        # 测试统计聚合
        stats = vectorized_ops.vectorized_statistics_aggregation(empty_df)
        assert stats['total'] == 0
        assert len(stats['limit_up_stocks']) == 0
    
    def test_performance_comparison(self, vectorized_ops, sample_stock_data):
        """测试向量化操作性能"""
        # 测试向量化操作性能
        start_time = time.time()
        vectorized_result = vectorized_ops.vectorized_limit_up_detection(sample_stock_data)
        vectorized_time = time.time() - start_time
        
        # 验证结果
        assert isinstance(vectorized_result, pd.DataFrame)
        assert len(vectorized_result) == len(sample_stock_data)
        
        # 性能要求：1000只股票应该在1秒内处理完成
        assert vectorized_time < 1.0, f"Vectorized operation took {vectorized_time:.3f}s"
        
        print(f"向量化操作性能: {vectorized_time:.3f}s for {len(sample_stock_data)} stocks")


class TestParallelProcessor:
    """并行处理器测试"""
    
    @pytest.fixture
    def parallel_processor(self):
        """创建并行处理器实例"""
        config = PerformanceConfig(
            enable_parallel_processing=True,
            max_workers=4,
            batch_size=500,
            chunk_size=100
        )
        return ParallelProcessor(config)
    
    @pytest.fixture
    def large_stock_data(self):
        """创建大量股票数据"""
        data = []
        for i in range(2000):  # 更大的数据集
            ts_code = f'60{i:04d}.SH' if i % 2 == 0 else f'00{i:04d}.SZ'
            base_price = 10.0 + i * 0.01
            
            if i % 3 == 0:
                close_price = base_price * 1.10
                high_price = close_price
            else:
                close_price = base_price * (1 + (i % 10 - 5) * 0.01)
                high_price = close_price * 1.02
            
            data.append({
                'ts_code': ts_code,
                'trade_date': '20241015',
                'open': base_price,
                'high': high_price,
                'low': base_price * 0.98,
                'close': close_price,
                'pre_close': base_price,
                'change': close_price - base_price,
                'pct_chg': (close_price - base_price) / base_price * 100,
                'vol': (i % 1000 + 1) * 1000,
                'amount': close_price * (i % 1000 + 1) * 1000,
                'name': f'股票{i:04d}'
            })
        
        return pd.DataFrame(data)
    
    @pytest.mark.asyncio
    async def test_parallel_stock_processing(self, parallel_processor, large_stock_data):
        """测试并行股票处理"""
        # 执行并行处理
        result = await parallel_processor.parallel_stock_processing(large_stock_data)
        
        # 验证结果
        assert isinstance(result, pd.DataFrame)
        assert len(result) == len(large_stock_data)
        assert 'is_limit_up' in result.columns
        assert 'market' in result.columns
        
        # 验证处理结果
        limit_up_count = result['is_limit_up'].sum()
        assert limit_up_count > 0
    
    @pytest.mark.asyncio
    async def test_parallel_vs_sequential_performance(self, parallel_processor, large_stock_data):
        """测试并行vs串行性能对比"""
        # 测试并行处理性能
        start_time = time.time()
        parallel_result = await parallel_processor.parallel_stock_processing(large_stock_data)
        parallel_time = time.time() - start_time
        
        # 测试串行处理性能（使用向量化操作）
        start_time = time.time()
        sequential_result = parallel_processor.vectorized_ops.vectorized_limit_up_detection(large_stock_data)
        sequential_time = time.time() - start_time
        
        # 验证结果一致性
        assert len(parallel_result) == len(sequential_result)
        
        # 记录性能对比
        print(f"并行处理时间: {parallel_time:.3f}s")
        print(f"串行处理时间: {sequential_time:.3f}s")
        
        # 对于大数据集，并行处理应该有优势或至少不会太慢
        assert parallel_time < sequential_time * 2, "并行处理性能不应该比串行慢太多"
    
    @pytest.mark.asyncio
    async def test_parallel_batch_operations(self, parallel_processor):
        """测试并行批量操作"""
        # 创建测试操作
        def test_operation(x):
            time.sleep(0.01)  # 模拟耗时操作
            return x * 2
        
        operations = [lambda: test_operation(i) for i in range(20)]
        
        # 执行并行批量操作
        start_time = time.time()
        results = await parallel_processor.parallel_batch_operations(operations, batch_size=5)
        parallel_time = time.time() - start_time
        
        # 验证结果
        assert len(results) == 20
        assert all(isinstance(r, int) for r in results if r is not None)
        
        # 性能验证：并行处理应该比串行快
        expected_sequential_time = 20 * 0.01  # 20个操作 * 0.01秒
        assert parallel_time < expected_sequential_time, f"并行时间 {parallel_time:.3f}s 应该小于串行时间 {expected_sequential_time:.3f}s"
    
    def test_split_dataframe(self, parallel_processor, large_stock_data):
        """测试DataFrame分割"""
        chunk_size = 100
        chunks = parallel_processor._split_dataframe(large_stock_data, chunk_size)
        
        # 验证分割结果
        assert len(chunks) > 1
        assert sum(len(chunk) for chunk in chunks) == len(large_stock_data)
        
        # 验证每个块的大小
        for i, chunk in enumerate(chunks[:-1]):  # 除了最后一个块
            assert len(chunk) == chunk_size
        
        # 最后一个块可能小于chunk_size
        assert len(chunks[-1]) <= chunk_size


class TestMemoryOptimizer:
    """内存优化器测试"""
    
    @pytest.fixture
    def memory_optimizer(self):
        """创建内存优化器实例"""
        config = PerformanceConfig(
            enable_memory_optimization=True,
            memory_limit_mb=512
        )
        return MemoryOptimizer(config)
    
    @pytest.fixture
    def memory_intensive_data(self):
        """创建内存密集型数据"""
        # 创建包含大量重复字符串的DataFrame
        data = {
            'ts_code': [f'60{i:04d}.SH' for i in range(1000)] * 5,  # 重复数据
            'name': ['股票名称很长很长很长' + str(i) for i in range(1000)] * 5,
            'large_int': [i * 1000000 for i in range(5000)],  # 大整数
            'large_float': [i * 1.123456789 for i in range(5000)],  # 高精度浮点数
            'category_data': ['类别A', '类别B', '类别C'] * 1667  # 分类数据
        }
        return pd.DataFrame(data)
    
    def test_get_memory_usage(self, memory_optimizer):
        """测试内存使用监控"""
        memory_usage = memory_optimizer.get_memory_usage()
        
        # 验证返回结构
        assert isinstance(memory_usage, dict)
        assert 'rss_mb' in memory_usage
        assert 'vms_mb' in memory_usage
        assert 'percent' in memory_usage
        
        # 验证数值合理性
        assert memory_usage['rss_mb'] > 0
        assert memory_usage['vms_mb'] > 0
        assert 0 <= memory_usage['percent'] <= 100
    
    def test_optimize_dataframe_memory(self, memory_optimizer, memory_intensive_data):
        """测试DataFrame内存优化"""
        # 记录原始内存使用
        original_memory = memory_intensive_data.memory_usage(deep=True).sum() / 1024 / 1024
        
        # 执行内存优化
        optimized_data = memory_optimizer.optimize_dataframe_memory(memory_intensive_data)
        
        # 记录优化后内存使用
        optimized_memory = optimized_data.memory_usage(deep=True).sum() / 1024 / 1024
        
        # 验证优化效果
        assert optimized_memory <= original_memory
        memory_saved = original_memory - optimized_memory
        
        print(f"内存优化效果: 原始 {original_memory:.2f}MB, 优化后 {optimized_memory:.2f}MB, 节省 {memory_saved:.2f}MB")
        
        # 验证数据完整性
        assert len(optimized_data) == len(memory_intensive_data)
        assert list(optimized_data.columns) == list(memory_intensive_data.columns)
    
    def test_batch_process_with_memory_control(self, memory_optimizer, memory_intensive_data):
        """测试带内存控制的批量处理"""
        def simple_process(batch_data):
            # 简单的处理函数
            return len(batch_data)
        
        # 执行批量处理
        results = memory_optimizer.batch_process_with_memory_control(
            memory_intensive_data, 
            simple_process, 
            batch_size=200
        )
        
        # 验证结果
        assert len(results) > 0
        assert sum(results) == len(memory_intensive_data)
    
    def test_memory_limit_check(self, memory_optimizer):
        """测试内存限制检查"""
        # 测试内存限制检查
        is_over_limit = memory_optimizer.check_memory_limit()
        
        # 验证返回类型
        assert isinstance(is_over_limit, bool)
        
        # 在正常测试环境下，通常不会超过限制
        # 这里主要测试功能是否正常工作
    
    def test_cleanup_memory(self, memory_optimizer):
        """测试内存清理"""
        # 记录清理前的内存
        before_cleanup = memory_optimizer.get_memory_usage()
        
        # 执行内存清理
        memory_optimizer.cleanup_memory()
        
        # 记录清理后的内存
        after_cleanup = memory_optimizer.get_memory_usage()
        
        # 验证清理操作执行成功（不一定会减少内存，但不应该增加）
        assert after_cleanup['rss_mb'] >= 0
        
        print(f"内存清理: 清理前 {before_cleanup['rss_mb']:.2f}MB, 清理后 {after_cleanup['rss_mb']:.2f}MB")


class TestPerformanceOptimizer:
    """性能优化器集成测试"""
    
    @pytest.fixture
    def performance_optimizer(self):
        """创建性能优化器实例"""
        config = PerformanceConfig(
            enable_vectorization=True,
            enable_parallel_processing=True,
            enable_memory_optimization=True,
            max_workers=4,
            batch_size=500
        )
        return PerformanceOptimizer(config)
    
    @pytest.fixture
    def test_stock_data(self):
        """创建测试股票数据"""
        data = []
        for i in range(1500):
            ts_code = f'60{i:04d}.SH' if i % 2 == 0 else f'00{i:04d}.SZ'
            base_price = 10.0 + i * 0.01
            
            if i % 3 == 0:
                close_price = base_price * 1.10
                high_price = close_price
            else:
                close_price = base_price * (1 + (i % 10 - 5) * 0.01)
                high_price = close_price * 1.02
            
            data.append({
                'ts_code': ts_code,
                'trade_date': '20241015',
                'open': base_price,
                'high': high_price,
                'low': base_price * 0.98,
                'close': close_price,
                'pre_close': base_price,
                'change': close_price - base_price,
                'pct_chg': (close_price - base_price) / base_price * 100,
                'vol': (i % 1000 + 1) * 1000,
                'amount': close_price * (i % 1000 + 1) * 1000,
                'name': f'股票{i:04d}'
            })
        
        return pd.DataFrame(data)
    
    @pytest.mark.asyncio
    async def test_optimize_limit_up_detection(self, performance_optimizer, test_stock_data):
        """测试涨停检测优化"""
        # 执行优化的涨停检测
        result = await performance_optimizer.optimize_limit_up_detection(test_stock_data)
        
        # 验证结果
        assert isinstance(result, pd.DataFrame)
        assert len(result) == len(test_stock_data)
        assert 'is_limit_up' in result.columns
        assert 'market' in result.columns
        assert 'confidence' in result.columns
        
        # 验证涨停检测结果
        limit_up_count = result['is_limit_up'].sum()
        assert limit_up_count > 0
    
    def test_optimize_statistics_aggregation(self, performance_optimizer, test_stock_data):
        """测试统计聚合优化"""
        # 先添加必要的列
        test_data = test_stock_data.copy()
        test_data['is_limit_up'] = test_data.index % 3 == 0  # 模拟涨停结果
        test_data['market'] = 'shanghai'  # 简化市场分类
        test_data['is_st'] = False  # 简化ST分类
        
        # 执行统计聚合优化
        stats = performance_optimizer.optimize_statistics_aggregation(test_data)
        
        # 验证结果
        assert isinstance(stats, dict)
        required_keys = ['total', 'non_st', 'shanghai', 'shenzhen', 'star', 'beijing', 'st']
        for key in required_keys:
            assert key in stats
    
    def test_get_performance_stats(self, performance_optimizer):
        """测试性能统计获取"""
        stats = performance_optimizer.get_performance_stats()
        
        # 验证统计结构
        assert isinstance(stats, dict)
        assert 'config' in stats
        assert 'memory_usage' in stats
        assert 'cpu_count' in stats
        
        # 验证配置信息
        config = stats['config']
        assert 'vectorization_enabled' in config
        assert 'parallel_processing_enabled' in config
        assert 'max_workers' in config
        
        # 验证内存使用信息
        memory_usage = stats['memory_usage']
        assert 'rss_mb' in memory_usage
        assert 'percent' in memory_usage


class TestPerformanceBenchmark:
    """性能基准测试器测试"""
    
    @pytest.fixture
    def benchmark_config(self):
        """创建基准测试配置"""
        return BenchmarkConfig(
            iterations=3,  # 减少测试时间
            warmup_iterations=1,
            data_sizes=[100, 500],  # 减少测试数据量
            concurrent_levels=[1, 3],
            detailed_logging=False
        )
    
    @pytest.fixture
    def performance_benchmark(self, benchmark_config):
        """创建性能基准测试器实例"""
        return PerformanceBenchmark(benchmark_config)
    
    def test_generate_test_data(self, performance_benchmark):
        """测试测试数据生成"""
        data_size = 100
        test_data = performance_benchmark.generate_test_data(data_size)
        
        # 验证数据结构
        assert isinstance(test_data, pd.DataFrame)
        assert len(test_data) == data_size
        
        # 验证必要的列
        required_columns = ['ts_code', 'open', 'close', 'high', 'pre_close', 'name']
        for col in required_columns:
            assert col in test_data.columns
    
    @pytest.mark.asyncio
    async def test_benchmark_limit_up_detection(self, performance_benchmark):
        """测试涨停检测基准测试"""
        # 执行基准测试
        results = await performance_benchmark.benchmark_limit_up_detection()
        
        # 验证结果
        assert isinstance(results, list)
        assert len(results) > 0
        
        for result in results:
            assert hasattr(result, 'test_name')
            assert hasattr(result, 'avg_time')
            assert hasattr(result, 'throughput')
            assert result.avg_time > 0
            assert result.throughput > 0
    
    @pytest.mark.asyncio
    async def test_benchmark_concurrent_processing(self, performance_benchmark):
        """测试并发处理基准测试"""
        # 执行并发基准测试
        results = await performance_benchmark.benchmark_concurrent_processing()
        
        # 验证结果
        assert isinstance(results, list)
        assert len(results) > 0
        
        for result in results:
            assert 'concurrent_processing' in result.test_name
            assert result.avg_time > 0
            assert result.success_rate >= 0
    
    def test_benchmark_memory_optimization(self, performance_benchmark):
        """测试内存优化基准测试"""
        # 执行内存优化基准测试
        results = performance_benchmark.benchmark_memory_optimization()
        
        # 验证结果
        assert isinstance(results, list)
        assert len(results) > 0
        
        for result in results:
            assert 'memory_optimization' in result.test_name
            assert result.avg_time > 0
    
    @pytest.mark.asyncio
    async def test_run_full_benchmark_suite(self, performance_benchmark):
        """测试完整基准测试套件"""
        # 执行完整基准测试套件
        suite_results = await performance_benchmark.run_full_benchmark_suite()
        
        # 验证结果结构
        assert isinstance(suite_results, dict)
        assert 'limit_up_detection' in suite_results
        assert 'concurrent_processing' in suite_results
        assert 'memory_optimization' in suite_results
        
        # 验证每个测试类型的结果
        for test_type, results in suite_results.items():
            assert isinstance(results, list)
            assert len(results) > 0
    
    def test_generate_performance_report(self, performance_benchmark):
        """测试性能报告生成"""
        # 先添加一些测试结果
        from quickstock.utils.performance_benchmarks import BenchmarkResult
        
        test_result = BenchmarkResult(
            test_name='test_limit_up_detection_100',
            total_time=1.0,
            avg_time=0.1,
            min_time=0.08,
            max_time=0.12,
            std_time=0.02,
            throughput=1000.0,
            memory_usage_mb=50.0,
            cpu_usage_percent=25.0,
            success_rate=1.0,
            iterations=10,
            data_size=100
        )
        
        performance_benchmark.results.append(test_result)
        
        # 生成报告
        report = performance_benchmark.generate_performance_report()
        
        # 验证报告内容
        assert isinstance(report, str)
        assert len(report) > 0
        assert '性能基准测试报告' in report
        assert 'test_limit_up_detection' in report


class TestContinuousPerformanceMonitor:
    """持续性能监控器测试"""
    
    @pytest.fixture
    def performance_monitor(self):
        """创建持续性能监控器实例"""
        return ContinuousPerformanceMonitor(monitor_interval=1)  # 1秒间隔用于测试
    
    @pytest.mark.asyncio
    async def test_performance_monitoring(self, performance_monitor):
        """测试性能监控"""
        # 启动监控（短时间）
        monitor_task = asyncio.create_task(performance_monitor.start_monitoring())
        
        # 等待几秒收集数据
        await asyncio.sleep(3)
        
        # 停止监控
        performance_monitor.stop_monitoring()
        await asyncio.sleep(0.1)  # 等待监控任务结束
        
        # 验证收集到的数据
        assert len(performance_monitor.performance_data) > 0
        
        # 验证数据结构
        sample_data = performance_monitor.performance_data[0]
        required_keys = ['timestamp', 'memory_rss_mb', 'cpu_percent', 'num_threads']
        for key in required_keys:
            assert key in sample_data
    
    def test_get_performance_summary(self, performance_monitor):
        """测试性能摘要获取"""
        # 添加一些模拟数据
        performance_monitor.performance_data = [
            {
                'timestamp': '2024-10-15T10:00:00',
                'memory_rss_mb': 100.0,
                'memory_vms_mb': 200.0,
                'memory_percent': 5.0,
                'cpu_percent': 10.0,
                'num_threads': 5,
                'open_files': 10,
                'connections': 2
            },
            {
                'timestamp': '2024-10-15T10:01:00',
                'memory_rss_mb': 110.0,
                'memory_vms_mb': 210.0,
                'memory_percent': 5.5,
                'cpu_percent': 15.0,
                'num_threads': 6,
                'open_files': 12,
                'connections': 3
            }
        ]
        
        # 获取性能摘要
        summary = performance_monitor.get_performance_summary()
        
        # 验证摘要结构
        assert isinstance(summary, dict)
        assert 'data_points' in summary
        assert 'memory_stats' in summary
        assert 'cpu_stats' in summary
        assert 'latest_snapshot' in summary
        
        # 验证统计数据
        assert summary['data_points'] == 2
        assert summary['memory_stats']['avg_mb'] == 105.0
        assert summary['cpu_stats']['avg_percent'] == 12.5


if __name__ == "__main__":
    pytest.main([__file__])