"""
数据库管理器全面测试

为DatabaseManager提供全面的单元测试覆盖，包括连接管理、事务处理、错误恢复等
"""

import pytest
import pytest_asyncio
import asyncio
import tempfile
import os
import sqlite3
from pathlib import Path
from datetime import datetime
from unittest.mock import Mock, patch, MagicMock

from quickstock.core.database import DatabaseManager, get_database_manager, initialize_database
from quickstock.core.errors import DatabaseError


class TestDatabaseManagerComprehensive:
    """数据库管理器全面测试类"""
    
    @pytest_asyncio.fixture
    async def temp_db_manager(self):
        """创建临时数据库管理器"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            db_manager = DatabaseManager(db_path)
            await db_manager.initialize()
            yield db_manager
        finally:
            # 清理临时文件
            await self._cleanup_db_files(db_path)
    
    async def _cleanup_db_files(self, db_path):
        """清理数据库相关文件"""
        files_to_clean = [db_path, f"{db_path}-wal", f"{db_path}-shm"]
        for file_path in files_to_clean:
            if os.path.exists(file_path):
                try:
                    os.unlink(file_path)
                except OSError:
                    pass  # 忽略清理错误
    
    @pytest.mark.asyncio
    async def test_database_initialization_comprehensive(self, temp_db_manager):
        """测试数据库初始化的全面场景"""
        db_manager = temp_db_manager
        
        # 验证数据库文件存在
        assert os.path.exists(db_manager.db_path)
        
        # 验证数据库版本
        async with db_manager.get_connection() as db:
            cursor = await db.execute("SELECT value FROM database_info WHERE key = 'version'")
            version = await cursor.fetchone()
            assert version is not None
            assert int(version[0]) == db_manager.CURRENT_VERSION
        
        # 验证所有表都存在且结构正确
        validation_result = await db_manager.validate_schema()
        assert validation_result['valid'] is True
        assert len(validation_result['errors']) == 0
        
        # 验证表结构
        expected_tables = ['limit_up_stats', 'limit_up_stocks', 'database_info']
        for table in expected_tables:
            assert validation_result['tables'][table] is True
        
        # 验证索引
        expected_indexes = [
            'idx_limit_up_stats_date',
            'idx_limit_up_stocks_date',
            'idx_limit_up_stocks_code',
            'idx_limit_up_stocks_market',
            'idx_limit_up_stocks_is_st',
            'idx_limit_up_stocks_unique'
        ]
        for index in expected_indexes:
            assert validation_result['indexes'][index] is True
        
        # 验证触发器
        expected_triggers = [
            'update_limit_up_stats_timestamp',
            'update_database_info_timestamp'
        ]
        for trigger in expected_triggers:
            assert validation_result['triggers'][trigger] is True
    
    @pytest.mark.asyncio
    async def test_connection_pool_management(self, temp_db_manager):
        """测试连接池管理"""
        db_manager = temp_db_manager
        
        # 测试多个并发连接
        connections = []
        
        async def get_connection_worker():
            async with db_manager.get_connection() as db:
                connections.append(db)
                # 执行简单查询
                cursor = await db.execute("SELECT 1")
                result = await cursor.fetchone()
                assert result[0] == 1
                await asyncio.sleep(0.1)  # 模拟工作
        
        # 创建多个并发连接任务
        tasks = [get_connection_worker() for _ in range(10)]
        await asyncio.gather(*tasks)
        
        # 验证连接都已正确关闭
        assert len(connections) == 10
    
    @pytest.mark.asyncio
    async def test_transaction_handling_comprehensive(self, temp_db_manager):
        """测试事务处理的全面场景"""
        db_manager = temp_db_manager
        
        # 测试成功事务
        async with db_manager.get_connection() as db:
            await db.execute("BEGIN TRANSACTION")
            try:
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', 100, 90, 30, 40, 20, 10, 10)
                )
                await db.execute("COMMIT")
            except Exception:
                await db.execute("ROLLBACK")
                raise
        
        # 验证数据已提交
        async with db_manager.get_connection() as db:
            cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stats WHERE trade_date = '20241015'")
            count = await cursor.fetchone()
            assert count[0] == 1
        
        # 测试回滚事务
        async with db_manager.get_connection() as db:
            await db.execute("BEGIN TRANSACTION")
            try:
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241016', 50, 45, 15, 20, 10, 5, 5)
                )
                # 故意引发错误
                await db.execute("INSERT INTO invalid_table VALUES (1)")
            except Exception:
                await db.execute("ROLLBACK")
        
        # 验证数据已回滚
        async with db_manager.get_connection() as db:
            cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stats WHERE trade_date = '20241016'")
            count = await cursor.fetchone()
            assert count[0] == 0
    
    @pytest.mark.asyncio
    async def test_constraint_validation_comprehensive(self, temp_db_manager):
        """测试约束验证的全面场景"""
        db_manager = temp_db_manager
        
        async with db_manager.get_connection() as db:
            # 测试日期格式约束
            with pytest.raises(Exception):  # 应该违反CHECK约束
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('2024-10-15', 100, 90, 30, 40, 20, 10, 10)  # 错误的日期格式
                )
                await db.commit()
            
            # 测试总数一致性约束
            with pytest.raises(Exception):  # 应该违反CHECK约束
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', 100, 90, 30, 40, 20, 5, 10)  # total != shanghai + shenzhen + star + beijing
                )
                await db.commit()
            
            # 测试ST一致性约束
            with pytest.raises(Exception):  # 应该违反CHECK约束
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', 100, 80, 30, 40, 20, 10, 15)  # total != st + non_st
                )
                await db.commit()
            
            # 测试负数约束
            with pytest.raises(Exception):  # 应该违反CHECK约束
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', -1, 90, 30, 40, 20, 10, 10)  # 负数total
                )
                await db.commit()
    
    @pytest.mark.asyncio
    async def test_foreign_key_constraints_comprehensive(self, temp_db_manager):
        """测试外键约束的全面场景"""
        db_manager = temp_db_manager
        
        async with db_manager.get_connection() as db:
            # 首先插入主表数据
            await db.execute(
                "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                ('20241015', 1, 1, 0, 1, 0, 0, 0)
            )
            await db.commit()
            
            # 测试正常的外键关系
            await db.execute(
                """
                INSERT INTO limit_up_stocks 
                (trade_date, ts_code, stock_name, market, is_st, open_price, close_price, high_price, pct_change)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                ('20241015', '000001.SZ', '平安银行', 'shenzhen', False, 10.0, 11.0, 11.0, 10.0)
            )
            await db.commit()
            
            # 测试违反外键约束
            with pytest.raises(Exception):  # 应该违反外键约束
                await db.execute(
                    """
                    INSERT INTO limit_up_stocks 
                    (trade_date, ts_code, stock_name, market, is_st, open_price, close_price, high_price, pct_change)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
                    """,
                    ('20241016', '000002.SZ', 'ST万科', 'shenzhen', True, 6.0, 6.3, 6.3, 5.0)  # 不存在的trade_date
                )
                await db.commit()
            
            # 测试级联删除
            await db.execute("DELETE FROM limit_up_stats WHERE trade_date = '20241015'")
            await db.commit()
            
            # 验证子表数据也被删除
            cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stocks WHERE trade_date = '20241015'")
            count = await cursor.fetchone()
            assert count[0] == 0
    
    @pytest.mark.asyncio
    async def test_trigger_functionality_comprehensive(self, temp_db_manager):
        """测试触发器功能的全面场景"""
        db_manager = temp_db_manager
        
        async with db_manager.get_connection() as db:
            # 插入数据
            await db.execute(
                "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                ('20241015', 100, 90, 30, 40, 20, 10, 10)
            )
            await db.commit()
            
            # 获取初始时间戳
            cursor = await db.execute(
                "SELECT created_at, updated_at FROM limit_up_stats WHERE trade_date = '20241015'"
            )
            initial_timestamps = await cursor.fetchone()
            initial_created = initial_timestamps[0]
            initial_updated = initial_timestamps[1]
            
            # 等待确保时间戳不同
            await asyncio.sleep(0.1)
            
            # 更新数据
            await db.execute(
                "UPDATE limit_up_stats SET total = 110, non_st = 100 WHERE trade_date = '20241015'"
            )
            await db.commit()
            
            # 检查时间戳更新
            cursor = await db.execute(
                "SELECT created_at, updated_at FROM limit_up_stats WHERE trade_date = '20241015'"
            )
            updated_timestamps = await cursor.fetchone()
            updated_created = updated_timestamps[0]
            updated_updated = updated_timestamps[1]
            
            # created_at应该保持不变，updated_at应该更新
            assert updated_created == initial_created
            assert updated_updated != initial_updated
            
            # 测试database_info表的触发器
            await db.execute(
                "UPDATE database_info SET value = 'test_value' WHERE key = 'version'"
            )
            await db.commit()
            
            cursor = await db.execute(
                "SELECT updated_at FROM database_info WHERE key = 'version'"
            )
            db_info_updated = await cursor.fetchone()
            assert db_info_updated is not None
    
    @pytest.mark.asyncio
    async def test_backup_restore_comprehensive(self, temp_db_manager):
        """测试备份恢复的全面场景"""
        db_manager = temp_db_manager
        
        # 插入测试数据
        async with db_manager.get_connection() as db:
            await db.execute(
                "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                ('20241015', 100, 90, 30, 40, 20, 10, 10)
            )
            await db.execute(
                "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                ('20241016', 80, 75, 25, 30, 15, 10, 5)
            )
            await db.commit()
        
        # 创建备份
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as backup_file:
            backup_path = backup_file.name
        
        try:
            # 执行备份
            backup_success = await db_manager.backup_database(backup_path)
            assert backup_success is True
            assert os.path.exists(backup_path)
            
            # 验证备份文件大小
            backup_size = os.path.getsize(backup_path)
            assert backup_size > 0
            
            # 修改原数据库
            async with db_manager.get_connection() as db:
                await db.execute("DELETE FROM limit_up_stats WHERE trade_date = '20241015'")
                await db.execute(
                    "UPDATE limit_up_stats SET total = 200 WHERE trade_date = '20241016'"
                )
                await db.commit()
            
            # 验证数据已修改
            async with db_manager.get_connection() as db:
                cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stats")
                count = await cursor.fetchone()
                assert count[0] == 1
                
                cursor = await db.execute("SELECT total FROM limit_up_stats WHERE trade_date = '20241016'")
                total = await cursor.fetchone()
                assert total[0] == 200
            
            # 恢复数据库
            restore_success = await db_manager.restore_database(backup_path)
            assert restore_success is True
            
            # 验证数据已恢复
            async with db_manager.get_connection() as db:
                cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stats")
                count = await cursor.fetchone()
                assert count[0] == 2  # 应该恢复到2条记录
                
                cursor = await db.execute("SELECT total FROM limit_up_stats WHERE trade_date = '20241016'")
                total = await cursor.fetchone()
                assert total[0] == 80  # 应该恢复到原始值
                
        finally:
            # 清理备份文件
            if os.path.exists(backup_path):
                os.unlink(backup_path)
    
    @pytest.mark.asyncio
    async def test_database_statistics_comprehensive(self, temp_db_manager):
        """测试数据库统计的全面场景"""
        db_manager = temp_db_manager
        
        # 插入测试数据
        async with db_manager.get_connection() as db:
            for i in range(10):
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    (f'2024101{i}', 100 + i * 10, 90 + i * 10, 30, 40, 20, 10, 10)
                )
            await db.commit()
        
        # 获取统计信息
        stats = await db_manager.get_database_stats()
        
        # 验证基本统计信息
        assert 'db_path' in stats
        assert 'db_size' in stats
        assert 'tables' in stats
        assert 'version' in stats
        assert 'created_at' in stats
        
        # 验证表统计信息
        assert stats['tables']['limit_up_stats'] == 10
        assert stats['tables']['limit_up_stocks'] == 0
        
        # 验证版本信息
        assert stats['version'] == str(db_manager.CURRENT_VERSION)
        
        # 验证数据库大小
        assert stats['db_size'] > 0
        
        # 验证创建时间格式
        datetime.fromisoformat(stats['created_at'])
    
    @pytest.mark.asyncio
    async def test_schema_validation_edge_cases(self):
        """测试模式验证的边界情况"""
        # 测试空数据库
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            empty_db_path = tmp_file.name
        
        try:
            empty_db_manager = DatabaseManager(empty_db_path)
            
            # 验证空数据库
            validation_result = await empty_db_manager.validate_schema()
            assert validation_result['valid'] is False
            assert len(validation_result['errors']) > 0
            
        finally:
            if os.path.exists(empty_db_path):
                os.unlink(empty_db_path)
        
        # 测试部分表缺失
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            partial_db_path = tmp_file.name
        
        try:
            partial_db_manager = DatabaseManager(partial_db_path)
            
            # 只创建部分表
            async with partial_db_manager.get_connection() as db:
                await db.execute(partial_db_manager.SCHEMA_SQL['limit_up_stats'])
                # 故意不创建其他表
                await db.commit()
            
            # 验证部分数据库
            validation_result = await partial_db_manager.validate_schema()
            assert validation_result['valid'] is False
            assert any('limit_up_stocks' in error for error in validation_result['errors'])
            
        finally:
            if os.path.exists(partial_db_path):
                os.unlink(partial_db_path)
    
    @pytest.mark.asyncio
    async def test_error_handling_comprehensive(self):
        """测试全面的错误处理"""
        # 测试无效路径
        invalid_path = "/invalid/nonexistent/path/database.db"
        db_manager = DatabaseManager(invalid_path)
        
        with pytest.raises(DatabaseError):
            await db_manager.initialize()
        
        # 测试权限错误（如果可能）
        if os.name != 'nt':  # 非Windows系统
            readonly_path = "/tmp/readonly_test.db"
            try:
                # 创建只读目录
                os.makedirs("/tmp/readonly_dir", mode=0o444, exist_ok=True)
                readonly_db_path = "/tmp/readonly_dir/test.db"
                
                db_manager = DatabaseManager(readonly_db_path)
                with pytest.raises(DatabaseError):
                    await db_manager.initialize()
                    
            except (OSError, PermissionError):
                # 如果无法创建只读目录，跳过此测试
                pass
    
    @pytest.mark.asyncio
    async def test_concurrent_database_operations(self, temp_db_manager):
        """测试并发数据库操作"""
        db_manager = temp_db_manager
        
        # 并发插入操作
        async def insert_worker(worker_id):
            async with db_manager.get_connection() as db:
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    (f'2024101{worker_id}', 100, 90, 30, 40, 20, 10, 10)
                )
                await db.commit()
        
        # 并发查询操作
        async def query_worker():
            async with db_manager.get_connection() as db:
                cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stats")
                count = await cursor.fetchone()
                return count[0]
        
        # 执行并发操作
        insert_tasks = [insert_worker(i) for i in range(10)]
        query_tasks = [query_worker() for _ in range(5)]
        
        # 等待所有任务完成
        await asyncio.gather(*insert_tasks)
        query_results = await asyncio.gather(*query_tasks)
        
        # 验证结果
        final_count = await query_worker()
        assert final_count == 10
        
        # 查询结果应该都是有效的
        assert all(isinstance(result, int) and result >= 0 for result in query_results)
    
    @pytest.mark.asyncio
    async def test_database_migration_simulation(self, temp_db_manager):
        """测试数据库迁移模拟"""
        db_manager = temp_db_manager
        
        # 模拟版本升级
        async with db_manager.get_connection() as db:
            # 更新版本号
            await db.execute(
                "UPDATE database_info SET value = ? WHERE key = 'version'",
                (str(db_manager.CURRENT_VERSION + 1),)
            )
            await db.commit()
        
        # 验证版本检查
        validation_result = await db_manager.validate_schema()
        # 根据实际实现，这里可能需要调整断言
        assert 'version' in validation_result
    
    @pytest.mark.asyncio
    async def test_performance_with_large_dataset(self, temp_db_manager):
        """测试大数据集性能"""
        db_manager = temp_db_manager
        import time
        
        # 批量插入大量数据
        start_time = time.time()
        
        async with db_manager.get_connection() as db:
            await db.execute("BEGIN TRANSACTION")
            try:
                for i in range(1000):
                    await db.execute(
                        "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                        (f'20241{i:03d}', 100, 90, 30, 40, 20, 10, 10)
                    )
                await db.execute("COMMIT")
            except Exception:
                await db.execute("ROLLBACK")
                raise
        
        insert_time = time.time() - start_time
        
        # 验证插入性能（1000条记录应该在合理时间内完成）
        assert insert_time < 10.0  # 10秒内完成
        
        # 测试查询性能
        start_time = time.time()
        
        async with db_manager.get_connection() as db:
            cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stats")
            count = await cursor.fetchone()
            assert count[0] == 1000
        
        query_time = time.time() - start_time
        assert query_time < 1.0  # 1秒内完成
    
    def test_singleton_pattern(self):
        """测试单例模式"""
        # 获取多个实例
        manager1 = get_database_manager()
        manager2 = get_database_manager()
        manager3 = get_database_manager()
        
        # 应该是同一个实例
        assert manager1 is manager2
        assert manager2 is manager3
        assert manager1 is manager3
    
    @pytest.mark.asyncio
    async def test_initialize_database_function(self):
        """测试数据库初始化便捷函数"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            # 使用便捷函数初始化
            success = await initialize_database(db_path)
            assert success is True
            
            # 验证数据库文件存在
            assert os.path.exists(db_path)
            
            # 验证可以连接和查询
            db_manager = DatabaseManager(db_path)
            validation_result = await db_manager.validate_schema()
            assert validation_result['valid'] is True
            
        finally:
            # 清理临时文件
            await self._cleanup_db_files(db_path)


class TestDatabaseManagerErrorRecovery:
    """数据库管理器错误恢复测试"""
    
    @pytest.mark.asyncio
    async def test_connection_recovery_after_corruption(self):
        """测试数据库损坏后的连接恢复"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            # 创建正常数据库
            db_manager = DatabaseManager(db_path)
            await db_manager.initialize()
            
            # 插入一些数据
            async with db_manager.get_connection() as db:
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', 100, 90, 30, 40, 20, 10, 10)
                )
                await db.commit()
            
            # 模拟数据库文件损坏（写入无效数据）
            with open(db_path, 'wb') as f:
                f.write(b'corrupted data')
            
            # 尝试连接损坏的数据库应该失败
            corrupted_manager = DatabaseManager(db_path)
            with pytest.raises(DatabaseError):
                async with corrupted_manager.get_connection() as db:
                    await db.execute("SELECT 1")
                    
        finally:
            if os.path.exists(db_path):
                os.unlink(db_path)
    
    @pytest.mark.asyncio
    async def test_transaction_recovery_after_interruption(self):
        """测试事务中断后的恢复"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            db_manager = DatabaseManager(db_path)
            await db_manager.initialize()
            
            # 模拟事务中断
            async with db_manager.get_connection() as db:
                await db.execute("BEGIN TRANSACTION")
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', 100, 90, 30, 40, 20, 10, 10)
                )
                # 故意不提交事务，模拟中断
            
            # 新连接应该看不到未提交的数据
            async with db_manager.get_connection() as db:
                cursor = await db.execute("SELECT COUNT(*) FROM limit_up_stats WHERE trade_date = '20241015'")
                count = await cursor.fetchone()
                assert count[0] == 0  # 事务应该已回滚
                
        finally:
            if os.path.exists(db_path):
                os.unlink(db_path)


if __name__ == "__main__":
    pytest.main([__file__])