"""
涨停统计系统性能测试

测试系统在大数据集和高并发场景下的性能表现
"""

import pytest
import pytest_asyncio
import asyncio
import tempfile
import os
import time
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any
from unittest.mock import Mock, AsyncMock
import pandas as pd

from quickstock.core.database import DatabaseManager
from quickstock.core.repository import LimitUpStatsRepository
from quickstock.services.limit_up_stats_service import LimitUpStatsService
from quickstock.models import LimitUpStats, LimitUpStatsRequest
from quickstock.utils.stock_classifier import StockCodeClassifier
from quickstock.utils.limit_up_detector import LimitUpDetector


class TestLimitUpStatsPerformance:
    """涨停统计系统性能测试类"""
    
    @pytest_asyncio.fixture
    async def temp_database(self):
        """创建临时数据库"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            db_manager = DatabaseManager(db_path)
            await db_manager.initialize()
            yield db_manager
        finally:
            # 清理临时文件
            files_to_clean = [db_path, f"{db_path}-wal", f"{db_path}-shm"]
            for file_path in files_to_clean:
                if os.path.exists(file_path):
                    try:
                        os.unlink(file_path)
                    except OSError:
                        pass
    
    def generate_large_stock_dataset(self, num_stocks: int, trade_date: str) -> pd.DataFrame:
        """生成大量股票数据"""
        stock_data = []
        
        for i in range(num_stocks):
            # 生成不同市场的股票代码
            if i % 4 == 0:  # 上海主板
                ts_code = f'60{i:04d}.SH'
                market = 'shanghai'
            elif i % 4 == 1:  # 深圳主板
                ts_code = f'00{i:04d}.SZ'
                market = 'shenzhen'
            elif i % 4 == 2:  # 科创板
                ts_code = f'688{i:03d}.SH'
                market = 'star'
            else:  # 北证
                ts_code = f'43{i:04d}.BJ'
                market = 'beijing'
            
            # 模拟价格数据
            base_price = 10.0 + (i % 100) * 0.1
            
            # 30%的股票涨停
            if i % 3 == 0:
                if market == 'star':
                    close_price = base_price * 1.20  # 科创板20%
                elif market == 'beijing':
                    close_price = base_price * 1.30  # 北证30%
                else:
                    close_price = base_price * 1.10  # 普通股票10%
                high_price = close_price
                pct_chg = (close_price - base_price) / base_price * 100
            else:
                # 非涨停股票
                close_price = base_price * (1 + (i % 10 - 5) * 0.01)
                high_price = close_price * 1.02
                pct_chg = (close_price - base_price) / base_price * 100
            
            stock_data.append({
                'ts_code': ts_code,
                'trade_date': trade_date,
                'open': base_price,
                'high': high_price,
                'low': base_price * 0.98,
                'close': close_price,
                'pre_close': base_price,
                'change': close_price - base_price,
                'pct_chg': pct_chg,
                'vol': (i % 1000 + 1) * 1000,
                'amount': close_price * (i % 1000 + 1) * 1000,
                'name': f'股票{i:04d}' if i % 20 != 0 else f'ST股票{i:04d}'
            })
        
        return pd.DataFrame(stock_data)
    
    @pytest.mark.asyncio
    async def test_large_dataset_processing_performance(self, temp_database):
        """测试大数据集处理性能"""
        repository = LimitUpStatsRepository(temp_database)
        
        # 生成5000只股票的数据
        large_dataset = self.generate_large_stock_dataset(5000, '20241015')
        basic_data = pd.DataFrame([
            {'ts_code': row['ts_code'], 'name': row['name']} 
            for _, row in large_dataset.iterrows()
        ])
        
        # 模拟数据管理器
        mock_data_manager = Mock()
        mock_data_manager.get_data = AsyncMock()
        mock_data_manager.get_data.side_effect = [large_dataset, basic_data]
        
        service = LimitUpStatsService(mock_data_manager)
        
        # 测试处理性能
        start_time = time.time()
        
        request = LimitUpStatsRequest(trade_date='20241015')
        result = await service.get_daily_limit_up_stats(request)
        
        processing_time = time.time() - start_time
        
        # 验证结果
        assert isinstance(result, LimitUpStats)
        assert result.total > 0
        
        # 性能要求：5000只股票应该在30秒内处理完成
        assert processing_time < 30.0, f"Processing took {processing_time:.2f}s, expected < 30s"
        
        # 计算每只股票的平均处理时间
        avg_time_per_stock = processing_time / 5000
        assert avg_time_per_stock < 0.006, f"Average time per stock: {avg_time_per_stock:.6f}s"
        
        print(f"Large dataset performance: {processing_time:.2f}s for 5000 stocks")
        print(f"Average time per stock: {avg_time_per_stock:.6f}s")
    
    @pytest.mark.asyncio
    async def test_database_batch_operations_performance(self, temp_database):
        """测试数据库批量操作性能"""
        repository = LimitUpStatsRepository(temp_database)
        
        # 创建1000个交易日的数据
        batch_stats = []
        for i in range(1000):
            stats = LimitUpStats(
                trade_date=f'20241{i:03d}',
                total=100 + i % 100,
                non_st=90 + i % 100,
                shanghai=30 + i % 20,
                shenzhen=40 + i % 25,
                star=20 + i % 15,
                beijing=10 + i % 10,
                st=10,
                limit_up_stocks=[f'stock_{j}.SZ' for j in range(100 + i % 100)],
                market_breakdown={}
            )
            batch_stats.append(stats)
        
        # 测试批量保存性能
        start_time = time.time()
        result = await repository.batch_save_stats(batch_stats)
        save_time = time.time() - start_time
        
        # 验证保存结果
        assert result['success'] == 1000
        assert result['failed'] == 0
        
        # 性能要求：1000条记录应该在15秒内保存完成
        assert save_time < 15.0, f"Batch save took {save_time:.2f}s, expected < 15s"
        
        # 测试批量查询性能
        start_time = time.time()
        all_stats = await repository.query_limit_up_stats()
        query_time = time.time() - start_time
        
        # 验证查询结果
        assert len(all_stats) == 1000
        
        # 性能要求：查询1000条记录应该在3秒内完成
        assert query_time < 3.0, f"Batch query took {query_time:.2f}s, expected < 3s"
        
        # 测试范围查询性能
        start_time = time.time()
        range_stats = await repository.query_limit_up_stats(
            start_date='20241100',
            end_date='20241200'
        )
        range_query_time = time.time() - start_time
        
        # 性能要求：范围查询应该在1秒内完成
        assert range_query_time < 1.0, f"Range query took {range_query_time:.2f}s, expected < 1s"
        
        print(f"Database performance:")
        print(f"  Batch save (1000 records): {save_time:.2f}s")
        print(f"  Batch query (1000 records): {query_time:.2f}s")
        print(f"  Range query: {range_query_time:.2f}s")
    
    @pytest.mark.asyncio
    async def test_concurrent_requests_performance(self, temp_database):
        """测试并发请求性能"""
        repository = LimitUpStatsRepository(temp_database)
        
        # 创建测试数据
        test_dataset = self.generate_large_stock_dataset(1000, '20241015')
        basic_data = pd.DataFrame([
            {'ts_code': row['ts_code'], 'name': row['name']} 
            for _, row in test_dataset.iterrows()
        ])
        
        # 模拟数据管理器
        mock_data_manager = Mock()
        mock_data_manager.get_data = AsyncMock()
        mock_data_manager.get_data.side_effect = lambda *args: (
            test_dataset if args[0].data_type == 'stock_daily' 
            else basic_data
        )
        
        service = LimitUpStatsService(mock_data_manager)
        
        # 并发处理任务
        async def process_request(date_suffix):
            request = LimitUpStatsRequest(trade_date=f'2024101{date_suffix}')
            result = await service.get_daily_limit_up_stats(request)
            await repository.save_limit_up_stats(result)
            return result
        
        # 测试10个并发请求
        start_time = time.time()
        
        tasks = [process_request(i) for i in range(10)]
        results = await asyncio.gather(*tasks)
        
        concurrent_time = time.time() - start_time
        
        # 验证结果
        assert len(results) == 10
        assert all(isinstance(result, LimitUpStats) for result in results)
        
        # 性能要求：10个并发请求应该在60秒内完成
        assert concurrent_time < 60.0, f"Concurrent processing took {concurrent_time:.2f}s, expected < 60s"
        
        # 验证数据库中的数据
        all_dates = await repository.list_available_dates()
        assert len(all_dates) == 10
        
        print(f"Concurrent performance: {concurrent_time:.2f}s for 10 concurrent requests")
        print(f"Average time per concurrent request: {concurrent_time/10:.2f}s")
    
    def test_stock_classifier_performance(self):
        """测试股票分类器性能"""
        classifier = StockCodeClassifier()
        
        # 生成大量股票代码
        stock_codes = []
        for i in range(10000):
            if i % 4 == 0:
                stock_codes.append(f'60{i:04d}.SH')
            elif i % 4 == 1:
                stock_codes.append(f'00{i:04d}.SZ')
            elif i % 4 == 2:
                stock_codes.append(f'688{i:03d}.SH')
            else:
                stock_codes.append(f'43{i:04d}.BJ')
        
        # 测试分类性能
        start_time = time.time()
        
        results = []
        for code in stock_codes:
            try:
                market = classifier.classify_market(code)
                results.append(market)
            except Exception:
                results.append('unknown')
        
        classification_time = time.time() - start_time
        
        # 验证结果
        assert len(results) == 10000
        
        # 性能要求：10000个股票代码应该在5秒内分类完成
        assert classification_time < 5.0, f"Classification took {classification_time:.2f}s, expected < 5s"
        
        # 计算每个代码的平均分类时间
        avg_time_per_code = classification_time / 10000
        assert avg_time_per_code < 0.0005, f"Average time per code: {avg_time_per_code:.6f}s"
        
        print(f"Stock classifier performance:")
        print(f"  Total time for 10000 codes: {classification_time:.2f}s")
        print(f"  Average time per code: {avg_time_per_code:.6f}s")
    
    def test_limit_up_detector_performance(self):
        """测试涨停检测器性能"""
        detector = LimitUpDetector()
        
        # 生成大量价格数据
        price_data = []
        for i in range(10000):
            base_price = 10.0 + i * 0.001
            if i % 3 == 0:  # 涨停股票
                close_price = base_price * 1.10
                high_price = close_price
            else:  # 非涨停股票
                close_price = base_price * (1 + (i % 10 - 5) * 0.01)
                high_price = close_price * 1.02
            
            price_data.append({
                'open': base_price,
                'close': close_price,
                'high': high_price,
                'prev_close': base_price
            })
        
        # 测试检测性能
        start_time = time.time()
        
        results = []
        for data in price_data:
            is_limit_up = detector.is_limit_up(
                open_price=data['open'],
                close_price=data['close'],
                high_price=data['high'],
                prev_close=data['prev_close'],
                stock_type='normal'
            )
            results.append(is_limit_up)
        
        detection_time = time.time() - start_time
        
        # 验证结果
        assert len(results) == 10000
        limit_up_count = sum(results)
        assert limit_up_count > 0  # 应该有涨停股票
        
        # 性能要求：10000次检测应该在3秒内完成
        assert detection_time < 3.0, f"Detection took {detection_time:.2f}s, expected < 3s"
        
        # 计算每次检测的平均时间
        avg_time_per_detection = detection_time / 10000
        assert avg_time_per_detection < 0.0003, f"Average time per detection: {avg_time_per_detection:.6f}s"
        
        print(f"Limit up detector performance:")
        print(f"  Total time for 10000 detections: {detection_time:.2f}s")
        print(f"  Average time per detection: {avg_time_per_detection:.6f}s")
        print(f"  Detected {limit_up_count} limit up stocks")
    
    @pytest.mark.asyncio
    async def test_memory_usage_performance(self, temp_database):
        """测试内存使用性能"""
        import psutil
        import os
        
        repository = LimitUpStatsRepository(temp_database)
        process = psutil.Process(os.getpid())
        
        # 记录初始内存使用
        initial_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        # 创建大量数据
        large_stats_list = []
        for i in range(1000):
            stats = LimitUpStats(
                trade_date=f'20241{i:03d}',
                total=1000,  # 大数量
                non_st=900,
                shanghai=300,
                shenzhen=400,
                star=200,
                beijing=100,
                st=100,
                limit_up_stocks=[f'stock_{j}.SZ' for j in range(1000)],  # 大列表
                market_breakdown={
                    'shanghai': [f'600{j:03d}.SH' for j in range(300)],
                    'shenzhen': [f'000{j:03d}.SZ' for j in range(400)],
                    'star': [f'688{j:03d}.SH' for j in range(200)],
                    'beijing': [f'430{j:03d}.BJ' for j in range(100)]
                }
            )
            large_stats_list.append(stats)
        
        # 记录数据创建后的内存使用
        after_creation_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        # 批量保存数据
        await repository.batch_save_stats(large_stats_list)
        
        # 记录保存后的内存使用
        after_save_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        # 查询所有数据
        all_stats = await repository.query_limit_up_stats()
        
        # 记录查询后的内存使用
        after_query_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        # 清理数据
        del large_stats_list
        del all_stats
        
        # 记录清理后的内存使用
        after_cleanup_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        # 验证内存使用合理
        creation_memory_increase = after_creation_memory - initial_memory
        save_memory_increase = after_save_memory - after_creation_memory
        query_memory_increase = after_query_memory - after_save_memory
        
        # 内存使用应该在合理范围内
        assert creation_memory_increase < 500, f"Data creation used {creation_memory_increase:.2f}MB"
        assert save_memory_increase < 100, f"Data save used {save_memory_increase:.2f}MB"
        assert query_memory_increase < 500, f"Data query used {query_memory_increase:.2f}MB"
        
        print(f"Memory usage performance:")
        print(f"  Initial memory: {initial_memory:.2f}MB")
        print(f"  After data creation: {after_creation_memory:.2f}MB (+{creation_memory_increase:.2f}MB)")
        print(f"  After data save: {after_save_memory:.2f}MB (+{save_memory_increase:.2f}MB)")
        print(f"  After data query: {after_query_memory:.2f}MB (+{query_memory_increase:.2f}MB)")
        print(f"  After cleanup: {after_cleanup_memory:.2f}MB")
    
    @pytest.mark.asyncio
    async def test_database_connection_pool_performance(self, temp_database):
        """测试数据库连接池性能"""
        repository = LimitUpStatsRepository(temp_database)
        
        # 并发数据库操作
        async def db_operation(operation_id):
            stats = LimitUpStats(
                trade_date=f'2024101{operation_id}',
                total=100,
                non_st=90,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[f'stock_{operation_id}.SZ'],
                market_breakdown={'shenzhen': [f'stock_{operation_id}.SZ']}
            )
            
            # 保存数据
            await repository.save_limit_up_stats(stats)
            
            # 查询数据
            retrieved = await repository.get_limit_up_stats(f'2024101{operation_id}')
            
            # 删除数据
            await repository.delete_limit_up_stats(f'2024101{operation_id}')
            
            return retrieved is not None
        
        # 测试50个并发数据库操作
        start_time = time.time()
        
        tasks = [db_operation(i) for i in range(50)]
        results = await asyncio.gather(*tasks)
        
        pool_time = time.time() - start_time
        
        # 验证结果
        assert all(results)
        
        # 性能要求：50个并发数据库操作应该在10秒内完成
        assert pool_time < 10.0, f"Connection pool operations took {pool_time:.2f}s, expected < 10s"
        
        print(f"Database connection pool performance:")
        print(f"  50 concurrent operations: {pool_time:.2f}s")
        print(f"  Average time per operation: {pool_time/50:.2f}s")


class TestLimitUpStatsStressTest:
    """涨停统计系统压力测试"""
    
    @pytest_asyncio.fixture
    async def temp_database(self):
        """创建临时数据库"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            db_manager = DatabaseManager(db_path)
            await db_manager.initialize()
            yield db_manager
        finally:
            # 清理临时文件
            files_to_clean = [db_path, f"{db_path}-wal", f"{db_path}-shm"]
            for file_path in files_to_clean:
                if os.path.exists(file_path):
                    try:
                        os.unlink(file_path)
                    except OSError:
                        pass
    
    @pytest.mark.asyncio
    async def test_extreme_load_stress_test(self, temp_database):
        """测试极限负载压力"""
        repository = LimitUpStatsRepository(temp_database)
        
        # 创建极大数据集（模拟A股全市场5000只股票）
        extreme_stats_list = []
        for i in range(100):  # 100个交易日
            stats = LimitUpStats(
                trade_date=f'20241{i:03d}',
                total=5000,  # 全市场股票数量
                non_st=4500,
                shanghai=2000,
                shenzhen=1800,
                star=800,
                beijing=400,
                st=500,
                limit_up_stocks=[f'stock_{j}.SZ' for j in range(5000)],
                market_breakdown={
                    'shanghai': [f'600{j:03d}.SH' for j in range(2000)],
                    'shenzhen': [f'000{j:03d}.SZ' for j in range(1800)],
                    'star': [f'688{j:03d}.SH' for j in range(800)],
                    'beijing': [f'430{j:03d}.BJ' for j in range(400)]
                }
            )
            extreme_stats_list.append(stats)
        
        # 测试极限保存性能
        start_time = time.time()
        
        # 分批保存以避免内存问题
        batch_size = 10
        for i in range(0, len(extreme_stats_list), batch_size):
            batch = extreme_stats_list[i:i+batch_size]
            result = await repository.batch_save_stats(batch)
            assert result['success'] == len(batch)
        
        extreme_save_time = time.time() - start_time
        
        # 性能要求：极限数据应该在合理时间内保存完成
        assert extreme_save_time < 120.0, f"Extreme save took {extreme_save_time:.2f}s, expected < 120s"
        
        # 测试极限查询性能
        start_time = time.time()
        all_stats = await repository.query_limit_up_stats()
        extreme_query_time = time.time() - start_time
        
        # 验证查询结果
        assert len(all_stats) == 100
        
        # 性能要求：极限查询应该在合理时间内完成
        assert extreme_query_time < 30.0, f"Extreme query took {extreme_query_time:.2f}s, expected < 30s"
        
        print(f"Extreme load stress test:")
        print(f"  Save 100 days × 5000 stocks: {extreme_save_time:.2f}s")
        print(f"  Query 100 records: {extreme_query_time:.2f}s")
    
    @pytest.mark.asyncio
    async def test_sustained_load_stress_test(self, temp_database):
        """测试持续负载压力"""
        repository = LimitUpStatsRepository(temp_database)
        
        # 模拟持续1小时的负载（每秒1个请求）
        # 为了测试速度，这里模拟100个请求
        sustained_operations = 100
        
        async def sustained_operation(op_id):
            stats = LimitUpStats(
                trade_date=f'20241{op_id:03d}',
                total=100 + op_id % 50,
                non_st=90 + op_id % 50,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[f'stock_{j}.SZ' for j in range(100 + op_id % 50)],
                market_breakdown={}
            )
            
            # 保存
            await repository.save_limit_up_stats(stats)
            
            # 查询
            retrieved = await repository.get_limit_up_stats(f'20241{op_id:03d}')
            
            # 模拟处理延迟
            await asyncio.sleep(0.01)
            
            return retrieved is not None
        
        # 执行持续负载测试
        start_time = time.time()
        
        # 分批执行以模拟持续负载
        batch_size = 10
        all_results = []
        
        for i in range(0, sustained_operations, batch_size):
            batch_tasks = [
                sustained_operation(i + j) 
                for j in range(min(batch_size, sustained_operations - i))
            ]
            batch_results = await asyncio.gather(*batch_tasks)
            all_results.extend(batch_results)
            
            # 短暂休息模拟真实负载间隔
            await asyncio.sleep(0.1)
        
        sustained_time = time.time() - start_time
        
        # 验证结果
        assert len(all_results) == sustained_operations
        assert all(all_results)
        
        # 性能要求：持续负载应该稳定处理
        avg_time_per_operation = sustained_time / sustained_operations
        assert avg_time_per_operation < 1.0, f"Average operation time: {avg_time_per_operation:.3f}s"
        
        print(f"Sustained load stress test:")
        print(f"  {sustained_operations} operations: {sustained_time:.2f}s")
        print(f"  Average time per operation: {avg_time_per_operation:.3f}s")
        print(f"  Operations per second: {sustained_operations/sustained_time:.2f}")


if __name__ == "__main__":
    pytest.main([__file__])