"""
数据库操作层测试

测试LimitUpStatsRepository的CRUD操作、批量操作和事务支持
"""

import pytest
import pytest_asyncio
import asyncio
import tempfile
import os
from datetime import datetime
from typing import List

from quickstock.core.repository import LimitUpStatsRepository
from quickstock.core.database import DatabaseManager
from quickstock.models import LimitUpStats, StockDailyData
from quickstock.core.errors import DatabaseError, ValidationError


class TestLimitUpStatsRepository:
    """涨停统计Repository测试类"""
    
    @pytest_asyncio.fixture
    async def temp_repository(self):
        """创建临时Repository实例"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            db_manager = DatabaseManager(db_path)
            await db_manager.initialize()
            repository = LimitUpStatsRepository(db_manager)
            yield repository
        finally:
            # 清理临时文件
            if os.path.exists(db_path):
                os.unlink(db_path)
            # 清理WAL和SHM文件
            for suffix in ['-wal', '-shm']:
                wal_file = db_path + suffix
                if os.path.exists(wal_file):
                    os.unlink(wal_file)
    
    @pytest.fixture
    def sample_limit_up_stats(self):
        """创建示例涨停统计数据"""
        return LimitUpStats(
            trade_date='20241015',
            total=100,
            non_st=90,
            shanghai=30,
            shenzhen=40,
            star=20,
            beijing=10,
            st=10,
            limit_up_stocks=['000001.SZ', '600000.SH', '688001.SH'],
            market_breakdown={
                'shenzhen': ['000001.SZ'],
                'shanghai': ['600000.SH'],
                'star': ['688001.SH']
            }
        )
    
    @pytest.fixture
    def sample_stock_details(self):
        """创建示例股票详细数据"""
        return [
            StockDailyData(
                ts_code='000001.SZ',
                trade_date='20241015',
                open=10.0,
                high=11.0,
                low=9.8,
                close=11.0,
                pre_close=10.0,
                change=1.0,
                pct_chg=10.0,
                vol=1000000,
                amount=10500000.0,
                name='平安银行'
            ),
            StockDailyData(
                ts_code='600000.SH',
                trade_date='20241015',
                open=8.0,
                high=8.8,
                low=7.9,
                close=8.8,
                pre_close=8.0,
                change=0.8,
                pct_chg=10.0,
                vol=800000,
                amount=6800000.0,
                name='浦发银行'
            ),
            StockDailyData(
                ts_code='688001.SH',
                trade_date='20241015',
                open=20.0,
                high=24.0,
                low=19.5,
                close=24.0,
                pre_close=20.0,
                change=4.0,
                pct_chg=20.0,
                vol=500000,
                amount=11000000.0,
                name='华兴源创'
            )
        ]
    
    @pytest.mark.asyncio
    async def test_save_and_get_limit_up_stats(self, temp_repository, sample_limit_up_stats):
        """测试保存和获取涨停统计数据"""
        repository = temp_repository
        stats = sample_limit_up_stats
        
        # 保存数据
        success = await repository.save_limit_up_stats(stats)
        assert success is True
        
        # 获取数据
        retrieved_stats = await repository.get_limit_up_stats(stats.trade_date)
        assert retrieved_stats is not None
        
        # 验证数据一致性
        assert retrieved_stats.trade_date == stats.trade_date
        assert retrieved_stats.total == stats.total
        assert retrieved_stats.non_st == stats.non_st
        assert retrieved_stats.shanghai == stats.shanghai
        assert retrieved_stats.shenzhen == stats.shenzhen
        assert retrieved_stats.star == stats.star
        assert retrieved_stats.beijing == stats.beijing
        assert retrieved_stats.st == stats.st
        
        # 验证时间戳
        assert retrieved_stats.created_at is not None
        assert retrieved_stats.updated_at is not None
    
    @pytest.mark.asyncio
    async def test_save_with_stock_details(self, temp_repository, sample_limit_up_stats, sample_stock_details):
        """测试保存涨停统计数据和股票详细数据"""
        repository = temp_repository
        stats = sample_limit_up_stats
        stock_details = sample_stock_details
        
        # 保存数据（包含股票详细数据）
        success = await repository.save_limit_up_stats(stats, stock_details)
        assert success is True
        
        # 获取数据
        retrieved_stats = await repository.get_limit_up_stats(stats.trade_date)
        assert retrieved_stats is not None
        
        # 验证涨停股票列表
        assert len(retrieved_stats.limit_up_stocks) == len(stock_details)
        for stock in stock_details:
            assert stock.ts_code in retrieved_stats.limit_up_stocks
        
        # 验证市场分类
        assert 'shenzhen' in retrieved_stats.market_breakdown
        assert 'shanghai' in retrieved_stats.market_breakdown
        assert 'star' in retrieved_stats.market_breakdown
        assert '000001.SZ' in retrieved_stats.market_breakdown['shenzhen']
        assert '600000.SH' in retrieved_stats.market_breakdown['shanghai']
        assert '688001.SH' in retrieved_stats.market_breakdown['star']
    
    @pytest.mark.asyncio
    async def test_update_existing_stats(self, temp_repository, sample_limit_up_stats):
        """测试更新现有统计数据"""
        repository = temp_repository
        stats = sample_limit_up_stats
        
        # 首次保存
        await repository.save_limit_up_stats(stats)
        
        # 修改数据
        updated_stats = LimitUpStats(
            trade_date=stats.trade_date,
            total=120,
            non_st=110,
            shanghai=35,
            shenzhen=45,
            star=25,
            beijing=15,
            st=10,
            limit_up_stocks=['000001.SZ', '600000.SH', '688001.SH', '000002.SZ'],
            market_breakdown={
                'shenzhen': ['000001.SZ', '000002.SZ'],
                'shanghai': ['600000.SH'],
                'star': ['688001.SH']
            }
        )
        
        # 更新数据
        success = await repository.save_limit_up_stats(updated_stats)
        assert success is True
        
        # 验证更新结果
        retrieved_stats = await repository.get_limit_up_stats(stats.trade_date)
        assert retrieved_stats.total == 120
        assert retrieved_stats.shanghai == 35
        assert retrieved_stats.shenzhen == 45
    
    @pytest.mark.asyncio
    async def test_query_limit_up_stats(self, temp_repository):
        """测试查询涨停统计数据"""
        repository = temp_repository
        
        # 创建多个日期的测试数据
        test_dates = ['20241010', '20241011', '20241012', '20241015']
        for i, date in enumerate(test_dates):
            stats = LimitUpStats(
                trade_date=date,
                total=100 + i * 10,
                non_st=90 + i * 10,
                shanghai=30 + i * 2,
                shenzhen=40 + i * 3,
                star=20 + i * 2,
                beijing=10 + i * 3,
                st=10,
                limit_up_stocks=[f'00000{i}.SZ'],
                market_breakdown={'shenzhen': [f'00000{i}.SZ']}
            )
            await repository.save_limit_up_stats(stats)
        
        # 测试查询所有数据
        all_stats = await repository.query_limit_up_stats()
        assert len(all_stats) == 4
        
        # 验证按日期降序排列
        dates = [stats.trade_date for stats in all_stats]
        assert dates == sorted(dates, reverse=True)
        
        # 测试日期范围查询
        range_stats = await repository.query_limit_up_stats(
            start_date='20241011',
            end_date='20241012'
        )
        assert len(range_stats) == 2
        assert all(stats.trade_date in ['20241011', '20241012'] for stats in range_stats)
        
        # 测试限制数量查询
        limited_stats = await repository.query_limit_up_stats(limit=2)
        assert len(limited_stats) == 2
        
        # 测试偏移查询
        offset_stats = await repository.query_limit_up_stats(limit=2, offset=1)
        assert len(offset_stats) == 2
        assert offset_stats[0].trade_date != all_stats[0].trade_date
    
    @pytest.mark.asyncio
    async def test_delete_limit_up_stats(self, temp_repository, sample_limit_up_stats, sample_stock_details):
        """测试删除涨停统计数据"""
        repository = temp_repository
        stats = sample_limit_up_stats
        
        # 保存数据（包含股票详细数据）
        await repository.save_limit_up_stats(stats, sample_stock_details)
        
        # 验证数据存在
        retrieved_stats = await repository.get_limit_up_stats(stats.trade_date)
        assert retrieved_stats is not None
        
        # 删除数据
        success = await repository.delete_limit_up_stats(stats.trade_date)
        assert success is True
        
        # 验证数据已删除
        deleted_stats = await repository.get_limit_up_stats(stats.trade_date)
        assert deleted_stats is None
        
        # 测试删除不存在的数据
        success = await repository.delete_limit_up_stats('20241001')
        assert success is False
    
    @pytest.mark.asyncio
    async def test_list_available_dates(self, temp_repository):
        """测试列出可用日期"""
        repository = temp_repository
        
        # 创建测试数据
        test_dates = ['20241010', '20241015', '20241020']
        for date in test_dates:
            stats = LimitUpStats(
                trade_date=date,
                total=100,
                non_st=90,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[],
                market_breakdown={}
            )
            await repository.save_limit_up_stats(stats)
        
        # 测试列出所有日期
        all_dates = await repository.list_available_dates()
        assert len(all_dates) == 3
        assert set(all_dates) == set(test_dates)
        
        # 验证按降序排列
        assert all_dates == sorted(all_dates, reverse=True)
        
        # 测试日期范围过滤
        filtered_dates = await repository.list_available_dates(
            start_date='20241012',
            end_date='20241018'
        )
        assert filtered_dates == ['20241015']
    
    @pytest.mark.asyncio
    async def test_get_database_stats(self, temp_repository, sample_limit_up_stats):
        """测试获取数据库统计信息"""
        repository = temp_repository
        
        # 获取初始统计信息
        initial_stats = await repository.get_database_stats()
        assert 'tables' in initial_stats
        assert initial_stats['tables']['limit_up_stats'] == 0
        
        # 添加一些数据
        await repository.save_limit_up_stats(sample_limit_up_stats)
        
        # 获取更新后的统计信息
        updated_stats = await repository.get_database_stats()
        assert updated_stats['tables']['limit_up_stats'] == 1
        
        # 验证其他统计信息
        assert 'date_range' in updated_stats
        assert 'averages' in updated_stats
        assert 'last_update' in updated_stats
        
        assert updated_stats['date_range']['earliest'] == sample_limit_up_stats.trade_date
        assert updated_stats['date_range']['latest'] == sample_limit_up_stats.trade_date
        assert updated_stats['averages']['total'] == sample_limit_up_stats.total
    
    @pytest.mark.asyncio
    async def test_batch_save_stats(self, temp_repository):
        """测试批量保存统计数据"""
        repository = temp_repository
        
        # 创建批量测试数据
        stats_list = []
        for i in range(5):
            stats = LimitUpStats(
                trade_date=f'2024101{i}',
                total=100 + i * 10,
                non_st=90 + i * 10,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[],
                market_breakdown={}
            )
            stats_list.append(stats)
        
        # 执行批量保存
        result = await repository.batch_save_stats(stats_list)
        
        # 验证结果
        assert result['total'] == 5
        assert result['success'] == 5
        assert result['failed'] == 0
        assert len(result['errors']) == 0
        
        # 验证数据已保存
        all_dates = await repository.list_available_dates()
        assert len(all_dates) == 5
    
    @pytest.mark.asyncio
    async def test_batch_save_with_errors(self, temp_repository):
        """测试批量保存时的错误处理"""
        repository = temp_repository
        
        # 创建包含无效数据的批量测试数据
        stats_list = []
        
        # 有效数据
        valid_stats = LimitUpStats(
            trade_date='20241010',
            total=100,
            non_st=90,
            shanghai=30,
            shenzhen=40,
            star=20,
            beijing=10,
            st=10,
            limit_up_stocks=[],
            market_breakdown={}
        )
        stats_list.append(valid_stats)
        
        # 无效数据（违反一致性约束）
        invalid_stats = LimitUpStats(
            trade_date='20241011',
            total=100,
            non_st=90,
            shanghai=30,
            shenzhen=40,
            star=20,
            beijing=5,  # 总和不等于total
            st=10,
            limit_up_stocks=[],
            market_breakdown={}
        )
        # 手动设置无效数据以绕过验证
        invalid_stats.beijing = 5
        stats_list.append(invalid_stats)
        
        # 执行批量保存
        result = await repository.batch_save_stats(stats_list)
        
        # 验证结果
        assert result['total'] == 2
        assert result['success'] == 1  # 只有有效数据成功
        assert result['failed'] == 1   # 无效数据失败
        assert len(result['errors']) == 1
    
    @pytest.mark.asyncio
    async def test_batch_delete_stats(self, temp_repository):
        """测试批量删除统计数据"""
        repository = temp_repository
        
        # 创建测试数据
        test_dates = ['20241010', '20241011', '20241012']
        for date in test_dates:
            stats = LimitUpStats(
                trade_date=date,
                total=100,
                non_st=90,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[],
                market_breakdown={}
            )
            await repository.save_limit_up_stats(stats)
        
        # 执行批量删除
        delete_dates = ['20241010', '20241012', '20241099']  # 包含不存在的日期
        result = await repository.batch_delete_stats(delete_dates)
        
        # 验证结果
        assert result['total'] == 3
        assert result['success'] == 2  # 两个存在的日期成功删除
        assert result['failed'] == 1   # 一个不存在的日期失败
        
        # 验证数据已删除
        remaining_dates = await repository.list_available_dates()
        assert remaining_dates == ['20241011']
    
    @pytest.mark.asyncio
    async def test_transaction_rollback(self, temp_repository):
        """测试事务回滚"""
        repository = temp_repository
        
        # 创建会导致约束违反的数据
        invalid_stats = LimitUpStats(
            trade_date='20241015',
            total=100,
            non_st=90,
            shanghai=30,
            shenzhen=40,
            star=20,
            beijing=5,  # 违反总和约束
            st=10,
            limit_up_stocks=[],
            market_breakdown={}
        )
        
        # 手动修改以绕过模型验证
        invalid_stats.beijing = 5
        
        # 尝试保存应该失败
        with pytest.raises(DatabaseError):
            await repository.save_limit_up_stats(invalid_stats)
        
        # 验证没有数据被保存（事务已回滚）
        stats = await repository.get_limit_up_stats('20241015')
        assert stats is None
    
    @pytest.mark.asyncio
    async def test_data_validation(self, temp_repository):
        """测试数据验证"""
        repository = temp_repository
        
        # 测试无效的交易日期格式
        with pytest.raises(ValueError):
            invalid_stats = LimitUpStats(
                trade_date='invalid-date',
                total=100,
                non_st=90,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[],
                market_breakdown={}
            )
        
        # 测试负数值
        with pytest.raises(ValueError):
            invalid_stats = LimitUpStats(
                trade_date='20241015',
                total=-100,  # 负数
                non_st=90,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[],
                market_breakdown={}
            )
    
    @pytest.mark.asyncio
    async def test_concurrent_operations(self, temp_repository):
        """测试并发操作"""
        repository = temp_repository
        
        # 创建多个并发保存任务
        async def save_stats(date_suffix: int):
            stats = LimitUpStats(
                trade_date=f'2024101{date_suffix}',
                total=100,
                non_st=90,
                shanghai=30,
                shenzhen=40,
                star=20,
                beijing=10,
                st=10,
                limit_up_stocks=[],
                market_breakdown={}
            )
            return await repository.save_limit_up_stats(stats)
        
        # 并发执行保存操作
        tasks = [save_stats(i) for i in range(5)]
        results = await asyncio.gather(*tasks)
        
        # 验证所有操作都成功
        assert all(results)
        
        # 验证数据完整性
        all_dates = await repository.list_available_dates()
        assert len(all_dates) == 5


if __name__ == '__main__':
    pytest.main([__file__])