"""
数据库管理模块

提供SQLite数据库连接管理、表结构创建和迁移功能
"""

import sqlite3
import asyncio
import aiosqlite
from typing import Optional, Dict, Any, List
from pathlib import Path
import logging
from contextlib import asynccontextmanager
from datetime import datetime
import threading
from concurrent.futures import ThreadPoolExecutor

from .errors import DatabaseError, ValidationError

logger = logging.getLogger(__name__)


class DatabaseManager:
    """数据库管理器 - 处理连接管理、表创建和迁移"""
    
    # 数据库版本和迁移脚本
    CURRENT_VERSION = 1
    
    # 表结构定义
    SCHEMA_SQL = {
        'limit_up_stats': """
            CREATE TABLE IF NOT EXISTS limit_up_stats (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                trade_date TEXT NOT NULL UNIQUE,
                total INTEGER NOT NULL DEFAULT 0,
                non_st INTEGER NOT NULL DEFAULT 0,
                shanghai INTEGER NOT NULL DEFAULT 0,
                shenzhen INTEGER NOT NULL DEFAULT 0,
                star INTEGER NOT NULL DEFAULT 0,
                beijing INTEGER NOT NULL DEFAULT 0,
                st INTEGER NOT NULL DEFAULT 0,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                CONSTRAINT chk_trade_date CHECK (trade_date GLOB '[0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9]'),
                CONSTRAINT chk_total_consistency CHECK (total = shanghai + shenzhen + star + beijing),
                CONSTRAINT chk_st_consistency CHECK (total = st + non_st),
                CONSTRAINT chk_non_negative CHECK (
                    total >= 0 AND non_st >= 0 AND shanghai >= 0 AND 
                    shenzhen >= 0 AND star >= 0 AND beijing >= 0 AND st >= 0
                )
            )
        """,
        
        'limit_up_stocks': """
            CREATE TABLE IF NOT EXISTS limit_up_stocks (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                trade_date TEXT NOT NULL,
                ts_code TEXT NOT NULL,
                stock_name TEXT,
                market TEXT NOT NULL,
                is_st BOOLEAN NOT NULL DEFAULT 0,
                open_price REAL,
                close_price REAL,
                high_price REAL,
                pct_change REAL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (trade_date) REFERENCES limit_up_stats(trade_date) ON DELETE CASCADE,
                CONSTRAINT chk_market CHECK (market IN ('shanghai', 'shenzhen', 'star', 'beijing')),
                CONSTRAINT chk_prices CHECK (
                    open_price > 0 AND close_price > 0 AND high_price > 0 AND
                    close_price <= high_price AND open_price <= high_price
                ),
                CONSTRAINT chk_pct_change CHECK (pct_change >= 0)
            )
        """,
        
        'database_info': """
            CREATE TABLE IF NOT EXISTS database_info (
                key TEXT PRIMARY KEY,
                value TEXT NOT NULL,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """,
        
        'price_distribution_stats': """
            CREATE TABLE IF NOT EXISTS price_distribution_stats (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                trade_date TEXT NOT NULL,
                market_type TEXT NOT NULL,
                range_name TEXT NOT NULL,
                stock_count INTEGER NOT NULL DEFAULT 0,
                percentage REAL NOT NULL DEFAULT 0.0,
                stock_codes TEXT,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                CONSTRAINT chk_trade_date_format CHECK (trade_date GLOB '[0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9]'),
                CONSTRAINT chk_market_type CHECK (market_type IN ('total', 'non_st', 'shanghai', 'shenzhen', 'star', 'beijing', 'st')),
                CONSTRAINT chk_stock_count CHECK (stock_count >= 0),
                CONSTRAINT chk_percentage CHECK (percentage >= 0.0 AND percentage <= 100.0),
                UNIQUE (trade_date, market_type, range_name)
            )
        """,
        
        'price_distribution_metadata': """
            CREATE TABLE IF NOT EXISTS price_distribution_metadata (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                trade_date TEXT NOT NULL UNIQUE,
                total_stocks INTEGER NOT NULL DEFAULT 0,
                processing_time REAL DEFAULT 0.0,
                data_quality_score REAL DEFAULT 1.0,
                data_source TEXT DEFAULT 'unknown',
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                CONSTRAINT chk_trade_date_format CHECK (trade_date GLOB '[0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9]'),
                CONSTRAINT chk_total_stocks CHECK (total_stocks >= 0),
                CONSTRAINT chk_processing_time CHECK (processing_time >= 0.0),
                CONSTRAINT chk_data_quality_score CHECK (data_quality_score >= 0.0 AND data_quality_score <= 1.0)
            )
        """
    }
    
    # 索引定义
    INDEXES_SQL = [
        "CREATE INDEX IF NOT EXISTS idx_limit_up_stats_date ON limit_up_stats(trade_date)",
        "CREATE INDEX IF NOT EXISTS idx_limit_up_stocks_date ON limit_up_stocks(trade_date)",
        "CREATE INDEX IF NOT EXISTS idx_limit_up_stocks_code ON limit_up_stocks(ts_code)",
        "CREATE INDEX IF NOT EXISTS idx_limit_up_stocks_market ON limit_up_stocks(market)",
        "CREATE INDEX IF NOT EXISTS idx_limit_up_stocks_is_st ON limit_up_stocks(is_st)",
        "CREATE UNIQUE INDEX IF NOT EXISTS idx_limit_up_stocks_unique ON limit_up_stocks(trade_date, ts_code)",
        "CREATE INDEX IF NOT EXISTS idx_price_distribution_stats_date ON price_distribution_stats(trade_date)",
        "CREATE INDEX IF NOT EXISTS idx_price_distribution_stats_market ON price_distribution_stats(market_type)",
        "CREATE INDEX IF NOT EXISTS idx_price_distribution_stats_range ON price_distribution_stats(range_name)",
        "CREATE INDEX IF NOT EXISTS idx_price_distribution_metadata_date ON price_distribution_metadata(trade_date)"
    ]
    
    # 触发器定义 - 自动更新updated_at字段
    TRIGGERS_SQL = [
        """
        CREATE TRIGGER IF NOT EXISTS update_limit_up_stats_timestamp 
        AFTER UPDATE ON limit_up_stats
        BEGIN
            UPDATE limit_up_stats SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
        END
        """,
        """
        CREATE TRIGGER IF NOT EXISTS update_database_info_timestamp 
        AFTER UPDATE ON database_info
        BEGIN
            UPDATE database_info SET updated_at = CURRENT_TIMESTAMP WHERE key = NEW.key;
        END
        """,
        """
        CREATE TRIGGER IF NOT EXISTS update_price_distribution_stats_timestamp 
        AFTER UPDATE ON price_distribution_stats
        BEGIN
            UPDATE price_distribution_stats SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
        END
        """,
        """
        CREATE TRIGGER IF NOT EXISTS update_price_distribution_metadata_timestamp 
        AFTER UPDATE ON price_distribution_metadata
        BEGIN
            UPDATE price_distribution_metadata SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
        END
        """
    ]
    
    def __init__(self, db_path: Optional[str] = None):
        """
        初始化数据库管理器
        
        Args:
            db_path: 数据库文件路径，如果为None则使用默认路径
        """
        if db_path is None:
            # 默认数据库路径：用户主目录下的.quickstock文件夹
            home_dir = Path.home()
            db_dir = home_dir / '.quickstock'
            db_dir.mkdir(exist_ok=True)
            db_path = str(db_dir / 'quickstock.db')
        
        self.db_path = db_path
        self._connection_pool: Optional[ThreadPoolExecutor] = None
        self._lock = threading.Lock()
        
        logger.info(f"Database manager initialized with path: {self.db_path}")
    
    async def initialize(self) -> bool:
        """
        初始化数据库 - 创建表结构、索引和触发器
        
        Returns:
            初始化是否成功
            
        Raises:
            DatabaseError: 数据库初始化失败
        """
        try:
            logger.info("Initializing database...")
            
            # 确保数据库文件目录存在
            db_dir = Path(self.db_path).parent
            db_dir.mkdir(parents=True, exist_ok=True)
            
            async with aiosqlite.connect(self.db_path) as db:
                # 启用外键约束
                await db.execute("PRAGMA foreign_keys = ON")
                
                # 设置WAL模式以提高并发性能
                await db.execute("PRAGMA journal_mode = WAL")
                
                # 创建表结构
                for table_name, schema_sql in self.SCHEMA_SQL.items():
                    logger.debug(f"Creating table: {table_name}")
                    await db.execute(schema_sql)
                
                # 创建索引
                for index_sql in self.INDEXES_SQL:
                    logger.debug(f"Creating index: {index_sql[:50]}...")
                    await db.execute(index_sql)
                
                # 创建触发器
                for trigger_sql in self.TRIGGERS_SQL:
                    logger.debug(f"Creating trigger: {trigger_sql[:50]}...")
                    await db.execute(trigger_sql)
                
                # 初始化数据库信息
                await self._initialize_database_info(db)
                
                await db.commit()
                
            logger.info("Database initialization completed successfully")
            return True
            
        except Exception as e:
            error_msg = f"Database initialization failed: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def _initialize_database_info(self, db: aiosqlite.Connection):
        """初始化数据库信息表"""
        # 检查是否已经初始化
        cursor = await db.execute(
            "SELECT value FROM database_info WHERE key = 'version'"
        )
        result = await cursor.fetchone()
        
        if result is None:
            # 首次初始化
            await db.execute(
                "INSERT INTO database_info (key, value) VALUES (?, ?)",
                ('version', str(self.CURRENT_VERSION))
            )
            await db.execute(
                "INSERT INTO database_info (key, value) VALUES (?, ?)",
                ('created_at', datetime.now().isoformat())
            )
            logger.info(f"Database initialized with version {self.CURRENT_VERSION}")
        else:
            # 检查版本兼容性
            current_version = int(result[0])
            if current_version > self.CURRENT_VERSION:
                raise DatabaseError(
                    f"Database version {current_version} is newer than supported version {self.CURRENT_VERSION}"
                )
            elif current_version < self.CURRENT_VERSION:
                logger.info(f"Database migration needed: {current_version} -> {self.CURRENT_VERSION}")
                await self._migrate_database(db, current_version)
    
    async def _migrate_database(self, db: aiosqlite.Connection, from_version: int):
        """
        数据库迁移
        
        Args:
            db: 数据库连接
            from_version: 当前数据库版本
        """
        logger.info(f"Migrating database from version {from_version} to {self.CURRENT_VERSION}")
        
        # 目前只有版本1，未来版本的迁移逻辑在这里添加
        if from_version < 1:
            # 迁移到版本1的逻辑
            pass
        
        # 更新版本信息
        await db.execute(
            "UPDATE database_info SET value = ? WHERE key = 'version'",
            (str(self.CURRENT_VERSION),)
        )
        
        logger.info(f"Database migration completed to version {self.CURRENT_VERSION}")
    
    @asynccontextmanager
    async def get_connection(self):
        """
        获取数据库连接的异步上下文管理器
        
        Yields:
            aiosqlite.Connection: 数据库连接
        """
        try:
            async with aiosqlite.connect(self.db_path) as db:
                # 启用外键约束
                await db.execute("PRAGMA foreign_keys = ON")
                yield db
        except Exception as e:
            error_msg = f"Failed to get database connection: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def validate_schema(self) -> Dict[str, Any]:
        """
        验证数据库模式完整性
        
        Returns:
            验证结果字典
        """
        validation_result = {
            'valid': True,
            'tables': {},
            'indexes': {},
            'triggers': {},
            'errors': []
        }
        
        try:
            async with self.get_connection() as db:
                # 验证表结构
                for table_name in self.SCHEMA_SQL.keys():
                    cursor = await db.execute(
                        "SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
                        (table_name,)
                    )
                    result = await cursor.fetchone()
                    validation_result['tables'][table_name] = result is not None
                    
                    if result is None:
                        validation_result['valid'] = False
                        validation_result['errors'].append(f"Table {table_name} not found")
                
                # 验证索引
                expected_indexes = [
                    'idx_limit_up_stats_date',
                    'idx_limit_up_stocks_date', 
                    'idx_limit_up_stocks_code',
                    'idx_limit_up_stocks_market',
                    'idx_limit_up_stocks_is_st',
                    'idx_limit_up_stocks_unique',
                    'idx_price_distribution_stats_date',
                    'idx_price_distribution_stats_market',
                    'idx_price_distribution_stats_range',
                    'idx_price_distribution_metadata_date'
                ]
                
                for index_name in expected_indexes:
                    cursor = await db.execute(
                        "SELECT name FROM sqlite_master WHERE type='index' AND name=?",
                        (index_name,)
                    )
                    result = await cursor.fetchone()
                    validation_result['indexes'][index_name] = result is not None
                    
                    if result is None:
                        validation_result['valid'] = False
                        validation_result['errors'].append(f"Index {index_name} not found")
                
                # 验证触发器
                expected_triggers = [
                    'update_limit_up_stats_timestamp',
                    'update_database_info_timestamp',
                    'update_price_distribution_stats_timestamp',
                    'update_price_distribution_metadata_timestamp'
                ]
                
                for trigger_name in expected_triggers:
                    cursor = await db.execute(
                        "SELECT name FROM sqlite_master WHERE type='trigger' AND name=?",
                        (trigger_name,)
                    )
                    result = await cursor.fetchone()
                    validation_result['triggers'][trigger_name] = result is not None
                    
                    if result is None:
                        validation_result['valid'] = False
                        validation_result['errors'].append(f"Trigger {trigger_name} not found")
                
        except Exception as e:
            validation_result['valid'] = False
            validation_result['errors'].append(f"Schema validation error: {str(e)}")
            logger.error(f"Schema validation failed: {str(e)}")
        
        return validation_result
    
    async def get_database_stats(self) -> Dict[str, Any]:
        """
        获取数据库统计信息
        
        Returns:
            数据库统计信息字典
        """
        stats = {
            'db_path': self.db_path,
            'db_size': 0,
            'tables': {},
            'version': None,
            'created_at': None
        }
        
        try:
            # 获取数据库文件大小
            db_file = Path(self.db_path)
            if db_file.exists():
                stats['db_size'] = db_file.stat().st_size
            
            async with self.get_connection() as db:
                # 获取版本信息
                cursor = await db.execute(
                    "SELECT key, value FROM database_info WHERE key IN ('version', 'created_at')"
                )
                info_rows = await cursor.fetchall()
                for key, value in info_rows:
                    stats[key] = value
                
                # 获取表统计信息
                for table_name in ['limit_up_stats', 'limit_up_stocks', 'price_distribution_stats', 'price_distribution_metadata']:
                    cursor = await db.execute(f"SELECT COUNT(*) FROM {table_name}")
                    count = await cursor.fetchone()
                    stats['tables'][table_name] = count[0] if count else 0
                
        except Exception as e:
            logger.error(f"Failed to get database stats: {str(e)}")
            stats['error'] = str(e)
        
        return stats
    
    async def backup_database(self, backup_path: str) -> bool:
        """
        备份数据库
        
        Args:
            backup_path: 备份文件路径
            
        Returns:
            备份是否成功
        """
        try:
            import shutil
            
            # 确保备份目录存在
            backup_dir = Path(backup_path).parent
            backup_dir.mkdir(parents=True, exist_ok=True)
            
            # 复制数据库文件
            shutil.copy2(self.db_path, backup_path)
            
            logger.info(f"Database backed up to: {backup_path}")
            return True
            
        except Exception as e:
            logger.error(f"Database backup failed: {str(e)}")
            return False
    
    async def restore_database(self, backup_path: str) -> bool:
        """
        从备份恢复数据库
        
        Args:
            backup_path: 备份文件路径
            
        Returns:
            恢复是否成功
        """
        try:
            import shutil
            
            if not Path(backup_path).exists():
                raise FileNotFoundError(f"Backup file not found: {backup_path}")
            
            # 备份当前数据库（如果存在）
            if Path(self.db_path).exists():
                current_backup = f"{self.db_path}.backup.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
                shutil.copy2(self.db_path, current_backup)
                logger.info(f"Current database backed up to: {current_backup}")
            
            # 恢复数据库
            shutil.copy2(backup_path, self.db_path)
            
            # 验证恢复的数据库
            validation_result = await self.validate_schema()
            if not validation_result['valid']:
                raise DatabaseError(f"Restored database validation failed: {validation_result['errors']}")
            
            logger.info(f"Database restored from: {backup_path}")
            return True
            
        except Exception as e:
            logger.error(f"Database restore failed: {str(e)}")
            return False
    
    def close(self):
        """关闭数据库管理器"""
        if self._connection_pool:
            self._connection_pool.shutdown(wait=True)
            self._connection_pool = None
        logger.info("Database manager closed")


# 全局数据库管理器实例
_db_manager: Optional[DatabaseManager] = None
_db_manager_lock = threading.Lock()


def get_database_manager(db_path: Optional[str] = None) -> DatabaseManager:
    """
    获取全局数据库管理器实例（单例模式）
    
    Args:
        db_path: 数据库文件路径
        
    Returns:
        DatabaseManager实例
    """
    global _db_manager
    
    with _db_manager_lock:
        if _db_manager is None:
            _db_manager = DatabaseManager(db_path)
        return _db_manager


async def initialize_database(db_path: Optional[str] = None) -> bool:
    """
    初始化数据库的便捷函数
    
    Args:
        db_path: 数据库文件路径
        
    Returns:
        初始化是否成功
    """
    db_manager = get_database_manager(db_path)
    return await db_manager.initialize()