"""
涨跌分布统计缓存管理器测试

测试PriceDistributionCacheManager的各项功能
"""

import asyncio
import json
import os
import tempfile
import unittest
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch

import pandas as pd
import pytest

from quickstock.core.price_distribution_cache import (
    CacheKeyManager,
    CacheEntry,
    MemoryCacheLayer,
    DatabaseCacheLayer,
    PriceDistributionCacheManager
)
from quickstock.models import (
    PriceDistributionStats,
    PriceDistributionRequest
)


class TestCacheKeyManager(unittest.TestCase):
    """缓存键管理器测试"""
    
    def setUp(self):
        self.key_manager = CacheKeyManager()
    
    def test_generate_key_price_distribution(self):
        """测试生成价格分布缓存键"""
        key = self.key_manager.generate_key(
            'price_distribution',
            trade_date='20240101',
            market='total',
            include_st=True
        )
        
        self.assertIn('price_dist:20240101:total:True:', key)
        self.assertEqual(len(key.split(':')), 5)
    
    def test_generate_key_stock_data(self):
        """测试生成股票数据缓存键"""
        key = self.key_manager.generate_key(
            'stock_data',
            trade_date='20240101'
        )
        
        self.assertEqual(key, 'stock_daily:20240101')
    
    def test_generate_key_invalid_template(self):
        """测试无效模板名称"""
        with self.assertRaises(ValueError):
            self.key_manager.generate_key('invalid_template')
    
    def test_generate_key_missing_params(self):
        """测试缺少必需参数"""
        with self.assertRaises(ValueError):
            self.key_manager.generate_key('price_distribution')
    
    def test_generate_hash(self):
        """测试生成哈希"""
        data1 = {'a': 1, 'b': 2}
        data2 = {'b': 2, 'a': 1}  # 不同顺序
        data3 = {'a': 1, 'b': 3}  # 不同值
        
        hash1 = self.key_manager.generate_hash(data1)
        hash2 = self.key_manager.generate_hash(data2)
        hash3 = self.key_manager.generate_hash(data3)
        
        # 相同数据应该产生相同哈希
        self.assertEqual(hash1, hash2)
        # 不同数据应该产生不同哈希
        self.assertNotEqual(hash1, hash3)
        # 哈希长度应该是8
        self.assertEqual(len(hash1), 8)
    
    def test_generate_hash_none(self):
        """测试None数据的哈希"""
        hash_value = self.key_manager.generate_hash(None)
        self.assertEqual(hash_value, "none")
    
    def test_parse_key_price_distribution(self):
        """测试解析价格分布缓存键"""
        key = 'price_dist:20240101:total:True:abcd1234'
        parsed = self.key_manager.parse_key(key)
        
        expected = {
            'type': 'price_dist',
            'trade_date': '20240101',
            'market': 'total',
            'include_st': True,
            'hash': 'abcd1234'
        }
        
        self.assertEqual(parsed, expected)
    
    def test_parse_key_stock_data(self):
        """测试解析股票数据缓存键"""
        key = 'stock_daily:20240101'
        parsed = self.key_manager.parse_key(key)
        
        expected = {
            'type': 'stock_daily',
            'trade_date': '20240101'
        }
        
        self.assertEqual(parsed, expected)
    
    def test_parse_key_invalid(self):
        """测试解析无效缓存键"""
        key = 'invalid'
        parsed = self.key_manager.parse_key(key)
        
        self.assertEqual(parsed, {'type': 'invalid'})
    
    def test_get_ttl(self):
        """测试获取TTL"""
        ttl = self.key_manager.get_ttl('price_distribution')
        self.assertEqual(ttl, 3600)
        
        # 测试默认TTL
        ttl = self.key_manager.get_ttl('unknown_template')
        self.assertEqual(ttl, 3600)


class TestCacheEntry(unittest.TestCase):
    """缓存条目测试"""
    
    def test_cache_entry_creation(self):
        """测试缓存条目创建"""
        data = {'test': 'data'}
        entry = CacheEntry(data)
        
        self.assertEqual(entry.data, data)
        self.assertIsNone(entry.expire_time)
        self.assertIsInstance(entry.created_time, datetime)
        self.assertEqual(entry.access_count, 0)
    
    def test_cache_entry_with_expiry(self):
        """测试带过期时间的缓存条目"""
        data = {'test': 'data'}
        expire_time = datetime.now() + timedelta(hours=1)
        entry = CacheEntry(data, expire_time)
        
        self.assertEqual(entry.expire_time, expire_time)
        self.assertFalse(entry.is_expired())
    
    def test_cache_entry_expired(self):
        """测试过期的缓存条目"""
        data = {'test': 'data'}
        expire_time = datetime.now() - timedelta(hours=1)
        entry = CacheEntry(data, expire_time)
        
        self.assertTrue(entry.is_expired())
    
    def test_update_access(self):
        """测试更新访问信息"""
        entry = CacheEntry({'test': 'data'})
        original_access_time = entry.access_time
        original_count = entry.access_count
        
        # 等待一小段时间确保时间戳不同
        import time
        time.sleep(0.01)
        
        entry.update_access()
        
        self.assertGreater(entry.access_time, original_access_time)
        self.assertEqual(entry.access_count, original_count + 1)
    
    def test_to_dict(self):
        """测试转换为字典"""
        data = {'test': 'data'}
        metadata = {'source': 'test'}
        entry = CacheEntry(data, metadata=metadata)
        
        result = entry.to_dict()
        
        self.assertEqual(result['data'], data)
        self.assertEqual(result['metadata'], metadata)
        self.assertIn('created_time', result)
        self.assertIn('access_time', result)
        self.assertIn('access_count', result)


class TestMemoryCacheLayer(unittest.IsolatedAsyncioTestCase):
    """内存缓存层测试"""
    
    def setUp(self):
        self.cache = MemoryCacheLayer(max_size=3)
    
    async def test_set_and_get(self):
        """测试设置和获取缓存"""
        key = 'test_key'
        data = {'test': 'data'}
        
        success = await self.cache.set(key, data)
        self.assertTrue(success)
        
        retrieved = await self.cache.get(key)
        self.assertEqual(retrieved, data)
    
    async def test_get_nonexistent(self):
        """测试获取不存在的缓存"""
        result = await self.cache.get('nonexistent')
        self.assertIsNone(result)
    
    async def test_set_with_ttl(self):
        """测试带TTL的缓存设置"""
        key = 'test_key'
        data = {'test': 'data'}
        
        # 设置1秒过期
        success = await self.cache.set(key, data, ttl=1)
        self.assertTrue(success)
        
        # 立即获取应该成功
        retrieved = await self.cache.get(key)
        self.assertEqual(retrieved, data)
        
        # 等待过期
        await asyncio.sleep(1.1)
        
        # 过期后应该返回None
        retrieved = await self.cache.get(key)
        self.assertIsNone(retrieved)
    
    async def test_delete(self):
        """测试删除缓存"""
        key = 'test_key'
        data = {'test': 'data'}
        
        await self.cache.set(key, data)
        
        # 删除存在的键
        success = await self.cache.delete(key)
        self.assertTrue(success)
        
        # 确认已删除
        retrieved = await self.cache.get(key)
        self.assertIsNone(retrieved)
        
        # 删除不存在的键
        success = await self.cache.delete('nonexistent')
        self.assertFalse(success)
    
    async def test_lru_eviction(self):
        """测试LRU淘汰策略"""
        # 填满缓存
        for i in range(3):
            await self.cache.set(f'key_{i}', f'data_{i}')
        
        # 访问key_0以更新其位置
        await self.cache.get('key_0')
        
        # 添加新键，应该淘汰key_1（最久未使用）
        await self.cache.set('key_3', 'data_3')
        
        # 检查淘汰结果
        self.assertIsNotNone(await self.cache.get('key_0'))  # 最近访问过
        self.assertIsNone(await self.cache.get('key_1'))     # 应该被淘汰
        self.assertIsNotNone(await self.cache.get('key_2'))  # 仍然存在
        self.assertIsNotNone(await self.cache.get('key_3'))  # 新添加的
    
    async def test_clear(self):
        """测试清空缓存"""
        # 添加一些数据
        for i in range(3):
            await self.cache.set(f'key_{i}', f'data_{i}')
        
        # 清空缓存
        count = await self.cache.clear()
        self.assertEqual(count, 3)
        
        # 确认已清空
        for i in range(3):
            result = await self.cache.get(f'key_{i}')
            self.assertIsNone(result)
    
    async def test_clear_expired(self):
        """测试清理过期缓存"""
        # 添加一些数据，部分设置为很快过期
        await self.cache.set('key_1', 'data_1', ttl=1)  # 1秒过期
        await self.cache.set('key_2', 'data_2')         # 不过期
        await self.cache.set('key_3', 'data_3', ttl=1)  # 1秒过期
        
        # 等待过期
        await asyncio.sleep(1.1)
        
        # 清理过期缓存
        cleared = await self.cache.clear_expired()
        self.assertEqual(cleared, 2)
        
        # 检查结果
        self.assertIsNone(await self.cache.get('key_1'))
        self.assertIsNotNone(await self.cache.get('key_2'))
        self.assertIsNone(await self.cache.get('key_3'))
    
    def test_get_stats(self):
        """测试获取统计信息"""
        stats = self.cache.get_stats()
        
        self.assertIn('size', stats)
        self.assertIn('max_size', stats)
        self.assertIn('usage_ratio', stats)
        self.assertIn('hit_rate', stats)
        self.assertIn('hits', stats)
        self.assertIn('misses', stats)
        
        self.assertEqual(stats['max_size'], 3)


class TestDatabaseCacheLayer(unittest.IsolatedAsyncioTestCase):
    """数据库缓存层测试"""
    
    def setUp(self):
        self.temp_dir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.temp_dir, 'test_cache.db')
        self.cache = DatabaseCacheLayer(self.db_path)
    
    def tearDown(self):
        import shutil
        shutil.rmtree(self.temp_dir)
    
    async def test_set_and_get_dict(self):
        """测试设置和获取字典数据"""
        key = 'test_key'
        data = {'test': 'data', 'number': 123}
        
        success = await self.cache.set(key, data)
        self.assertTrue(success)
        
        retrieved = await self.cache.get(key)
        self.assertEqual(retrieved, data)
    
    async def test_set_and_get_dataframe(self):
        """测试设置和获取DataFrame数据"""
        key = 'test_df'
        data = pd.DataFrame({
            'A': [1, 2, 3],
            'B': ['a', 'b', 'c']
        })
        
        success = await self.cache.set(key, data)
        self.assertTrue(success)
        
        retrieved = await self.cache.get(key)
        pd.testing.assert_frame_equal(retrieved, data)
    
    async def test_set_and_get_list(self):
        """测试设置和获取列表数据"""
        key = 'test_list'
        data = [1, 2, 3, 'test']
        
        success = await self.cache.set(key, data)
        self.assertTrue(success)
        
        retrieved = await self.cache.get(key)
        self.assertEqual(retrieved, data)
    
    async def test_get_nonexistent(self):
        """测试获取不存在的缓存"""
        result = await self.cache.get('nonexistent')
        self.assertIsNone(result)
    
    async def test_set_with_ttl(self):
        """测试带TTL的缓存设置"""
        key = 'test_key'
        data = {'test': 'data'}
        
        # 设置1秒过期
        success = await self.cache.set(key, data, ttl=1)
        self.assertTrue(success)
        
        # 立即获取应该成功
        retrieved = await self.cache.get(key)
        self.assertEqual(retrieved, data)
        
        # 等待过期
        await asyncio.sleep(1.1)
        
        # 过期后应该返回None
        retrieved = await self.cache.get(key)
        self.assertIsNone(retrieved)
    
    async def test_delete(self):
        """测试删除缓存"""
        key = 'test_key'
        data = {'test': 'data'}
        
        await self.cache.set(key, data)
        
        # 删除存在的键
        success = await self.cache.delete(key)
        self.assertTrue(success)
        
        # 确认已删除
        retrieved = await self.cache.get(key)
        self.assertIsNone(retrieved)
        
        # 删除不存在的键
        success = await self.cache.delete('nonexistent')
        self.assertFalse(success)
    
    async def test_clear(self):
        """测试清空缓存"""
        # 添加一些数据
        for i in range(3):
            await self.cache.set(f'key_{i}', f'data_{i}')
        
        # 清空缓存
        count = await self.cache.clear()
        self.assertEqual(count, 3)
        
        # 确认已清空
        for i in range(3):
            result = await self.cache.get(f'key_{i}')
            self.assertIsNone(result)
    
    async def test_clear_expired(self):
        """测试清理过期缓存"""
        # 添加一些数据，部分设置为很快过期
        await self.cache.set('key_1', 'data_1', ttl=1)  # 1秒过期
        await self.cache.set('key_2', 'data_2')         # 不过期
        await self.cache.set('key_3', 'data_3', ttl=1)  # 1秒过期
        
        # 等待过期
        await asyncio.sleep(1.1)
        
        # 清理过期缓存
        cleared = await self.cache.clear_expired()
        self.assertEqual(cleared, 2)
        
        # 检查结果
        self.assertIsNone(await self.cache.get('key_1'))
        self.assertIsNotNone(await self.cache.get('key_2'))
        self.assertIsNone(await self.cache.get('key_3'))
    
    async def test_delete_by_pattern(self):
        """测试按模式删除缓存"""
        # 添加一些数据
        await self.cache.set('price_dist:20240101:total', 'data1')
        await self.cache.set('price_dist:20240101:shanghai', 'data2')
        await self.cache.set('price_dist:20240102:total', 'data3')
        await self.cache.set('stock_daily:20240101', 'data4')
        
        # 按模式删除
        deleted = await self.cache.delete_by_pattern('price_dist:20240101:%')
        self.assertEqual(deleted, 2)
        
        # 检查结果
        self.assertIsNone(await self.cache.get('price_dist:20240101:total'))
        self.assertIsNone(await self.cache.get('price_dist:20240101:shanghai'))
        self.assertIsNotNone(await self.cache.get('price_dist:20240102:total'))
        self.assertIsNotNone(await self.cache.get('stock_daily:20240101'))
    
    def test_get_stats(self):
        """测试获取统计信息"""
        stats = self.cache.get_stats()
        
        self.assertIn('total_entries', stats)
        self.assertIn('expired_entries', stats)
        self.assertIn('valid_entries', stats)
        self.assertIn('db_size_bytes', stats)
        self.assertIn('db_size_mb', stats)
        self.assertIn('hit_rate', stats)


class TestPriceDistributionCacheManager(unittest.IsolatedAsyncioTestCase):
    """涨跌分布统计缓存管理器测试"""
    
    def setUp(self):
        self.temp_dir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.temp_dir, 'test_cache.db')
        self.cache_manager = PriceDistributionCacheManager(
            db_path=self.db_path,
            memory_cache_size=10
        )
    
    def tearDown(self):
        import shutil
        shutil.rmtree(self.temp_dir)
    
    async def test_get_set_distribution_stats(self):
        """测试获取和设置分布统计缓存"""
        trade_date = '20240101'
        stats_data = {
            'trade_date': trade_date,
            'total_stocks': 100,
            'positive_ranges': {'0-3%': 20, '3-5%': 15},
            'negative_ranges': {'0到-3%': 25, '-3到-5%': 10}
        }
        
        # 设置缓存
        success = await self.cache_manager.set_distribution_stats(
            trade_date, stats_data
        )
        self.assertTrue(success)
        
        # 获取缓存
        retrieved = await self.cache_manager.get_distribution_stats(trade_date)
        self.assertEqual(retrieved, stats_data)
    
    async def test_get_set_stock_data_cache(self):
        """测试获取和设置股票数据缓存"""
        trade_date = '20240101'
        stock_data = pd.DataFrame({
            'ts_code': ['000001.SZ', '000002.SZ'],
            'close': [10.0, 20.0],
            'pct_chg': [2.5, -1.5]
        })
        
        # 设置缓存
        success = await self.cache_manager.set_stock_data_cache(
            trade_date, stock_data
        )
        self.assertTrue(success)
        
        # 获取缓存
        retrieved = await self.cache_manager.get_stock_data_cache(trade_date)
        pd.testing.assert_frame_equal(retrieved, stock_data)
    
    async def test_get_set_market_classification_cache(self):
        """测试获取和设置市场分类缓存"""
        trade_date = '20240101'
        classification = {
            'shanghai': ['000001.SH', '000002.SH'],
            'shenzhen': ['000001.SZ', '000002.SZ']
        }
        
        # 设置缓存
        success = await self.cache_manager.set_market_classification_cache(
            trade_date, classification
        )
        self.assertTrue(success)
        
        # 获取缓存
        retrieved = await self.cache_manager.get_market_classification_cache(trade_date)
        self.assertEqual(retrieved, classification)
    
    async def test_batch_operations(self):
        """测试批量操作"""
        cache_data = {
            'key1': {'data': 'value1'},
            'key2': {'data': 'value2'},
            'key3': {'data': 'value3'}
        }
        
        # 批量设置
        results = await self.cache_manager.batch_set_cache(cache_data)
        for key, success in results.items():
            self.assertTrue(success)
        
        # 批量获取
        retrieved = await self.cache_manager.batch_get_cache(list(cache_data.keys()))
        self.assertEqual(len(retrieved), 3)
        for key, data in cache_data.items():
            self.assertEqual(retrieved[key], data)
    
    async def test_delete_distribution_stats(self):
        """测试删除分布统计缓存"""
        trade_date = '20240101'
        
        # 设置多个缓存
        await self.cache_manager.set_distribution_stats(
            trade_date, {'data': 'total'}, market='total'
        )
        await self.cache_manager.set_distribution_stats(
            trade_date, {'data': 'shanghai'}, market='shanghai'
        )
        
        # 删除特定市场的缓存
        deleted = await self.cache_manager.delete_distribution_stats(
            trade_date, market='total'
        )
        self.assertEqual(deleted, 1)
        
        # 检查结果
        result1 = await self.cache_manager.get_distribution_stats(
            trade_date, market='total'
        )
        result2 = await self.cache_manager.get_distribution_stats(
            trade_date, market='shanghai'
        )
        
        self.assertIsNone(result1)
        self.assertIsNotNone(result2)
    
    async def test_delete_by_trade_date(self):
        """测试按交易日期删除缓存"""
        trade_date = '20240101'
        
        # 设置多种类型的缓存
        await self.cache_manager.set_distribution_stats(trade_date, {'data': 'stats'})
        await self.cache_manager.set_stock_data_cache(
            trade_date, pd.DataFrame({'A': [1, 2]})
        )
        await self.cache_manager.set_market_classification_cache(
            trade_date, {'shanghai': []}
        )
        
        # 按日期删除
        deleted = await self.cache_manager.delete_by_trade_date(trade_date)
        self.assertGreater(deleted, 0)
        
        # 检查结果
        self.assertIsNone(
            await self.cache_manager.get_distribution_stats(trade_date)
        )
        self.assertIsNone(
            await self.cache_manager.get_stock_data_cache(trade_date)
        )
        self.assertIsNone(
            await self.cache_manager.get_market_classification_cache(trade_date)
        )
    
    async def test_clear_expired_cache(self):
        """测试清理过期缓存"""
        trade_date = '20240101'
        
        # 设置快速过期的缓存
        await self.cache_manager.set_distribution_stats(
            trade_date, {'data': 'test'}, ttl=1
        )
        
        # 等待过期
        await asyncio.sleep(1.1)
        
        # 清理过期缓存
        cleared = await self.cache_manager.clear_expired_cache()
        self.assertGreater(cleared, 0)
        
        # 检查结果
        result = await self.cache_manager.get_distribution_stats(trade_date)
        self.assertIsNone(result)
    
    async def test_clear_all_cache(self):
        """测试清空所有缓存"""
        # 设置一些缓存
        await self.cache_manager.set_distribution_stats('20240101', {'data': '1'})
        await self.cache_manager.set_distribution_stats('20240102', {'data': '2'})
        
        # 清空所有缓存
        cleared = await self.cache_manager.clear_all_cache()
        self.assertGreater(cleared, 0)
        
        # 检查结果
        result1 = await self.cache_manager.get_distribution_stats('20240101')
        result2 = await self.cache_manager.get_distribution_stats('20240102')
        
        self.assertIsNone(result1)
        self.assertIsNone(result2)
    
    async def test_get_cache_info(self):
        """测试获取缓存信息"""
        # 设置一些缓存
        trade_date = '20240101'
        await self.cache_manager.set_distribution_stats(trade_date, {'data': 'test'})
        
        # 获取缓存信息
        info = await self.cache_manager.get_cache_info()
        
        self.assertIn('memory_cache', info)
        self.assertIn('database_cache', info)
        self.assertIn('global_stats', info)
        
        # 获取特定日期的缓存信息
        date_info = await self.cache_manager.get_cache_info(trade_date)
        self.assertIn('trade_date_info', date_info)
    
    async def test_get_cache_statistics(self):
        """测试获取缓存统计信息"""
        # 设置一些缓存并访问
        await self.cache_manager.set_distribution_stats('20240101', {'data': 'test'})
        await self.cache_manager.get_distribution_stats('20240101')
        
        stats = await self.cache_manager.get_cache_statistics()
        
        self.assertIn('memory_cache', stats)
        self.assertIn('database_cache', stats)
        self.assertIn('global_stats', stats)
        self.assertIn('overall_hit_rate', stats)
    
    async def test_validate_cache_consistency(self):
        """测试验证缓存一致性"""
        trade_date = '20240101'
        
        # 设置缓存
        await self.cache_manager.set_distribution_stats(trade_date, {'data': 'test'})
        
        # 验证一致性
        result = await self.cache_manager.validate_cache_consistency(trade_date)
        
        self.assertIn('trade_date', result)
        self.assertIn('consistent', result)
        self.assertIn('memory_entries', result)
        self.assertIn('database_entries', result)
        self.assertEqual(result['trade_date'], trade_date)
    
    async def test_refresh_cache(self):
        """测试刷新缓存"""
        trade_date = '20240101'
        
        # 设置缓存
        await self.cache_manager.set_distribution_stats(trade_date, {'data': 'test'})
        
        # 普通刷新
        result = await self.cache_manager.refresh_cache(trade_date)
        self.assertIn('refreshed', result)
        self.assertTrue(result['refreshed'])
        
        # 强制刷新
        result = await self.cache_manager.refresh_cache(trade_date, force=True)
        self.assertIn('refreshed', result)
        self.assertTrue(result['refreshed'])
        self.assertGreater(result['deleted_entries'], 0)
    
    async def test_multilayer_cache_strategy(self):
        """测试多层缓存策略"""
        trade_date = '20240101'
        data = {'test': 'multilayer'}
        
        # 设置缓存（应该同时设置内存和数据库缓存）
        await self.cache_manager.set_distribution_stats(trade_date, data)
        
        # 清空内存缓存
        await self.cache_manager.memory_cache.clear()
        
        # 从数据库缓存获取（应该自动同步到内存缓存）
        retrieved = await self.cache_manager.get_distribution_stats(trade_date)
        self.assertEqual(retrieved, data)
        
        # 再次获取应该从内存缓存获取
        retrieved2 = await self.cache_manager.get_distribution_stats(trade_date)
        self.assertEqual(retrieved2, data)
    
    async def test_cache_without_database(self):
        """测试没有数据库的缓存管理器"""
        cache_manager = PriceDistributionCacheManager(
            db_path=None,  # 不使用数据库缓存
            memory_cache_size=10
        )
        
        trade_date = '20240101'
        data = {'test': 'memory_only'}
        
        # 设置和获取缓存
        success = await cache_manager.set_distribution_stats(trade_date, data)
        self.assertTrue(success)
        
        retrieved = await cache_manager.get_distribution_stats(trade_date)
        self.assertEqual(retrieved, data)
        
        # 获取统计信息
        stats = await cache_manager.get_cache_statistics()
        self.assertIn('memory_cache', stats)
        self.assertNotIn('database_cache', stats)


if __name__ == '__main__':
    # 运行测试
    unittest.main()