"""
数据库操作层 - Repository模式实现

提供涨停统计数据的CRUD操作，包括批量操作和事务支持
"""

import asyncio
import aiosqlite
from typing import Optional, List, Dict, Any, Tuple
from datetime import datetime
import logging
from contextlib import asynccontextmanager

from ..models import LimitUpStats, StockDailyData
from .database import DatabaseManager, get_database_manager
from .errors import DatabaseError, ValidationError

logger = logging.getLogger(__name__)


class LimitUpStatsRepository:
    """涨停统计数据库操作类"""
    
    def __init__(self, db_manager: Optional[DatabaseManager] = None):
        """
        初始化Repository
        
        Args:
            db_manager: 数据库管理器实例，如果为None则使用全局实例
        """
        self.db_manager = db_manager or get_database_manager()
    
    async def save_limit_up_stats(self, stats: LimitUpStats, 
                                 stock_details: Optional[List[StockDailyData]] = None) -> bool:
        """
        保存涨停统计数据
        
        Args:
            stats: 涨停统计数据
            stock_details: 涨停股票详细数据列表（可选）
            
        Returns:
            保存是否成功
            
        Raises:
            DatabaseError: 数据库操作失败
            ValidationError: 数据验证失败
        """
        try:
            # 验证数据
            stats.validate()
            
            async with self.db_manager.get_connection() as db:
                # 开始事务
                await db.execute("BEGIN TRANSACTION")
                
                try:
                    # 检查是否已存在该日期的数据
                    existing_stats = await self._get_stats_by_date(db, stats.trade_date)
                    
                    if existing_stats:
                        # 更新现有数据
                        await self._update_stats(db, stats)
                        logger.info(f"Updated limit up stats for {stats.trade_date}")
                    else:
                        # 插入新数据
                        await self._insert_stats(db, stats)
                        logger.info(f"Inserted new limit up stats for {stats.trade_date}")
                    
                    # 如果提供了股票详细数据，保存到limit_up_stocks表
                    if stock_details:
                        await self._save_stock_details(db, stats.trade_date, stock_details)
                    
                    # 提交事务
                    await db.commit()
                    return True
                    
                except Exception as e:
                    # 回滚事务
                    await db.rollback()
                    raise e
                    
        except Exception as e:
            error_msg = f"Failed to save limit up stats for {stats.trade_date}: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def get_limit_up_stats(self, trade_date: str) -> Optional[LimitUpStats]:
        """
        获取指定日期的涨停统计数据
        
        Args:
            trade_date: 交易日期 (YYYYMMDD格式)
            
        Returns:
            涨停统计数据，如果不存在则返回None
        """
        try:
            async with self.db_manager.get_connection() as db:
                stats_data = await self._get_stats_by_date(db, trade_date)
                
                if not stats_data:
                    return None
                
                # 获取涨停股票列表
                limit_up_stocks = await self._get_limit_up_stocks_by_date(db, trade_date)
                
                # 构建市场分类数据
                market_breakdown = await self._get_market_breakdown(db, trade_date)
                
                # 创建LimitUpStats对象（从数据库重建时跳过验证）
                stats_obj = object.__new__(LimitUpStats)
                stats_obj.trade_date = stats_data['trade_date']
                stats_obj.total = stats_data['total']
                stats_obj.non_st = stats_data['non_st']
                stats_obj.shanghai = stats_data['shanghai']
                stats_obj.shenzhen = stats_data['shenzhen']
                stats_obj.star = stats_data['star']
                stats_obj.beijing = stats_data['beijing']
                stats_obj.st = stats_data['st']
                stats_obj.limit_up_stocks = limit_up_stocks
                stats_obj.market_breakdown = market_breakdown
                stats_obj.created_at = stats_data['created_at']
                stats_obj.updated_at = stats_data['updated_at']
                
                return stats_obj
                
        except Exception as e:
            error_msg = f"Failed to get limit up stats for {trade_date}: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def query_limit_up_stats(self, start_date: Optional[str] = None, 
                                 end_date: Optional[str] = None,
                                 limit: Optional[int] = None,
                                 offset: int = 0) -> List[LimitUpStats]:
        """
        查询日期范围内的涨停统计数据
        
        Args:
            start_date: 开始日期 (YYYYMMDD格式)，可选
            end_date: 结束日期 (YYYYMMDD格式)，可选
            limit: 返回记录数限制，可选
            offset: 偏移量，默认为0
            
        Returns:
            涨停统计数据列表
        """
        try:
            async with self.db_manager.get_connection() as db:
                # 构建查询条件
                where_conditions = []
                params = []
                
                if start_date:
                    where_conditions.append("trade_date >= ?")
                    params.append(start_date)
                
                if end_date:
                    where_conditions.append("trade_date <= ?")
                    params.append(end_date)
                
                where_clause = ""
                if where_conditions:
                    where_clause = "WHERE " + " AND ".join(where_conditions)
                
                # 构建查询SQL
                sql = f"""
                    SELECT trade_date, total, non_st, shanghai, shenzhen, star, beijing, st,
                           created_at, updated_at
                    FROM limit_up_stats
                    {where_clause}
                    ORDER BY trade_date DESC
                """
                
                if limit:
                    sql += f" LIMIT {limit}"
                if offset > 0:
                    sql += f" OFFSET {offset}"
                
                cursor = await db.execute(sql, params)
                rows = await cursor.fetchall()
                
                # 转换为LimitUpStats对象列表
                results = []
                for row in rows:
                    # 获取每个日期的涨停股票列表和市场分类
                    trade_date = row[0]
                    limit_up_stocks = await self._get_limit_up_stocks_by_date(db, trade_date)
                    market_breakdown = await self._get_market_breakdown(db, trade_date)
                    
                    # 创建LimitUpStats对象（从数据库重建时跳过验证）
                    stats = object.__new__(LimitUpStats)
                    stats.trade_date = row[0]
                    stats.total = row[1]
                    stats.non_st = row[2]
                    stats.shanghai = row[3]
                    stats.shenzhen = row[4]
                    stats.star = row[5]
                    stats.beijing = row[6]
                    stats.st = row[7]
                    stats.limit_up_stocks = limit_up_stocks
                    stats.market_breakdown = market_breakdown
                    stats.created_at = row[8]
                    stats.updated_at = row[9]
                    results.append(stats)
                
                return results
                
        except Exception as e:
            error_msg = f"Failed to query limit up stats: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def delete_limit_up_stats(self, trade_date: str) -> bool:
        """
        删除指定日期的涨停统计数据
        
        Args:
            trade_date: 交易日期 (YYYYMMDD格式)
            
        Returns:
            删除是否成功
        """
        try:
            async with self.db_manager.get_connection() as db:
                # 开始事务
                await db.execute("BEGIN TRANSACTION")
                
                try:
                    # 删除股票详细数据（由于外键约束，会自动级联删除）
                    cursor = await db.execute(
                        "DELETE FROM limit_up_stocks WHERE trade_date = ?",
                        (trade_date,)
                    )
                    stocks_deleted = cursor.rowcount
                    
                    # 删除统计数据
                    cursor = await db.execute(
                        "DELETE FROM limit_up_stats WHERE trade_date = ?",
                        (trade_date,)
                    )
                    stats_deleted = cursor.rowcount
                    
                    await db.commit()
                    
                    if stats_deleted > 0:
                        logger.info(f"Deleted limit up stats for {trade_date} (stats: {stats_deleted}, stocks: {stocks_deleted})")
                        return True
                    else:
                        logger.warning(f"No limit up stats found for {trade_date}")
                        return False
                        
                except Exception as e:
                    await db.rollback()
                    raise e
                    
        except Exception as e:
            error_msg = f"Failed to delete limit up stats for {trade_date}: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def list_available_dates(self, start_date: Optional[str] = None,
                                 end_date: Optional[str] = None) -> List[str]:
        """
        列出所有可用的统计日期
        
        Args:
            start_date: 开始日期过滤，可选
            end_date: 结束日期过滤，可选
            
        Returns:
            日期列表，按降序排列
        """
        try:
            async with self.db_manager.get_connection() as db:
                # 构建查询条件
                where_conditions = []
                params = []
                
                if start_date:
                    where_conditions.append("trade_date >= ?")
                    params.append(start_date)
                
                if end_date:
                    where_conditions.append("trade_date <= ?")
                    params.append(end_date)
                
                where_clause = ""
                if where_conditions:
                    where_clause = "WHERE " + " AND ".join(where_conditions)
                
                sql = f"""
                    SELECT trade_date 
                    FROM limit_up_stats 
                    {where_clause}
                    ORDER BY trade_date DESC
                """
                
                cursor = await db.execute(sql, params)
                rows = await cursor.fetchall()
                
                return [row[0] for row in rows]
                
        except Exception as e:
            error_msg = f"Failed to list available dates: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def get_database_stats(self) -> Dict[str, Any]:
        """
        获取数据库统计信息
        
        Returns:
            数据库统计信息字典
        """
        try:
            stats = await self.db_manager.get_database_stats()
            
            # 添加更详细的统计信息
            async with self.db_manager.get_connection() as db:
                # 获取日期范围
                cursor = await db.execute(
                    "SELECT MIN(trade_date), MAX(trade_date) FROM limit_up_stats"
                )
                date_range = await cursor.fetchone()
                
                if date_range and date_range[0]:
                    stats['date_range'] = {
                        'earliest': date_range[0],
                        'latest': date_range[1]
                    }
                
                # 获取平均涨停数量
                cursor = await db.execute(
                    "SELECT AVG(total), AVG(non_st), AVG(st) FROM limit_up_stats"
                )
                averages = await cursor.fetchone()
                
                if averages and averages[0] is not None:
                    stats['averages'] = {
                        'total': round(averages[0], 2),
                        'non_st': round(averages[1], 2),
                        'st': round(averages[2], 2)
                    }
                
                # 获取最近更新时间
                cursor = await db.execute(
                    "SELECT MAX(updated_at) FROM limit_up_stats"
                )
                last_update = await cursor.fetchone()
                
                if last_update and last_update[0]:
                    stats['last_update'] = last_update[0]
            
            return stats
            
        except Exception as e:
            error_msg = f"Failed to get database stats: {str(e)}"
            logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def batch_save_stats(self, stats_list: List[LimitUpStats]) -> Dict[str, Any]:
        """
        批量保存涨停统计数据
        
        Args:
            stats_list: 涨停统计数据列表
            
        Returns:
            批量操作结果字典
        """
        result = {
            'total': len(stats_list),
            'success': 0,
            'failed': 0,
            'errors': []
        }
        
        try:
            async with self.db_manager.get_connection() as db:
                await db.execute("BEGIN TRANSACTION")
                
                try:
                    for stats in stats_list:
                        try:
                            # 验证数据
                            stats.validate()
                            
                            # 检查是否已存在
                            existing = await self._get_stats_by_date(db, stats.trade_date)
                            
                            if existing:
                                await self._update_stats(db, stats)
                            else:
                                await self._insert_stats(db, stats)
                            
                            result['success'] += 1
                            
                        except Exception as e:
                            result['failed'] += 1
                            result['errors'].append({
                                'trade_date': stats.trade_date,
                                'error': str(e)
                            })
                            logger.warning(f"Failed to save stats for {stats.trade_date}: {str(e)}")
                    
                    await db.commit()
                    logger.info(f"Batch save completed: {result['success']} success, {result['failed']} failed")
                    
                except Exception as e:
                    await db.rollback()
                    raise e
                    
        except Exception as e:
            error_msg = f"Batch save operation failed: {str(e)}"
            logger.error(error_msg)
            result['errors'].append({'general': error_msg})
            raise DatabaseError(error_msg) from e
        
        return result
    
    async def batch_delete_stats(self, trade_dates: List[str]) -> Dict[str, Any]:
        """
        批量删除涨停统计数据
        
        Args:
            trade_dates: 交易日期列表
            
        Returns:
            批量删除结果字典
        """
        result = {
            'total': len(trade_dates),
            'success': 0,
            'failed': 0,
            'errors': []
        }
        
        try:
            async with self.db_manager.get_connection() as db:
                await db.execute("BEGIN TRANSACTION")
                
                try:
                    for trade_date in trade_dates:
                        try:
                            # 删除股票详细数据
                            await db.execute(
                                "DELETE FROM limit_up_stocks WHERE trade_date = ?",
                                (trade_date,)
                            )
                            
                            # 删除统计数据
                            cursor = await db.execute(
                                "DELETE FROM limit_up_stats WHERE trade_date = ?",
                                (trade_date,)
                            )
                            
                            if cursor.rowcount > 0:
                                result['success'] += 1
                            else:
                                result['failed'] += 1
                                result['errors'].append({
                                    'trade_date': trade_date,
                                    'error': 'No data found'
                                })
                                
                        except Exception as e:
                            result['failed'] += 1
                            result['errors'].append({
                                'trade_date': trade_date,
                                'error': str(e)
                            })
                    
                    await db.commit()
                    logger.info(f"Batch delete completed: {result['success']} success, {result['failed']} failed")
                    
                except Exception as e:
                    await db.rollback()
                    raise e
                    
        except Exception as e:
            error_msg = f"Batch delete operation failed: {str(e)}"
            logger.error(error_msg)
            result['errors'].append({'general': error_msg})
            raise DatabaseError(error_msg) from e
        
        return result
    
    # 私有辅助方法
    
    async def _get_stats_by_date(self, db: aiosqlite.Connection, trade_date: str) -> Optional[Dict[str, Any]]:
        """获取指定日期的统计数据"""
        cursor = await db.execute(
            """
            SELECT trade_date, total, non_st, shanghai, shenzhen, star, beijing, st,
                   created_at, updated_at
            FROM limit_up_stats 
            WHERE trade_date = ?
            """,
            (trade_date,)
        )
        row = await cursor.fetchone()
        
        if row:
            return {
                'trade_date': row[0],
                'total': row[1],
                'non_st': row[2],
                'shanghai': row[3],
                'shenzhen': row[4],
                'star': row[5],
                'beijing': row[6],
                'st': row[7],
                'created_at': row[8],
                'updated_at': row[9]
            }
        return None
    
    async def _insert_stats(self, db: aiosqlite.Connection, stats: LimitUpStats):
        """插入新的统计数据"""
        await db.execute(
            """
            INSERT INTO limit_up_stats 
            (trade_date, total, non_st, shanghai, shenzhen, star, beijing, st)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (stats.trade_date, stats.total, stats.non_st, stats.shanghai,
             stats.shenzhen, stats.star, stats.beijing, stats.st)
        )
    
    async def _update_stats(self, db: aiosqlite.Connection, stats: LimitUpStats):
        """更新现有的统计数据"""
        await db.execute(
            """
            UPDATE limit_up_stats 
            SET total = ?, non_st = ?, shanghai = ?, shenzhen = ?, 
                star = ?, beijing = ?, st = ?
            WHERE trade_date = ?
            """,
            (stats.total, stats.non_st, stats.shanghai, stats.shenzhen,
             stats.star, stats.beijing, stats.st, stats.trade_date)
        )
    
    async def _get_limit_up_stocks_by_date(self, db: aiosqlite.Connection, trade_date: str) -> List[str]:
        """获取指定日期的涨停股票代码列表"""
        cursor = await db.execute(
            "SELECT ts_code FROM limit_up_stocks WHERE trade_date = ? ORDER BY ts_code",
            (trade_date,)
        )
        rows = await cursor.fetchall()
        return [row[0] for row in rows]
    
    async def _get_market_breakdown(self, db: aiosqlite.Connection, trade_date: str) -> Dict[str, List[str]]:
        """获取指定日期的市场分类数据"""
        cursor = await db.execute(
            """
            SELECT market, ts_code 
            FROM limit_up_stocks 
            WHERE trade_date = ? 
            ORDER BY market, ts_code
            """,
            (trade_date,)
        )
        rows = await cursor.fetchall()
        
        breakdown = {}
        for market, ts_code in rows:
            if market not in breakdown:
                breakdown[market] = []
            breakdown[market].append(ts_code)
        
        return breakdown
    
    async def _save_stock_details(self, db: aiosqlite.Connection, trade_date: str, 
                                stock_details: List[StockDailyData]):
        """保存涨停股票详细数据"""
        # 首先删除该日期的现有股票详细数据
        await db.execute(
            "DELETE FROM limit_up_stocks WHERE trade_date = ?",
            (trade_date,)
        )
        
        # 插入新的股票详细数据
        for stock in stock_details:
            # 确定市场分类
            market = self._classify_stock_market(stock.ts_code)
            is_st = self._is_st_stock(stock.name)
            
            await db.execute(
                """
                INSERT INTO limit_up_stocks 
                (trade_date, ts_code, stock_name, market, is_st, 
                 open_price, close_price, high_price, pct_change)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (trade_date, stock.ts_code, stock.name, market, is_st,
                 stock.open, stock.close, stock.high, stock.pct_chg)
            )
    
    def _classify_stock_market(self, ts_code: str) -> str:
        """根据股票代码分类市场"""
        # 简单的市场分类逻辑
        if ts_code.startswith('688'):
            return 'star'
        elif ts_code.startswith('60') or ts_code.startswith('900'):
            return 'shanghai'
        elif ts_code.startswith('00') or ts_code.startswith('30') or ts_code.startswith('200'):
            return 'shenzhen'
        elif ts_code.startswith('8') or ts_code.startswith('4'):
            return 'beijing'
        else:
            return 'shanghai'  # 默认分类
    
    def _is_st_stock(self, stock_name: str) -> bool:
        """判断是否为ST股票"""
        if not stock_name:
            return False
        return any(pattern in stock_name.upper() for pattern in ['ST', '*ST', '退市', '暂停'])