"""
涨跌分布统计数据库存储测试

测试数据库存储和查询功能
"""

import asyncio
import json
import os
import tempfile
import unittest
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from quickstock.core.database import DatabaseManager
from quickstock.core.price_distribution_repository import PriceDistributionRepository
from quickstock.core.price_distribution_database_cache import PriceDistributionDatabaseCache
from quickstock.core.price_distribution_cache import PriceDistributionCacheManager
from quickstock.models.price_distribution_models import PriceDistributionStats
from quickstock.core.errors import DatabaseError, ValidationError


class TestPriceDistributionRepository(unittest.IsolatedAsyncioTestCase):
    """涨跌分布统计存储库测试"""
    
    async def asyncSetUp(self):
        """设置测试环境"""
        # 创建临时数据库
        self.temp_dir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.temp_dir, 'test_quickstock.db')
        
        # 初始化数据库管理器
        self.db_manager = DatabaseManager(self.db_path)
        await self.db_manager.initialize()
        
        # 初始化存储库
        self.repository = PriceDistributionRepository(self.db_manager)
        
        # 创建测试数据
        self.test_stats = PriceDistributionStats(
            trade_date='20240115',
            total_stocks=1000,
            positive_ranges={
                '0-3%': 200,
                '3-5%': 150,
                '5-7%': 100,
                '7-10%': 80,
                '>=10%': 50
            },
            positive_percentages={
                '0-3%': 20.0,
                '3-5%': 15.0,
                '5-7%': 10.0,
                '7-10%': 8.0,
                '>=10%': 5.0
            },
            negative_ranges={
                '0到-3%': 180,
                '-3到-5%': 120,
                '-5到-7%': 70,
                '-7到-10%': 40,
                '<=-10%': 20
            },
            negative_percentages={
                '0到-3%': 18.0,
                '-3到-5%': 12.0,
                '-5到-7%': 7.0,
                '-7到-10%': 4.0,
                '<=-10%': 2.0
            },
            market_breakdown={
                'shanghai': {
                    'total_stocks': 400,
                    'positive_ranges': {'0-3%': 80, '3-5%': 60},
                    'positive_percentages': {'0-3%': 20.0, '3-5%': 15.0},
                    'negative_ranges': {'0到-3%': 70, '-3到-5%': 50},
                    'negative_percentages': {'0到-3%': 17.5, '-3到-5%': 12.5},
                    'stock_codes': {
                        '0-3%': ['000001.SZ', '000002.SZ'],
                        '3-5%': ['600000.SH', '600001.SH']
                    }
                },
                'shenzhen': {
                    'total_stocks': 600,
                    'positive_ranges': {'0-3%': 120, '3-5%': 90},
                    'positive_percentages': {'0-3%': 20.0, '3-5%': 15.0},
                    'negative_ranges': {'0到-3%': 110, '-3到-5%': 70},
                    'negative_percentages': {'0到-3%': 18.3, '-3到-5%': 11.7}
                }
            },
            processing_time=2.5,
            data_quality_score=0.95
        )
    
    async def asyncTearDown(self):
        """清理测试环境"""
        # 关闭数据库连接
        self.db_manager.close()
        
        # 删除临时文件
        import shutil
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    async def test_save_distribution_stats(self):
        """测试保存分布统计数据"""
        # 保存数据
        success = await self.repository.save_distribution_stats(self.test_stats)
        self.assertTrue(success)
        
        # 验证数据是否保存成功
        saved_stats = await self.repository.get_distribution_stats('20240115')
        self.assertIsNotNone(saved_stats)
        self.assertEqual(saved_stats.trade_date, '20240115')
        self.assertEqual(saved_stats.total_stocks, 1000)
        self.assertEqual(saved_stats.positive_ranges['0-3%'], 200)
        self.assertEqual(saved_stats.negative_ranges['0到-3%'], 180)
        self.assertAlmostEqual(saved_stats.processing_time, 2.5)
        self.assertAlmostEqual(saved_stats.data_quality_score, 0.95)
    
    async def test_save_distribution_stats_with_market_breakdown(self):
        """测试保存包含市场板块分布的统计数据"""
        success = await self.repository.save_distribution_stats(self.test_stats)
        self.assertTrue(success)
        
        # 验证市场板块数据
        saved_stats = await self.repository.get_distribution_stats('20240115')
        self.assertIn('shanghai', saved_stats.market_breakdown)
        self.assertIn('shenzhen', saved_stats.market_breakdown)
        
        shanghai_data = saved_stats.market_breakdown['shanghai']
        self.assertEqual(shanghai_data['positive_ranges']['0-3%'], 80)
        self.assertEqual(shanghai_data['negative_ranges']['0到-3%'], 70)
        
        # 验证股票代码是否正确保存和恢复
        if 'stock_codes' in shanghai_data:
            self.assertIn('000001.SZ', shanghai_data['stock_codes']['0-3%'])
    
    async def test_get_nonexistent_distribution_stats(self):
        """测试获取不存在的分布统计数据"""
        stats = await self.repository.get_distribution_stats('20240101')
        self.assertIsNone(stats)
    
    async def test_delete_distribution_stats(self):
        """测试删除分布统计数据"""
        # 先保存数据
        await self.repository.save_distribution_stats(self.test_stats)
        
        # 验证数据存在
        stats = await self.repository.get_distribution_stats('20240115')
        self.assertIsNotNone(stats)
        
        # 删除数据
        success = await self.repository.delete_distribution_stats('20240115')
        self.assertTrue(success)
        
        # 验证数据已删除
        stats = await self.repository.get_distribution_stats('20240115')
        self.assertIsNone(stats)
    
    async def test_get_available_dates(self):
        """测试获取可用日期列表"""
        # 保存多个日期的数据
        dates = ['20240115', '20240116', '20240117']
        for date in dates:
            test_stats = PriceDistributionStats(
                trade_date=date,
                total_stocks=1000,
                positive_ranges={'0-3%': 100},
                positive_percentages={'0-3%': 10.0},
                negative_ranges={'0到-3%': 100},
                negative_percentages={'0到-3%': 10.0},
                market_breakdown={}
            )
            await self.repository.save_distribution_stats(test_stats)
        
        # 获取所有可用日期
        available_dates = await self.repository.get_available_dates()
        self.assertEqual(len(available_dates), 3)
        for date in dates:
            self.assertIn(date, available_dates)
        
        # 测试日期范围过滤
        filtered_dates = await self.repository.get_available_dates('20240116', '20240117')
        self.assertEqual(len(filtered_dates), 2)
        self.assertIn('20240116', filtered_dates)
        self.assertIn('20240117', filtered_dates)
        self.assertNotIn('20240115', filtered_dates)
    
    async def test_get_stats_summary(self):
        """测试获取统计摘要"""
        # 保存测试数据
        await self.repository.save_distribution_stats(self.test_stats)
        
        # 获取统计摘要
        summary = await self.repository.get_stats_summary()
        
        self.assertEqual(summary['total_records'], 1)
        self.assertEqual(summary['date_range']['earliest'], '20240115')
        self.assertEqual(summary['date_range']['latest'], '20240115')
        self.assertEqual(summary['averages']['total_stocks'], 1000.0)
        self.assertAlmostEqual(summary['averages']['processing_time'], 2.5)
        self.assertAlmostEqual(summary['averages']['quality_score'], 0.95)
        
        # 验证市场统计
        self.assertIn('total', summary['market_stats'])
        self.assertIn('shanghai', summary['market_stats'])
        self.assertIn('shenzhen', summary['market_stats'])
    
    async def test_batch_save_distribution_stats(self):
        """测试批量保存分布统计数据"""
        # 创建多个测试数据
        stats_list = []
        for i, date in enumerate(['20240115', '20240116', '20240117']):
            stats = PriceDistributionStats(
                trade_date=date,
                total_stocks=1000 + i * 100,
                positive_ranges={'0-3%': 100 + i * 10},
                positive_percentages={'0-3%': 10.0 + i},
                negative_ranges={'0到-3%': 100 + i * 10},
                negative_percentages={'0到-3%': 10.0 + i},
                market_breakdown={}
            )
            stats_list.append(stats)
        
        # 批量保存
        result = await self.repository.batch_save_distribution_stats(stats_list)
        
        self.assertEqual(result['success_count'], 3)
        self.assertEqual(result['error_count'], 0)
        self.assertEqual(len(result['errors']), 0)
        
        # 验证数据是否保存成功
        for date in ['20240115', '20240116', '20240117']:
            stats = await self.repository.get_distribution_stats(date)
            self.assertIsNotNone(stats)
    
    async def test_cleanup_old_data(self):
        """测试清理旧数据"""
        # 保存测试数据
        await self.repository.save_distribution_stats(self.test_stats)
        
        # 清理数据（保留0天，即清理所有数据）
        result = await self.repository.cleanup_old_data(keep_days=0)
        
        self.assertTrue(result['success'])
        self.assertGreater(result['stats_deleted'], 0)
        self.assertGreater(result['metadata_deleted'], 0)
        
        # 验证数据已被清理
        stats = await self.repository.get_distribution_stats('20240115')
        self.assertIsNone(stats)
    
    async def test_save_invalid_stats(self):
        """测试保存无效的统计数据"""
        # 测试在模型创建阶段就会失败的无效数据
        with self.assertRaises(ValueError):
            invalid_stats = PriceDistributionStats(
                trade_date='invalid_date',  # 无效日期格式
                total_stocks=1000,
                positive_ranges={'0-3%': 100},
                positive_percentages={'0-3%': 10.0},
                negative_ranges={'0到-3%': 100},
                negative_percentages={'0到-3%': 10.0},
                market_breakdown={}
            )
    
    async def test_database_error_handling(self):
        """测试数据库错误处理"""
        # 使用无效的数据库路径来模拟数据库错误
        invalid_db_manager = DatabaseManager('/invalid/path/test.db')
        invalid_repository = PriceDistributionRepository(invalid_db_manager)
        
        # 尝试保存数据应该抛出数据库错误
        with self.assertRaises(DatabaseError):
            await invalid_repository.save_distribution_stats(self.test_stats)


class TestPriceDistributionDatabaseCache(unittest.IsolatedAsyncioTestCase):
    """涨跌分布统计数据库缓存层测试"""
    
    async def asyncSetUp(self):
        """设置测试环境"""
        # 创建临时数据库
        self.temp_dir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.temp_dir, 'test_quickstock.db')
        
        # 初始化数据库管理器
        self.db_manager = DatabaseManager(self.db_path)
        await self.db_manager.initialize()
        
        # 初始化数据库缓存层
        self.db_cache = PriceDistributionDatabaseCache(self.db_manager)
        
        # 创建测试数据
        self.test_stats = PriceDistributionStats(
            trade_date='20240115',
            total_stocks=1000,
            positive_ranges={'0-3%': 200, '3-5%': 150},
            positive_percentages={'0-3%': 20.0, '3-5%': 15.0},
            negative_ranges={'0到-3%': 180, '-3到-5%': 120},
            negative_percentages={'0到-3%': 18.0, '-3到-5%': 12.0},
            market_breakdown={},
            processing_time=2.5,
            data_quality_score=0.95
        )
    
    async def asyncTearDown(self):
        """清理测试环境"""
        # 关闭数据库连接
        self.db_manager.close()
        
        # 删除临时文件
        import shutil
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    async def test_cache_get_set(self):
        """测试缓存获取和设置"""
        # 初始时缓存为空
        stats = await self.db_cache.get('20240115')
        self.assertIsNone(stats)
        
        # 设置缓存
        success = await self.db_cache.set('20240115', self.test_stats)
        self.assertTrue(success)
        
        # 获取缓存
        cached_stats = await self.db_cache.get('20240115')
        self.assertIsNotNone(cached_stats)
        self.assertEqual(cached_stats.trade_date, '20240115')
        self.assertEqual(cached_stats.total_stocks, 1000)
    
    async def test_cache_delete(self):
        """测试缓存删除"""
        # 设置缓存
        await self.db_cache.set('20240115', self.test_stats)
        
        # 验证缓存存在
        stats = await self.db_cache.get('20240115')
        self.assertIsNotNone(stats)
        
        # 删除缓存
        success = await self.db_cache.delete('20240115')
        self.assertTrue(success)
        
        # 验证缓存已删除
        stats = await self.db_cache.get('20240115')
        self.assertIsNone(stats)
    
    async def test_cache_exists(self):
        """测试缓存存在性检查"""
        # 初始时不存在
        exists = await self.db_cache.exists('20240115')
        self.assertFalse(exists)
        
        # 设置缓存后存在
        await self.db_cache.set('20240115', self.test_stats)
        exists = await self.db_cache.exists('20240115')
        self.assertTrue(exists)
    
    async def test_cache_clear(self):
        """测试缓存清理"""
        # 设置多个缓存
        dates = ['20240115', '20240116', '20240117']
        for date in dates:
            test_stats = PriceDistributionStats(
                trade_date=date,
                total_stocks=1000,
                positive_ranges={'0-3%': 100},
                positive_percentages={'0-3%': 10.0},
                negative_ranges={'0到-3%': 100},
                negative_percentages={'0到-3%': 10.0},
                market_breakdown={}
            )
            await self.db_cache.set(date, test_stats)
        
        # 清理指定日期
        cleared = await self.db_cache.clear('20240115')
        self.assertEqual(cleared, 1)
        
        # 验证指定日期已清理
        exists = await self.db_cache.exists('20240115')
        self.assertFalse(exists)
        
        # 其他日期仍存在
        exists = await self.db_cache.exists('20240116')
        self.assertTrue(exists)
    
    async def test_batch_operations(self):
        """测试批量操作"""
        # 准备批量数据
        stats_dict = {}
        for i, date in enumerate(['20240115', '20240116', '20240117']):
            stats = PriceDistributionStats(
                trade_date=date,
                total_stocks=1000 + i * 100,
                positive_ranges={'0-3%': 100 + i * 10},
                positive_percentages={'0-3%': 10.0 + i},
                negative_ranges={'0到-3%': 100 + i * 10},
                negative_percentages={'0到-3%': 10.0 + i},
                market_breakdown={}
            )
            stats_dict[date] = stats
        
        # 批量设置
        set_results = await self.db_cache.batch_set(stats_dict)
        for date, success in set_results.items():
            self.assertTrue(success, f"Failed to set cache for {date}")
        
        # 批量获取
        get_results = await self.db_cache.batch_get(list(stats_dict.keys()))
        for date, stats in get_results.items():
            self.assertIsNotNone(stats, f"Failed to get cache for {date}")
            self.assertEqual(stats.trade_date, date)
    
    async def test_get_cache_info(self):
        """测试获取缓存信息"""
        # 设置一些缓存数据
        await self.db_cache.set('20240115', self.test_stats)
        
        # 获取缓存信息
        info = await self.db_cache.get_cache_info()
        
        self.assertEqual(info['type'], 'database')
        self.assertIn('database_stats', info)
        self.assertIn('summary', info)
        self.assertGreater(info['total_records'], 0)
    
    async def test_validate_cache_integrity(self):
        """测试缓存完整性验证"""
        # 设置缓存数据
        await self.db_cache.set('20240115', self.test_stats)
        
        # 验证缓存完整性
        validation_result = await self.db_cache.validate_cache_integrity()
        
        self.assertTrue(validation_result['valid'])
        self.assertEqual(len(validation_result['errors']), 0)
        self.assertIn('stats', validation_result)
    
    async def test_optimize_cache(self):
        """测试缓存优化"""
        # 设置一些缓存数据
        await self.db_cache.set('20240115', self.test_stats)
        
        # 优化缓存
        optimization_result = await self.db_cache.optimize_cache()
        
        self.assertTrue(optimization_result['success'])
        self.assertIn('actions', optimization_result)
        self.assertIn('stats', optimization_result)


class TestPriceDistributionCacheManagerIntegration(unittest.IsolatedAsyncioTestCase):
    """涨跌分布统计缓存管理器集成测试"""
    
    async def asyncSetUp(self):
        """设置测试环境"""
        # 创建临时数据库
        self.temp_dir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.temp_dir, 'test_quickstock.db')
        
        # 初始化数据库管理器
        self.db_manager = DatabaseManager(self.db_path)
        await self.db_manager.initialize()
        
        # 初始化缓存管理器（集成数据库缓存）
        self.cache_manager = PriceDistributionCacheManager(
            db_manager=self.db_manager,
            memory_cache_size=100
        )
        
        # 创建测试数据
        self.test_stats = PriceDistributionStats(
            trade_date='20240115',
            total_stocks=1000,
            positive_ranges={'0-3%': 200, '3-5%': 150},
            positive_percentages={'0-3%': 20.0, '3-5%': 15.0},
            negative_ranges={'0到-3%': 180, '-3到-5%': 120},
            negative_percentages={'0到-3%': 18.0, '-3到-5%': 12.0},
            market_breakdown={},
            processing_time=2.5,
            data_quality_score=0.95
        )
    
    async def asyncTearDown(self):
        """清理测试环境"""
        # 关闭数据库连接
        self.db_manager.close()
        
        # 删除临时文件
        import shutil
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    async def test_integrated_cache_layers(self):
        """测试集成的缓存层"""
        # 设置分布统计缓存
        success = await self.cache_manager.set_distribution_stats(
            '20240115', self.test_stats, 'total', True
        )
        self.assertTrue(success)
        
        # 获取分布统计缓存（应该从内存缓存获取）
        cached_stats = await self.cache_manager.get_distribution_stats(
            '20240115', 'total', True
        )
        self.assertIsNotNone(cached_stats)
        self.assertEqual(cached_stats.trade_date, '20240115')
        
        # 清空内存缓存
        await self.cache_manager.memory_cache.clear()
        
        # 再次获取（应该从数据库缓存获取）
        cached_stats = await self.cache_manager.get_distribution_stats(
            '20240115', 'total', True
        )
        self.assertIsNotNone(cached_stats)
        self.assertEqual(cached_stats.trade_date, '20240115')
    
    async def test_cache_info_integration(self):
        """测试缓存信息集成"""
        # 设置一些缓存数据
        await self.cache_manager.set_distribution_stats(
            '20240115', self.test_stats, 'total', True
        )
        
        # 获取缓存信息
        info = await self.cache_manager.get_cache_info()
        
        self.assertIn('memory_cache', info)
        self.assertIn('structured_database_cache', info)
        self.assertIn('global_stats', info)
        
        # 验证结构化数据库缓存信息
        structured_db_info = info['structured_database_cache']
        self.assertEqual(structured_db_info['type'], 'database')
        self.assertGreater(structured_db_info['total_records'], 0)
    
    async def test_cache_deletion_integration(self):
        """测试缓存删除集成"""
        # 设置缓存
        await self.cache_manager.set_distribution_stats(
            '20240115', self.test_stats, 'total', True
        )
        
        # 验证缓存存在
        cached_stats = await self.cache_manager.get_distribution_stats(
            '20240115', 'total', True
        )
        self.assertIsNotNone(cached_stats)
        
        # 删除缓存
        deleted_count = await self.cache_manager.delete_distribution_stats('20240115')
        self.assertGreater(deleted_count, 0)
        
        # 验证缓存已删除（从所有层）
        cached_stats = await self.cache_manager.get_distribution_stats(
            '20240115', 'total', True
        )
        self.assertIsNone(cached_stats)


class TestDatabaseSchemaValidation(unittest.IsolatedAsyncioTestCase):
    """数据库模式验证测试"""
    
    async def asyncSetUp(self):
        """设置测试环境"""
        # 创建临时数据库
        self.temp_dir = tempfile.mkdtemp()
        self.db_path = os.path.join(self.temp_dir, 'test_quickstock.db')
        
        # 初始化数据库管理器
        self.db_manager = DatabaseManager(self.db_path)
        await self.db_manager.initialize()
    
    async def asyncTearDown(self):
        """清理测试环境"""
        # 关闭数据库连接
        self.db_manager.close()
        
        # 删除临时文件
        import shutil
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    async def test_database_schema_validation(self):
        """测试数据库模式验证"""
        # 验证数据库模式
        validation_result = await self.db_manager.validate_schema()
        
        self.assertTrue(validation_result['valid'])
        self.assertEqual(len(validation_result['errors']), 0)
        
        # 验证新增的表存在
        self.assertTrue(validation_result['tables']['price_distribution_stats'])
        self.assertTrue(validation_result['tables']['price_distribution_metadata'])
        
        # 验证新增的索引存在
        self.assertTrue(validation_result['indexes']['idx_price_distribution_stats_date'])
        self.assertTrue(validation_result['indexes']['idx_price_distribution_stats_market'])
        self.assertTrue(validation_result['indexes']['idx_price_distribution_stats_range'])
        self.assertTrue(validation_result['indexes']['idx_price_distribution_metadata_date'])
        
        # 验证新增的触发器存在
        self.assertTrue(validation_result['triggers']['update_price_distribution_stats_timestamp'])
        self.assertTrue(validation_result['triggers']['update_price_distribution_metadata_timestamp'])
    
    async def test_database_stats_include_new_tables(self):
        """测试数据库统计信息包含新表"""
        # 获取数据库统计信息
        stats = await self.db_manager.get_database_stats()
        
        # 验证新表包含在统计信息中
        self.assertIn('price_distribution_stats', stats['tables'])
        self.assertIn('price_distribution_metadata', stats['tables'])
        
        # 初始时表应该为空
        self.assertEqual(stats['tables']['price_distribution_stats'], 0)
        self.assertEqual(stats['tables']['price_distribution_metadata'], 0)


if __name__ == '__main__':
    # 运行测试
    unittest.main()