"""
数据库管理器测试

测试数据库初始化、连接管理、模式验证等功能
"""

import pytest
import pytest_asyncio
import asyncio
import tempfile
import os
from pathlib import Path
from datetime import datetime

from quickstock.core.database import DatabaseManager, get_database_manager, initialize_database
from quickstock.core.errors import DatabaseError


class TestDatabaseManager:
    """数据库管理器测试类"""
    
    @pytest_asyncio.fixture
    async def temp_db_manager(self):
        """创建临时数据库管理器"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            db_manager = DatabaseManager(db_path)
            await db_manager.initialize()
            yield db_manager
        finally:
            # 清理临时文件
            if os.path.exists(db_path):
                os.unlink(db_path)
            # 清理WAL和SHM文件
            for suffix in ['-wal', '-shm']:
                wal_file = db_path + suffix
                if os.path.exists(wal_file):
                    os.unlink(wal_file)
    
    @pytest.mark.asyncio
    async def test_database_initialization(self, temp_db_manager):
        """测试数据库初始化"""
        db_manager = temp_db_manager
        
        # 验证数据库文件存在
        assert os.path.exists(db_manager.db_path)
        
        # 验证模式完整性
        validation_result = await db_manager.validate_schema()
        assert validation_result['valid'] is True
        assert len(validation_result['errors']) == 0
        
        # 验证所有表都存在
        expected_tables = ['limit_up_stats', 'limit_up_stocks', 'database_info']
        for table in expected_tables:
            assert validation_result['tables'][table] is True
        
        # 验证所有索引都存在
        expected_indexes = [
            'idx_limit_up_stats_date',
            'idx_limit_up_stocks_date',
            'idx_limit_up_stocks_code',
            'idx_limit_up_stocks_market',
            'idx_limit_up_stocks_is_st',
            'idx_limit_up_stocks_unique'
        ]
        for index in expected_indexes:
            assert validation_result['indexes'][index] is True
        
        # 验证所有触发器都存在
        expected_triggers = [
            'update_limit_up_stats_timestamp',
            'update_database_info_timestamp'
        ]
        for trigger in expected_triggers:
            assert validation_result['triggers'][trigger] is True
    
    @pytest.mark.asyncio
    async def test_database_connection(self, temp_db_manager):
        """测试数据库连接"""
        db_manager = temp_db_manager
        
        # 测试连接获取
        async with db_manager.get_connection() as db:
            # 测试基本查询
            cursor = await db.execute("SELECT 1")
            result = await cursor.fetchone()
            assert result[0] == 1
            
            # 测试外键约束是否启用
            cursor = await db.execute("PRAGMA foreign_keys")
            result = await cursor.fetchone()
            assert result[0] == 1  # 外键约束应该启用
    
    @pytest.mark.asyncio
    async def test_database_info_initialization(self, temp_db_manager):
        """测试数据库信息初始化"""
        db_manager = temp_db_manager
        
        async with db_manager.get_connection() as db:
            # 检查版本信息
            cursor = await db.execute(
                "SELECT value FROM database_info WHERE key = 'version'"
            )
            result = await cursor.fetchone()
            assert result is not None
            assert int(result[0]) == db_manager.CURRENT_VERSION
            
            # 检查创建时间
            cursor = await db.execute(
                "SELECT value FROM database_info WHERE key = 'created_at'"
            )
            result = await cursor.fetchone()
            assert result is not None
            # 验证时间格式
            datetime.fromisoformat(result[0])
    
    @pytest.mark.asyncio
    async def test_database_stats(self, temp_db_manager):
        """测试数据库统计信息"""
        db_manager = temp_db_manager
        
        stats = await db_manager.get_database_stats()
        
        # 验证基本统计信息
        assert 'db_path' in stats
        assert 'db_size' in stats
        assert 'tables' in stats
        assert 'version' in stats
        assert 'created_at' in stats
        
        # 验证表统计信息
        assert 'limit_up_stats' in stats['tables']
        assert 'limit_up_stocks' in stats['tables']
        assert stats['tables']['limit_up_stats'] == 0  # 初始为空
        assert stats['tables']['limit_up_stocks'] == 0  # 初始为空
        
        # 验证版本信息
        assert stats['version'] == str(db_manager.CURRENT_VERSION)
    
    @pytest.mark.asyncio
    async def test_database_backup_restore(self, temp_db_manager):
        """测试数据库备份和恢复"""
        db_manager = temp_db_manager
        
        # 插入一些测试数据
        async with db_manager.get_connection() as db:
            await db.execute(
                "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                ('20241015', 100, 90, 30, 40, 15, 5, 10)
            )
            await db.commit()
        
        # 创建备份
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as backup_file:
            backup_path = backup_file.name
        
        try:
            # 执行备份
            backup_success = await db_manager.backup_database(backup_path)
            assert backup_success is True
            assert os.path.exists(backup_path)
            
            # 修改原数据库
            async with db_manager.get_connection() as db:
                await db.execute(
                    "UPDATE limit_up_stats SET total = 200 WHERE trade_date = '20241015'"
                )
                await db.commit()
            
            # 验证数据已修改
            async with db_manager.get_connection() as db:
                cursor = await db.execute(
                    "SELECT total FROM limit_up_stats WHERE trade_date = '20241015'"
                )
                result = await cursor.fetchone()
                assert result[0] == 200
            
            # 恢复数据库
            restore_success = await db_manager.restore_database(backup_path)
            assert restore_success is True
            
            # 验证数据已恢复
            async with db_manager.get_connection() as db:
                cursor = await db.execute(
                    "SELECT total FROM limit_up_stats WHERE trade_date = '20241015'"
                )
                result = await cursor.fetchone()
                assert result[0] == 100  # 应该恢复到原始值
                
        finally:
            # 清理备份文件
            if os.path.exists(backup_path):
                os.unlink(backup_path)
    
    @pytest.mark.asyncio
    async def test_schema_validation_with_missing_table(self):
        """测试缺少表时的模式验证"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            db_manager = DatabaseManager(db_path)
            
            # 只创建部分表结构
            async with db_manager.get_connection() as db:
                await db.execute(db_manager.SCHEMA_SQL['limit_up_stats'])
                # 故意不创建 limit_up_stocks 表
                await db.commit()
            
            # 验证模式
            validation_result = await db_manager.validate_schema()
            
            # 应该检测到缺少表
            assert validation_result['valid'] is False
            assert any('limit_up_stocks' in error for error in validation_result['errors'])
            
        finally:
            # 清理临时文件
            if os.path.exists(db_path):
                os.unlink(db_path)
    
    @pytest.mark.asyncio
    async def test_database_constraints(self, temp_db_manager):
        """测试数据库约束"""
        db_manager = temp_db_manager
        
        async with db_manager.get_connection() as db:
            # 测试日期格式约束
            with pytest.raises(Exception):  # 应该违反CHECK约束
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('invalid-date', 100, 90, 30, 40, 15, 5, 10)
                )
                await db.commit()
            
            # 测试总数一致性约束
            with pytest.raises(Exception):  # 应该违反CHECK约束
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', 100, 90, 30, 40, 15, 5, 20)  # total != shanghai + shenzhen + star + beijing
                )
                await db.commit()
            
            # 测试ST一致性约束
            with pytest.raises(Exception):  # 应该违反CHECK约束
                await db.execute(
                    "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                    ('20241015', 100, 80, 30, 40, 15, 15, 10)  # total != st + non_st
                )
                await db.commit()
    
    @pytest.mark.asyncio
    async def test_database_triggers(self, temp_db_manager):
        """测试数据库触发器"""
        db_manager = temp_db_manager
        
        async with db_manager.get_connection() as db:
            # 插入数据
            await db.execute(
                "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                ('20241015', 100, 90, 30, 40, 15, 15, 10)
            )
            await db.commit()
            
            # 获取初始时间戳
            cursor = await db.execute(
                "SELECT created_at, updated_at FROM limit_up_stats WHERE trade_date = '20241015'"
            )
            initial_timestamps = await cursor.fetchone()
            
            # 等待一小段时间确保时间戳不同
            await asyncio.sleep(0.1)
            
            # 更新数据
            await db.execute(
                "UPDATE limit_up_stats SET total = 110, non_st = 100 WHERE trade_date = '20241015'"
            )
            await db.commit()
            
            # 检查时间戳是否更新
            cursor = await db.execute(
                "SELECT created_at, updated_at FROM limit_up_stats WHERE trade_date = '20241015'"
            )
            updated_timestamps = await cursor.fetchone()
            
            # created_at应该保持不变，updated_at应该更新
            assert updated_timestamps[0] == initial_timestamps[0]  # created_at不变
            assert updated_timestamps[1] != initial_timestamps[1]  # updated_at改变
    
    def test_singleton_database_manager(self):
        """测试单例模式数据库管理器"""
        # 获取两个实例
        manager1 = get_database_manager()
        manager2 = get_database_manager()
        
        # 应该是同一个实例
        assert manager1 is manager2
    
    @pytest.mark.asyncio
    async def test_initialize_database_function(self):
        """测试数据库初始化便捷函数"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
            db_path = tmp_file.name
        
        try:
            # 使用便捷函数初始化
            success = await initialize_database(db_path)
            assert success is True
            
            # 验证数据库文件存在
            assert os.path.exists(db_path)
            
            # 验证可以连接
            db_manager = DatabaseManager(db_path)
            validation_result = await db_manager.validate_schema()
            assert validation_result['valid'] is True
            
        finally:
            # 清理临时文件
            if os.path.exists(db_path):
                os.unlink(db_path)
    
    @pytest.mark.asyncio
    async def test_database_error_handling(self):
        """测试数据库错误处理"""
        # 测试无效路径
        invalid_path = "/invalid/path/database.db"
        db_manager = DatabaseManager(invalid_path)
        
        with pytest.raises(DatabaseError):
            await db_manager.initialize()
    
    @pytest.mark.asyncio
    async def test_foreign_key_constraints(self, temp_db_manager):
        """测试外键约束"""
        db_manager = temp_db_manager
        
        async with db_manager.get_connection() as db:
            # 尝试插入没有对应统计记录的股票详细数据
            with pytest.raises(Exception):  # 应该违反外键约束
                await db.execute(
                    """
                    INSERT INTO limit_up_stocks 
                    (trade_date, ts_code, stock_name, market, is_st, open_price, close_price, high_price, pct_change)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
                    """,
                    ('20241015', '000001.SZ', '平安银行', 'shenzhen', False, 10.0, 11.0, 11.0, 10.0)
                )
                await db.commit()
            
            # 先插入统计记录
            await db.execute(
                "INSERT INTO limit_up_stats (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
                ('20241015', 1, 1, 0, 1, 0, 0, 0)
            )
            await db.commit()
            
            # 现在应该可以插入股票详细数据
            await db.execute(
                """
                INSERT INTO limit_up_stocks 
                (trade_date, ts_code, stock_name, market, is_st, open_price, close_price, high_price, pct_change)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                ('20241015', '000001.SZ', '平安银行', 'shenzhen', False, 10.0, 11.0, 11.0, 10.0)
            )
            await db.commit()
            
            # 验证数据插入成功
            cursor = await db.execute(
                "SELECT COUNT(*) FROM limit_up_stocks WHERE trade_date = '20241015'"
            )
            result = await cursor.fetchone()
            assert result[0] == 1


if __name__ == '__main__':
    pytest.main([__file__])