"""
涨跌分布统计数据库存储库

提供涨跌分布统计数据的数据库存储和查询功能
"""

import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple

import aiosqlite

from .database import DatabaseManager
from .errors import DatabaseError, ValidationError
from ..models.price_distribution_models import PriceDistributionStats

logger = logging.getLogger(__name__)


class PriceDistributionRepository:
    """涨跌分布统计数据库存储库"""
    
    def __init__(self, db_manager: DatabaseManager):
        """
        初始化存储库
        
        Args:
            db_manager: 数据库管理器实例
        """
        self.db_manager = db_manager
        self.logger = logger
    
    async def save_distribution_stats(self, stats: PriceDistributionStats) -> bool:
        """
        保存涨跌分布统计数据到数据库
        
        Args:
            stats: 涨跌分布统计数据
            
        Returns:
            保存是否成功
            
        Raises:
            DatabaseError: 数据库操作失败
            ValidationError: 数据验证失败
        """
        try:
            # 验证数据
            stats.validate()
            
            async with self.db_manager.get_connection() as db:
                # 开始事务
                await db.execute("BEGIN TRANSACTION")
                
                try:
                    # 保存元数据
                    await self._save_metadata(db, stats)
                    
                    # 保存分布统计数据
                    await self._save_distribution_data(db, stats)
                    
                    # 提交事务
                    await db.commit()
                    
                    self.logger.info(f"Successfully saved distribution stats for {stats.trade_date}")
                    return True
                    
                except Exception as e:
                    # 回滚事务
                    await db.rollback()
                    raise e
                    
        except Exception as e:
            error_msg = f"Failed to save distribution stats for {stats.trade_date}: {str(e)}"
            self.logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def _save_metadata(self, db: aiosqlite.Connection, stats: PriceDistributionStats):
        """保存元数据"""
        metadata_sql = """
            INSERT OR REPLACE INTO price_distribution_metadata 
            (trade_date, total_stocks, processing_time, data_quality_score, data_source, created_at, updated_at)
            VALUES (?, ?, ?, ?, ?, ?, ?)
        """
        
        await db.execute(metadata_sql, (
            stats.trade_date,
            stats.total_stocks,
            stats.processing_time,
            stats.data_quality_score,
            'quickstock',  # 数据源
            stats.created_at,
            stats.updated_at or datetime.now().isoformat()
        ))
    
    async def _save_distribution_data(self, db: aiosqlite.Connection, stats: PriceDistributionStats):
        """保存分布统计数据"""
        # 删除现有数据
        await db.execute(
            "DELETE FROM price_distribution_stats WHERE trade_date = ?",
            (stats.trade_date,)
        )
        
        # 保存总体市场数据
        await self._save_market_data(db, stats.trade_date, 'total', stats)
        
        # 保存各市场板块数据
        for market_type, market_data in stats.market_breakdown.items():
            await self._save_market_breakdown_data(db, stats.trade_date, market_type, market_data)
    
    async def _save_market_data(self, db: aiosqlite.Connection, trade_date: str, 
                               market_type: str, stats: PriceDistributionStats):
        """保存市场数据"""
        stats_sql = """
            INSERT INTO price_distribution_stats 
            (trade_date, market_type, range_name, stock_count, percentage, stock_codes, created_at, updated_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
        """
        
        current_time = datetime.now().isoformat()
        
        # 保存正涨幅分布
        for range_name, count in stats.positive_ranges.items():
            percentage = stats.positive_percentages.get(range_name, 0.0)
            await db.execute(stats_sql, (
                trade_date, market_type, range_name, count, percentage, 
                None, current_time, current_time
            ))
        
        # 保存负涨幅分布
        for range_name, count in stats.negative_ranges.items():
            percentage = stats.negative_percentages.get(range_name, 0.0)
            await db.execute(stats_sql, (
                trade_date, market_type, range_name, count, percentage,
                None, current_time, current_time
            ))
    
    async def _save_market_breakdown_data(self, db: aiosqlite.Connection, trade_date: str,
                                         market_type: str, market_data: Dict[str, Any]):
        """保存市场板块分布数据"""
        stats_sql = """
            INSERT INTO price_distribution_stats 
            (trade_date, market_type, range_name, stock_count, percentage, stock_codes, created_at, updated_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
        """
        
        current_time = datetime.now().isoformat()
        
        # 保存正涨幅分布
        positive_ranges = market_data.get('positive_ranges', {})
        positive_percentages = market_data.get('positive_percentages', {})
        for range_name, count in positive_ranges.items():
            percentage = positive_percentages.get(range_name, 0.0)
            stock_codes = market_data.get('stock_codes', {}).get(range_name)
            stock_codes_json = json.dumps(stock_codes) if stock_codes else None
            
            await db.execute(stats_sql, (
                trade_date, market_type, range_name, count, percentage,
                stock_codes_json, current_time, current_time
            ))
        
        # 保存负涨幅分布
        negative_ranges = market_data.get('negative_ranges', {})
        negative_percentages = market_data.get('negative_percentages', {})
        for range_name, count in negative_ranges.items():
            percentage = negative_percentages.get(range_name, 0.0)
            stock_codes = market_data.get('stock_codes', {}).get(range_name)
            stock_codes_json = json.dumps(stock_codes) if stock_codes else None
            
            await db.execute(stats_sql, (
                trade_date, market_type, range_name, count, percentage,
                stock_codes_json, current_time, current_time
            ))
    
    async def get_distribution_stats(self, trade_date: str) -> Optional[PriceDistributionStats]:
        """
        从数据库获取涨跌分布统计数据
        
        Args:
            trade_date: 交易日期 (YYYYMMDD)
            
        Returns:
            涨跌分布统计数据，如果不存在则返回None
            
        Raises:
            DatabaseError: 数据库操作失败
        """
        try:
            async with self.db_manager.get_connection() as db:
                # 获取元数据
                metadata = await self._get_metadata(db, trade_date)
                if not metadata:
                    return None
                
                # 获取分布数据
                distribution_data = await self._get_distribution_data(db, trade_date)
                if not distribution_data:
                    return None
                
                # 构建统计对象
                stats = self._build_stats_object(trade_date, metadata, distribution_data)
                
                self.logger.debug(f"Successfully retrieved distribution stats for {trade_date}")
                return stats
                
        except Exception as e:
            error_msg = f"Failed to get distribution stats for {trade_date}: {str(e)}"
            self.logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def _get_metadata(self, db: aiosqlite.Connection, trade_date: str) -> Optional[Dict[str, Any]]:
        """获取元数据"""
        cursor = await db.execute("""
            SELECT total_stocks, processing_time, data_quality_score, data_source, created_at, updated_at
            FROM price_distribution_metadata 
            WHERE trade_date = ?
        """, (trade_date,))
        
        row = await cursor.fetchone()
        if not row:
            return None
        
        return {
            'total_stocks': row[0],
            'processing_time': row[1],
            'data_quality_score': row[2],
            'data_source': row[3],
            'created_at': row[4],
            'updated_at': row[5]
        }
    
    async def _get_distribution_data(self, db: aiosqlite.Connection, trade_date: str) -> Optional[Dict[str, Any]]:
        """获取分布数据"""
        cursor = await db.execute("""
            SELECT market_type, range_name, stock_count, percentage, stock_codes
            FROM price_distribution_stats 
            WHERE trade_date = ?
            ORDER BY market_type, range_name
        """, (trade_date,))
        
        rows = await cursor.fetchall()
        if not rows:
            return None
        
        # 组织数据结构
        data = {
            'total_positive_ranges': {},
            'total_positive_percentages': {},
            'total_negative_ranges': {},
            'total_negative_percentages': {},
            'market_breakdown': {}
        }
        
        for row in rows:
            market_type, range_name, stock_count, percentage, stock_codes_json = row
            
            # 解析股票代码
            stock_codes = None
            if stock_codes_json:
                try:
                    stock_codes = json.loads(stock_codes_json)
                except json.JSONDecodeError:
                    self.logger.warning(f"Failed to parse stock codes for {trade_date}, {market_type}, {range_name}")
            
            if market_type == 'total':
                # 总体市场数据
                if self._is_positive_range(range_name):
                    data['total_positive_ranges'][range_name] = stock_count
                    data['total_positive_percentages'][range_name] = percentage
                else:
                    data['total_negative_ranges'][range_name] = stock_count
                    data['total_negative_percentages'][range_name] = percentage
            else:
                # 市场板块数据
                if market_type not in data['market_breakdown']:
                    data['market_breakdown'][market_type] = {
                        'positive_ranges': {},
                        'positive_percentages': {},
                        'negative_ranges': {},
                        'negative_percentages': {},
                        'stock_codes': {}
                    }
                
                market_data = data['market_breakdown'][market_type]
                if self._is_positive_range(range_name):
                    market_data['positive_ranges'][range_name] = stock_count
                    market_data['positive_percentages'][range_name] = percentage
                else:
                    market_data['negative_ranges'][range_name] = stock_count
                    market_data['negative_percentages'][range_name] = percentage
                
                if stock_codes:
                    market_data['stock_codes'][range_name] = stock_codes
        
        return data
    
    def _is_positive_range(self, range_name: str) -> bool:
        """判断是否为正涨幅区间"""
        # 负涨幅区间的特征：包含"到-"或以"-"开头或包含"<=-"
        negative_indicators = ['到-', '<=-']
        for indicator in negative_indicators:
            if indicator in range_name:
                return False
        
        # 如果以"-"开头且不是"0到-"的情况，则为负区间
        if range_name.startswith('-') and not range_name.startswith('0到-'):
            return False
        
        return True
    
    def _build_stats_object(self, trade_date: str, metadata: Dict[str, Any], 
                           distribution_data: Dict[str, Any]) -> PriceDistributionStats:
        """构建统计对象"""
        return PriceDistributionStats(
            trade_date=trade_date,
            total_stocks=metadata['total_stocks'],
            positive_ranges=distribution_data['total_positive_ranges'],
            positive_percentages=distribution_data['total_positive_percentages'],
            negative_ranges=distribution_data['total_negative_ranges'],
            negative_percentages=distribution_data['total_negative_percentages'],
            market_breakdown=distribution_data['market_breakdown'],
            created_at=metadata['created_at'],
            updated_at=metadata['updated_at'],
            processing_time=metadata['processing_time'],
            data_quality_score=metadata['data_quality_score']
        )
    
    async def delete_distribution_stats(self, trade_date: str) -> bool:
        """
        删除指定日期的涨跌分布统计数据
        
        Args:
            trade_date: 交易日期 (YYYYMMDD)
            
        Returns:
            删除是否成功
            
        Raises:
            DatabaseError: 数据库操作失败
        """
        try:
            async with self.db_manager.get_connection() as db:
                # 开始事务
                await db.execute("BEGIN TRANSACTION")
                
                try:
                    # 删除分布统计数据
                    cursor = await db.execute(
                        "DELETE FROM price_distribution_stats WHERE trade_date = ?",
                        (trade_date,)
                    )
                    stats_deleted = cursor.rowcount
                    
                    # 删除元数据
                    cursor = await db.execute(
                        "DELETE FROM price_distribution_metadata WHERE trade_date = ?",
                        (trade_date,)
                    )
                    metadata_deleted = cursor.rowcount
                    
                    # 提交事务
                    await db.commit()
                    
                    self.logger.info(f"Deleted distribution stats for {trade_date}: "
                                   f"{stats_deleted} stats records, {metadata_deleted} metadata records")
                    return True
                    
                except Exception as e:
                    # 回滚事务
                    await db.rollback()
                    raise e
                    
        except Exception as e:
            error_msg = f"Failed to delete distribution stats for {trade_date}: {str(e)}"
            self.logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def get_available_dates(self, start_date: Optional[str] = None, 
                                 end_date: Optional[str] = None) -> List[str]:
        """
        获取可用的交易日期列表
        
        Args:
            start_date: 开始日期 (YYYYMMDD)，可选
            end_date: 结束日期 (YYYYMMDD)，可选
            
        Returns:
            交易日期列表
            
        Raises:
            DatabaseError: 数据库操作失败
        """
        try:
            async with self.db_manager.get_connection() as db:
                sql = "SELECT DISTINCT trade_date FROM price_distribution_metadata"
                params = []
                
                conditions = []
                if start_date:
                    conditions.append("trade_date >= ?")
                    params.append(start_date)
                
                if end_date:
                    conditions.append("trade_date <= ?")
                    params.append(end_date)
                
                if conditions:
                    sql += " WHERE " + " AND ".join(conditions)
                
                sql += " ORDER BY trade_date DESC"
                
                cursor = await db.execute(sql, params)
                rows = await cursor.fetchall()
                
                dates = [row[0] for row in rows]
                self.logger.debug(f"Found {len(dates)} available dates")
                return dates
                
        except Exception as e:
            error_msg = f"Failed to get available dates: {str(e)}"
            self.logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def get_stats_summary(self, start_date: Optional[str] = None,
                               end_date: Optional[str] = None) -> Dict[str, Any]:
        """
        获取统计摘要信息
        
        Args:
            start_date: 开始日期 (YYYYMMDD)，可选
            end_date: 结束日期 (YYYYMMDD)，可选
            
        Returns:
            统计摘要信息
            
        Raises:
            DatabaseError: 数据库操作失败
        """
        try:
            async with self.db_manager.get_connection() as db:
                # 基础查询
                base_sql = """
                    SELECT 
                        COUNT(*) as total_records,
                        MIN(trade_date) as earliest_date,
                        MAX(trade_date) as latest_date,
                        AVG(total_stocks) as avg_total_stocks,
                        AVG(processing_time) as avg_processing_time,
                        AVG(data_quality_score) as avg_quality_score
                    FROM price_distribution_metadata
                """
                
                params = []
                conditions = []
                
                if start_date:
                    conditions.append("trade_date >= ?")
                    params.append(start_date)
                
                if end_date:
                    conditions.append("trade_date <= ?")
                    params.append(end_date)
                
                if conditions:
                    base_sql += " WHERE " + " AND ".join(conditions)
                
                cursor = await db.execute(base_sql, params)
                summary_row = await cursor.fetchone()
                
                if not summary_row or summary_row[0] == 0:
                    return {
                        'total_records': 0,
                        'date_range': None,
                        'averages': None,
                        'market_stats': {}
                    }
                
                # 获取市场统计
                market_stats_sql = """
                    SELECT 
                        market_type,
                        COUNT(DISTINCT trade_date) as date_count,
                        AVG(stock_count) as avg_stock_count
                    FROM price_distribution_stats
                """
                
                if conditions:
                    market_stats_sql += " WHERE " + " AND ".join(conditions)
                
                market_stats_sql += " GROUP BY market_type"
                
                cursor = await db.execute(market_stats_sql, params)
                market_rows = await cursor.fetchall()
                
                market_stats = {}
                for row in market_rows:
                    market_type, date_count, avg_stock_count = row
                    market_stats[market_type] = {
                        'date_count': date_count,
                        'avg_stock_count': round(avg_stock_count, 2) if avg_stock_count else 0
                    }
                
                return {
                    'total_records': summary_row[0],
                    'date_range': {
                        'earliest': summary_row[1],
                        'latest': summary_row[2]
                    },
                    'averages': {
                        'total_stocks': round(summary_row[3], 2) if summary_row[3] else 0,
                        'processing_time': round(summary_row[4], 3) if summary_row[4] else 0,
                        'quality_score': round(summary_row[5], 3) if summary_row[5] else 0
                    },
                    'market_stats': market_stats
                }
                
        except Exception as e:
            error_msg = f"Failed to get stats summary: {str(e)}"
            self.logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def batch_save_distribution_stats(self, stats_list: List[PriceDistributionStats]) -> Dict[str, Any]:
        """
        批量保存涨跌分布统计数据
        
        Args:
            stats_list: 统计数据列表
            
        Returns:
            批量保存结果
            
        Raises:
            DatabaseError: 数据库操作失败
        """
        if not stats_list:
            return {'success_count': 0, 'error_count': 0, 'errors': []}
        
        success_count = 0
        error_count = 0
        errors = []
        
        try:
            async with self.db_manager.get_connection() as db:
                for stats in stats_list:
                    try:
                        # 开始事务
                        await db.execute("BEGIN TRANSACTION")
                        
                        # 保存元数据
                        await self._save_metadata(db, stats)
                        
                        # 保存分布统计数据
                        await self._save_distribution_data(db, stats)
                        
                        # 提交事务
                        await db.commit()
                        
                        success_count += 1
                        self.logger.debug(f"Successfully saved stats for {stats.trade_date}")
                        
                    except Exception as e:
                        # 回滚事务
                        await db.rollback()
                        error_count += 1
                        error_msg = f"Failed to save stats for {stats.trade_date}: {str(e)}"
                        errors.append(error_msg)
                        self.logger.error(error_msg)
            
            self.logger.info(f"Batch save completed: {success_count} success, {error_count} errors")
            
            return {
                'success_count': success_count,
                'error_count': error_count,
                'errors': errors
            }
            
        except Exception as e:
            error_msg = f"Batch save failed: {str(e)}"
            self.logger.error(error_msg)
            raise DatabaseError(error_msg) from e
    
    async def cleanup_old_data(self, keep_days: int = 90) -> Dict[str, Any]:
        """
        清理旧数据
        
        Args:
            keep_days: 保留天数
            
        Returns:
            清理结果
            
        Raises:
            DatabaseError: 数据库操作失败
        """
        try:
            from datetime import datetime, timedelta
            
            # 计算截止日期
            cutoff_date = (datetime.now() - timedelta(days=keep_days)).strftime('%Y%m%d')
            
            async with self.db_manager.get_connection() as db:
                # 开始事务
                await db.execute("BEGIN TRANSACTION")
                
                try:
                    # 删除旧的分布统计数据
                    cursor = await db.execute(
                        "DELETE FROM price_distribution_stats WHERE trade_date < ?",
                        (cutoff_date,)
                    )
                    stats_deleted = cursor.rowcount
                    
                    # 删除旧的元数据
                    cursor = await db.execute(
                        "DELETE FROM price_distribution_metadata WHERE trade_date < ?",
                        (cutoff_date,)
                    )
                    metadata_deleted = cursor.rowcount
                    
                    # 提交事务
                    await db.commit()
                    
                    result = {
                        'cutoff_date': cutoff_date,
                        'stats_deleted': stats_deleted,
                        'metadata_deleted': metadata_deleted,
                        'success': True
                    }
                    
                    self.logger.info(f"Cleanup completed: deleted {stats_deleted} stats records "
                                   f"and {metadata_deleted} metadata records before {cutoff_date}")
                    
                    return result
                    
                except Exception as e:
                    # 回滚事务
                    await db.rollback()
                    raise e
                    
        except Exception as e:
            error_msg = f"Failed to cleanup old data: {str(e)}"
            self.logger.error(error_msg)
            raise DatabaseError(error_msg) from e