"""
涨停统计专用缓存管理器

提供智能缓存策略，优先从数据库读取，支持数据同步和更新机制
"""

import asyncio
import logging
from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta
import pandas as pd

from .cache import CacheLayer
from .repository import LimitUpStatsRepository
from ..models import LimitUpStats, LimitUpStatsRequest
from .errors import DatabaseError, CacheError

logger = logging.getLogger(__name__)


class LimitUpCacheManager:
    """
    涨停统计缓存管理器
    
    实现智能缓存策略：
    1. 优先从数据库读取已保存的数据
    2. 内存缓存最近查询的结果
    3. 自动同步和更新机制
    4. 支持批量操作优化
    """
    
    def __init__(self, cache_layer: CacheLayer, repository: LimitUpStatsRepository):
        """
        初始化缓存管理器
        
        Args:
            cache_layer: 通用缓存层
            repository: 数据库操作层
        """
        self.cache_layer = cache_layer
        self.repository = repository
        self.logger = logger
        
        # 缓存配置
        self.memory_cache_expire_hours = 2  # 内存缓存2小时
        self.db_cache_expire_days = 30      # 数据库缓存30天
        
        # 统计信息
        self._stats = {
            'memory_hits': 0,
            'db_hits': 0,
            'cache_misses': 0,
            'total_requests': 0,
            'last_sync_time': None
        }
    
    async def get_limit_up_stats(self, request: LimitUpStatsRequest) -> Optional[LimitUpStats]:
        """
        获取涨停统计数据（智能缓存策略）
        
        Args:
            request: 涨停统计请求
            
        Returns:
            涨停统计数据，如果不存在则返回None
        """
        self._stats['total_requests'] += 1
        trade_date = request.trade_date
        
        try:
            # 1. 首先检查内存缓存
            if not request.force_refresh:
                memory_result = await self._get_from_memory_cache(trade_date)
                if memory_result:
                    self._stats['memory_hits'] += 1
                    self.logger.debug(f"内存缓存命中: {trade_date}")
                    return memory_result
            
            # 2. 检查数据库缓存
            if not request.force_refresh:
                db_result = await self._get_from_database(trade_date)
                if db_result:
                    self._stats['db_hits'] += 1
                    self.logger.debug(f"数据库缓存命中: {trade_date}")
                    
                    # 将数据库结果放入内存缓存
                    await self._set_memory_cache(trade_date, db_result)
                    return db_result
            
            # 3. 缓存未命中
            self._stats['cache_misses'] += 1
            self.logger.debug(f"缓存未命中: {trade_date}")
            return None
            
        except Exception as e:
            self.logger.error(f"获取缓存数据失败 {trade_date}: {e}")
            return None
    
    async def set_limit_up_stats(self, stats: LimitUpStats, 
                                save_to_db: bool = True) -> bool:
        """
        设置涨停统计数据到缓存
        
        Args:
            stats: 涨停统计数据
            save_to_db: 是否保存到数据库
            
        Returns:
            是否设置成功
        """
        try:
            trade_date = stats.trade_date
            
            # 1. 保存到内存缓存
            await self._set_memory_cache(trade_date, stats)
            
            # 2. 保存到数据库（如果需要）
            if save_to_db:
                success = await self.repository.save_limit_up_stats(stats)
                if success:
                    self.logger.info(f"涨停统计数据已保存到数据库: {trade_date}")
                else:
                    self.logger.warning(f"涨停统计数据保存到数据库失败: {trade_date}")
                    return False
            
            return True
            
        except Exception as e:
            self.logger.error(f"设置缓存数据失败 {stats.trade_date}: {e}")
            return False
    
    async def invalidate_cache(self, trade_date: str) -> bool:
        """
        使指定日期的缓存失效
        
        Args:
            trade_date: 交易日期
            
        Returns:
            是否成功
        """
        try:
            # 1. 清除内存缓存
            memory_key = self._build_memory_cache_key(trade_date)
            await self.cache_layer.delete(memory_key)
            
            # 2. 从数据库删除（可选，根据需求决定）
            # await self.repository.delete_limit_up_stats(trade_date)
            
            self.logger.info(f"缓存已失效: {trade_date}")
            return True
            
        except Exception as e:
            self.logger.error(f"缓存失效操作失败 {trade_date}: {e}")
            return False
    
    async def batch_get_stats(self, trade_dates: List[str], 
                            force_refresh: bool = False) -> Dict[str, Optional[LimitUpStats]]:
        """
        批量获取涨停统计数据
        
        Args:
            trade_dates: 交易日期列表
            force_refresh: 是否强制刷新
            
        Returns:
            日期到统计数据的映射
        """
        results = {}
        
        if not force_refresh:
            # 1. 批量检查内存缓存
            memory_results = await self._batch_get_from_memory(trade_dates)
            results.update(memory_results)
            
            # 2. 对于内存缓存未命中的，检查数据库
            missing_dates = [date for date in trade_dates if date not in results]
            if missing_dates:
                db_results = await self._batch_get_from_database(missing_dates)
                results.update(db_results)
                
                # 将数据库结果放入内存缓存
                for date, stats in db_results.items():
                    if stats:
                        await self._set_memory_cache(date, stats)
        
        # 3. 对于仍然缺失的日期，返回None
        for date in trade_dates:
            if date not in results:
                results[date] = None
        
        return results
    
    async def sync_database_to_memory(self, days: int = 7) -> Dict[str, Any]:
        """
        将数据库中最近的数据同步到内存缓存
        
        Args:
            days: 同步最近几天的数据
            
        Returns:
            同步结果统计
        """
        try:
            # 计算日期范围
            end_date = datetime.now().strftime('%Y%m%d')
            start_date = (datetime.now() - timedelta(days=days)).strftime('%Y%m%d')
            
            # 从数据库获取数据
            stats_list = await self.repository.query_limit_up_stats(
                start_date=start_date,
                end_date=end_date
            )
            
            # 同步到内存缓存
            synced_count = 0
            for stats in stats_list:
                await self._set_memory_cache(stats.trade_date, stats)
                synced_count += 1
            
            self._stats['last_sync_time'] = datetime.now().isoformat()
            
            result = {
                'synced_count': synced_count,
                'date_range': f"{start_date} - {end_date}",
                'sync_time': self._stats['last_sync_time']
            }
            
            self.logger.info(f"数据库到内存同步完成: {synced_count} 条记录")
            return result
            
        except Exception as e:
            error_msg = f"数据库同步失败: {e}"
            self.logger.error(error_msg)
            raise CacheError(error_msg) from e
    
    async def cleanup_expired_cache(self) -> Dict[str, Any]:
        """
        清理过期的缓存数据
        
        Returns:
            清理结果统计
        """
        try:
            # 清理内存缓存中的过期数据
            memory_cleaned = await self.cache_layer.clear_expired()
            
            # 清理数据库中的过期数据（可选）
            # 这里可以根据需求决定是否删除数据库中的旧数据
            
            result = {
                'memory_cleaned': memory_cleaned,
                'cleanup_time': datetime.now().isoformat()
            }
            
            self.logger.info(f"缓存清理完成: 内存清理 {memory_cleaned} 条")
            return result
            
        except Exception as e:
            error_msg = f"缓存清理失败: {e}"
            self.logger.error(error_msg)
            raise CacheError(error_msg) from e
    
    def get_cache_stats(self) -> Dict[str, Any]:
        """
        获取缓存统计信息
        
        Returns:
            缓存统计信息
        """
        stats = self._stats.copy()
        
        # 计算命中率
        total_requests = stats['total_requests']
        if total_requests > 0:
            memory_hit_rate = stats['memory_hits'] / total_requests
            db_hit_rate = stats['db_hits'] / total_requests
            total_hit_rate = (stats['memory_hits'] + stats['db_hits']) / total_requests
            
            stats.update({
                'memory_hit_rate': memory_hit_rate,
                'db_hit_rate': db_hit_rate,
                'total_hit_rate': total_hit_rate,
                'miss_rate': stats['cache_misses'] / total_requests
            })
        else:
            stats.update({
                'memory_hit_rate': 0.0,
                'db_hit_rate': 0.0,
                'total_hit_rate': 0.0,
                'miss_rate': 0.0
            })
        
        return stats
    
    def reset_stats(self):
        """重置统计信息"""
        self._stats = {
            'memory_hits': 0,
            'db_hits': 0,
            'cache_misses': 0,
            'total_requests': 0,
            'last_sync_time': None
        }
    
    # 私有方法
    
    async def _get_from_memory_cache(self, trade_date: str) -> Optional[LimitUpStats]:
        """从内存缓存获取数据"""
        try:
            cache_key = self._build_memory_cache_key(trade_date)
            cached_data = await self.cache_layer.get(cache_key)
            
            if cached_data is not None and isinstance(cached_data, dict):
                return LimitUpStats.from_dict(cached_data)
            
            return None
            
        except Exception as e:
            self.logger.warning(f"内存缓存读取失败 {trade_date}: {e}")
            return None
    
    async def _get_from_database(self, trade_date: str) -> Optional[LimitUpStats]:
        """从数据库获取数据"""
        try:
            return await self.repository.get_limit_up_stats(trade_date)
        except Exception as e:
            self.logger.warning(f"数据库读取失败 {trade_date}: {e}")
            return None
    
    async def _set_memory_cache(self, trade_date: str, stats: LimitUpStats):
        """设置内存缓存"""
        try:
            cache_key = self._build_memory_cache_key(trade_date)
            await self.cache_layer.set(
                cache_key, 
                stats.to_dict(), 
                self.memory_cache_expire_hours
            )
        except Exception as e:
            self.logger.warning(f"内存缓存设置失败 {trade_date}: {e}")
    
    async def _batch_get_from_memory(self, trade_dates: List[str]) -> Dict[str, LimitUpStats]:
        """批量从内存缓存获取数据"""
        results = {}
        
        for trade_date in trade_dates:
            stats = await self._get_from_memory_cache(trade_date)
            if stats:
                results[trade_date] = stats
        
        return results
    
    async def _batch_get_from_database(self, trade_dates: List[str]) -> Dict[str, LimitUpStats]:
        """批量从数据库获取数据"""
        results = {}
        
        try:
            # 使用日期范围查询优化批量获取
            if trade_dates:
                min_date = min(trade_dates)
                max_date = max(trade_dates)
                
                stats_list = await self.repository.query_limit_up_stats(
                    start_date=min_date,
                    end_date=max_date
                )
                
                # 构建结果映射
                for stats in stats_list:
                    if stats.trade_date in trade_dates:
                        results[stats.trade_date] = stats
        
        except Exception as e:
            self.logger.warning(f"批量数据库读取失败: {e}")
        
        return results
    
    def _build_memory_cache_key(self, trade_date: str) -> str:
        """构建内存缓存键"""
        return f"limit_up_stats:{trade_date}"


class CacheError(Exception):
    """缓存操作异常"""
    pass