"""
涨跌分布统计数据库缓存层

提供数据库作为缓存层的功能，与内存缓存和Redis缓存协同工作
"""

import logging
from typing import Any, Dict, List, Optional

from .database import DatabaseManager
from .price_distribution_repository import PriceDistributionRepository
from ..models.price_distribution_models import PriceDistributionStats

logger = logging.getLogger(__name__)


class PriceDistributionDatabaseCache:
    """涨跌分布统计数据库缓存层"""
    
    def __init__(self, db_manager: DatabaseManager):
        """
        初始化数据库缓存层
        
        Args:
            db_manager: 数据库管理器实例
        """
        self.db_manager = db_manager
        self.repository = PriceDistributionRepository(db_manager)
        self.logger = logger
    
    async def get(self, trade_date: str) -> Optional[PriceDistributionStats]:
        """
        从数据库缓存获取统计数据
        
        Args:
            trade_date: 交易日期 (YYYYMMDD)
            
        Returns:
            统计数据，如果不存在则返回None
        """
        try:
            stats = await self.repository.get_distribution_stats(trade_date)
            if stats:
                self.logger.debug(f"Database cache hit for {trade_date}")
            else:
                self.logger.debug(f"Database cache miss for {trade_date}")
            return stats
        except Exception as e:
            self.logger.error(f"Database cache get error for {trade_date}: {str(e)}")
            return None
    
    async def set(self, trade_date: str, stats: PriceDistributionStats, ttl: Optional[int] = None) -> bool:
        """
        设置数据库缓存
        
        Args:
            trade_date: 交易日期 (YYYYMMDD)
            stats: 统计数据
            ttl: 生存时间（数据库缓存忽略此参数）
            
        Returns:
            设置是否成功
        """
        try:
            success = await self.repository.save_distribution_stats(stats)
            if success:
                self.logger.debug(f"Database cache set for {trade_date}")
            return success
        except Exception as e:
            self.logger.error(f"Database cache set error for {trade_date}: {str(e)}")
            return False
    
    async def delete(self, trade_date: str) -> bool:
        """
        删除数据库缓存
        
        Args:
            trade_date: 交易日期 (YYYYMMDD)
            
        Returns:
            删除是否成功
        """
        try:
            success = await self.repository.delete_distribution_stats(trade_date)
            if success:
                self.logger.debug(f"Database cache deleted for {trade_date}")
            return success
        except Exception as e:
            self.logger.error(f"Database cache delete error for {trade_date}: {str(e)}")
            return False
    
    async def exists(self, trade_date: str) -> bool:
        """
        检查数据库缓存是否存在
        
        Args:
            trade_date: 交易日期 (YYYYMMDD)
            
        Returns:
            是否存在
        """
        try:
            stats = await self.repository.get_distribution_stats(trade_date)
            return stats is not None
        except Exception as e:
            self.logger.error(f"Database cache exists check error for {trade_date}: {str(e)}")
            return False
    
    async def clear(self, pattern: Optional[str] = None) -> int:
        """
        清理数据库缓存
        
        Args:
            pattern: 清理模式（可选）
                - None: 清理所有数据
                - "old": 清理旧数据（90天前）
                - 具体日期: 清理指定日期
                
        Returns:
            清理的记录数
        """
        try:
            if pattern is None:
                # 清理所有数据 - 谨慎操作
                self.logger.warning("Clearing all database cache data")
                # 这里可以实现清理所有数据的逻辑
                return 0
            elif pattern == "old":
                # 清理旧数据
                result = await self.repository.cleanup_old_data(keep_days=90)
                return result.get('stats_deleted', 0) + result.get('metadata_deleted', 0)
            else:
                # 清理指定日期
                success = await self.repository.delete_distribution_stats(pattern)
                return 1 if success else 0
        except Exception as e:
            self.logger.error(f"Database cache clear error: {str(e)}")
            return 0
    
    async def get_cache_info(self) -> Dict[str, Any]:
        """
        获取数据库缓存信息
        
        Returns:
            缓存信息字典
        """
        try:
            # 获取数据库统计信息
            db_stats = await self.db_manager.get_database_stats()
            
            # 获取统计摘要
            summary = await self.repository.get_stats_summary()
            
            # 获取可用日期
            available_dates = await self.repository.get_available_dates()
            
            return {
                'type': 'database',
                'database_stats': db_stats,
                'summary': summary,
                'available_dates_count': len(available_dates),
                'latest_dates': available_dates[:10] if available_dates else [],
                'cache_size': db_stats.get('db_size', 0),
                'total_records': summary.get('total_records', 0)
            }
        except Exception as e:
            self.logger.error(f"Database cache info error: {str(e)}")
            return {
                'type': 'database',
                'error': str(e)
            }
    
    async def batch_get(self, trade_dates: List[str]) -> Dict[str, Optional[PriceDistributionStats]]:
        """
        批量获取统计数据
        
        Args:
            trade_dates: 交易日期列表
            
        Returns:
            日期到统计数据的映射
        """
        result = {}
        for trade_date in trade_dates:
            try:
                stats = await self.get(trade_date)
                result[trade_date] = stats
            except Exception as e:
                self.logger.error(f"Batch get error for {trade_date}: {str(e)}")
                result[trade_date] = None
        
        return result
    
    async def batch_set(self, stats_dict: Dict[str, PriceDistributionStats]) -> Dict[str, bool]:
        """
        批量设置统计数据
        
        Args:
            stats_dict: 日期到统计数据的映射
            
        Returns:
            日期到设置结果的映射
        """
        result = {}
        stats_list = list(stats_dict.values())
        
        try:
            # 使用批量保存方法
            batch_result = await self.repository.batch_save_distribution_stats(stats_list)
            
            # 根据批量结果设置返回值
            for trade_date in stats_dict.keys():
                result[trade_date] = True  # 假设成功，实际应该根据具体错误判断
            
            # 处理错误
            if batch_result['errors']:
                for error in batch_result['errors']:
                    # 从错误信息中提取日期（简化处理）
                    for trade_date in stats_dict.keys():
                        if trade_date in error:
                            result[trade_date] = False
                            break
            
            return result
            
        except Exception as e:
            self.logger.error(f"Batch set error: {str(e)}")
            # 所有设置都失败
            return {trade_date: False for trade_date in stats_dict.keys()}
    
    async def get_date_range_stats(self, start_date: str, end_date: str) -> List[PriceDistributionStats]:
        """
        获取日期范围内的统计数据
        
        Args:
            start_date: 开始日期 (YYYYMMDD)
            end_date: 结束日期 (YYYYMMDD)
            
        Returns:
            统计数据列表
        """
        try:
            # 获取日期范围内的可用日期
            available_dates = await self.repository.get_available_dates(start_date, end_date)
            
            # 批量获取统计数据
            stats_list = []
            for trade_date in available_dates:
                stats = await self.get(trade_date)
                if stats:
                    stats_list.append(stats)
            
            return stats_list
            
        except Exception as e:
            self.logger.error(f"Get date range stats error: {str(e)}")
            return []
    
    async def validate_cache_integrity(self) -> Dict[str, Any]:
        """
        验证数据库缓存完整性
        
        Returns:
            验证结果
        """
        try:
            validation_result = {
                'valid': True,
                'errors': [],
                'warnings': [],
                'stats': {}
            }
            
            # 验证数据库模式
            schema_validation = await self.db_manager.validate_schema()
            if not schema_validation['valid']:
                validation_result['valid'] = False
                validation_result['errors'].extend(schema_validation['errors'])
            
            # 获取统计信息
            summary = await self.repository.get_stats_summary()
            validation_result['stats'] = summary
            
            # 检查数据一致性
            available_dates = await self.repository.get_available_dates()
            if available_dates:
                # 随机检查几个日期的数据完整性
                import random
                sample_dates = random.sample(available_dates, min(5, len(available_dates)))
                
                for trade_date in sample_dates:
                    try:
                        stats = await self.get(trade_date)
                        if stats:
                            stats.validate()  # 验证数据模型
                        else:
                            validation_result['warnings'].append(f"No data found for {trade_date}")
                    except Exception as e:
                        validation_result['errors'].append(f"Data validation failed for {trade_date}: {str(e)}")
                        validation_result['valid'] = False
            
            return validation_result
            
        except Exception as e:
            return {
                'valid': False,
                'errors': [f"Cache integrity validation failed: {str(e)}"],
                'warnings': [],
                'stats': {}
            }
    
    async def optimize_cache(self) -> Dict[str, Any]:
        """
        优化数据库缓存
        
        Returns:
            优化结果
        """
        try:
            optimization_result = {
                'success': True,
                'actions': [],
                'stats': {}
            }
            
            # 清理旧数据
            cleanup_result = await self.repository.cleanup_old_data(keep_days=90)
            if cleanup_result['success']:
                optimization_result['actions'].append(f"Cleaned up old data: {cleanup_result['stats_deleted']} records")
            
            # 获取优化后的统计信息
            optimization_result['stats'] = await self.repository.get_stats_summary()
            
            return optimization_result
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'actions': [],
                'stats': {}
            }