"""
增量计算功能单元测试

测试增量式特征工程的核心功能：
1. 构造函数参数
2. 缓存验证
3. 增量计算逻辑
4. to_parquet 增量合并
5. save_as_config 配置固化
"""

import pytest
import numpy as np
import pandas as pd
from pathlib import Path
import tempfile
import shutil

from infra.features_v2.core.feature_set import FeatureSet
from infra.features_v2.core.executor import ExecutionContext


class MockUnifiedExperiment:
    """模拟 UnifiedExperiment"""
    def __init__(self, chip_id, device_id, hdf5_path):
        self.chip_id = chip_id
        self.device_id = device_id
        self._hdf5_path = hdf5_path

    def _get_experiment(self):
        class MockExperiment:
            def __init__(self, hdf5_path):
                self.hdf5_path = hdf5_path
        return MockExperiment(self._hdf5_path)

    def get_v2_feature_dataframe(self, config_name):
        """模拟返回缓存的 DataFrame"""
        return None  # 默认无缓存


class TestIncrementalConstruction:
    """测试构造函数扩展"""

    def test_basic_construction(self):
        """测试基础构造"""
        features = FeatureSet()
        assert features.experiment is None
        assert features.unified_experiment is None
        assert features.config_name is None
        assert features.config_version == '1.0'

    def test_construction_with_unified_experiment(self):
        """测试使用 UnifiedExperiment 构造"""
        with tempfile.NamedTemporaryFile(suffix='.h5') as tmpfile:
            unified_exp = MockUnifiedExperiment('#TEST001', '1', tmpfile.name)

            features = FeatureSet(
                unified_experiment=unified_exp,
                config_name='test_config',
                config_version='2.0'
            )

            assert features.unified_experiment is unified_exp
            assert features.config_name == 'test_config'
            assert features.config_version == '2.0'
            # 应该自动提取 experiment
            assert features.experiment is not None


class TestCacheValidation:
    """测试缓存验证功能"""

    def test_compute_source_hash(self):
        """测试源文件哈希计算"""
        with tempfile.NamedTemporaryFile(suffix='.h5', delete=False) as tmpfile:
            tmpfile.write(b'test data')
            tmpfile_path = tmpfile.name

        try:
            unified_exp = MockUnifiedExperiment('#TEST001', '1', tmpfile_path)
            features = FeatureSet(unified_experiment=unified_exp)

            hash1 = features._compute_source_hash()
            assert hash1 != ""
            assert len(hash1) == 16  # MD5 前16位

            # 再次计算应该相同
            hash2 = features._compute_source_hash()
            assert hash1 == hash2

        finally:
            Path(tmpfile_path).unlink()

    def test_validate_cache_with_metadata(self):
        """测试带元数据的缓存验证"""
        with tempfile.NamedTemporaryFile(suffix='.h5', delete=False) as tmpfile:
            tmpfile.write(b'test data')
            tmpfile_path = tmpfile.name

        try:
            unified_exp = MockUnifiedExperiment('#TEST001', '1', tmpfile_path)
            features = FeatureSet(unified_experiment=unified_exp)

            # 创建带正确哈希的缓存 DataFrame
            current_hash = features._compute_source_hash()
            cached_df = pd.DataFrame({'step_index': [0, 1, 2], 'feature1': [1.0, 2.0, 3.0]})
            cached_df.attrs = {'source_hash': current_hash}

            # 应该验证通过
            assert features._validate_cache(cached_df) is True

            # 使用错误的哈希
            cached_df.attrs = {'source_hash': 'wrong_hash'}
            assert features._validate_cache(cached_df) is False

        finally:
            Path(tmpfile_path).unlink()

    def test_validate_cache_strict_mode(self):
        """测试严格模式"""
        features = FeatureSet()

        # 无元数据的 DataFrame
        cached_df = pd.DataFrame({'feature1': [1, 2, 3]})

        # 非严格模式：应该通过
        assert features._validate_cache(cached_df, strict=False) is True

        # 严格模式：应该失败
        assert features._validate_cache(cached_df, strict=True) is False


class TestParquetOperations:
    """测试 Parquet 操作"""

    def test_to_parquet_with_metadata(self):
        """测试保存 Parquet 带元数据"""
        with tempfile.TemporaryDirectory() as tmpdir, \
             tempfile.NamedTemporaryFile(suffix='.h5', delete=False) as tmpfile:
            tmpfile.write(b'test data')
            tmpfile_path = tmpfile.name

            try:
                unified_exp = MockUnifiedExperiment('#TEST001', '1', tmpfile_path)
                features = FeatureSet(
                    unified_experiment=unified_exp,
                    config_name='test_config',
                    config_version='1.0'
                )

                # 模拟计算结果
                features._computed_results = ExecutionContext()
                features._computed_results.set('feature1', np.array([1.0, 2.0, 3.0]), 0)

                # 保存 Parquet
                output_path = Path(tmpdir) / 'test.parquet'
                features.to_parquet(str(output_path), save_metadata=True)

                # 验证文件存在
                assert output_path.exists()

                # 读取并验证元数据
                df = pd.read_parquet(output_path)
                assert 'chip_id' in df.attrs
                assert df.attrs['chip_id'] == '#TEST001'
                assert df.attrs['device_id'] == '1'
                assert df.attrs['config_name'] == 'test_config'
                assert 'source_hash' in df.attrs

            finally:
                Path(tmpfile_path).unlink()

    def test_to_parquet_merge_existing(self):
        """测试增量合并功能"""
        with tempfile.TemporaryDirectory() as tmpdir:
            output_path = Path(tmpdir) / 'test.parquet'

            # 第一次保存
            features1 = FeatureSet(config_name='test')
            features1._computed_results = ExecutionContext()
            features1._computed_results.set('feature1', np.array([1.0, 2.0, 3.0]), 0)
            features1.to_parquet(str(output_path), save_metadata=False)

            # 第二次保存（合并模式）
            features2 = FeatureSet(config_name='test')
            features2._computed_results = ExecutionContext()
            features2._computed_results.set('feature2', np.array([4.0, 5.0, 6.0]), 0)
            features2.to_parquet(str(output_path), merge_existing=True, save_metadata=False)

            # 验证合并结果
            df = pd.read_parquet(output_path)
            assert 'feature1' in df.columns
            assert 'feature2' in df.columns
            assert len(df) == 3


class TestSaveAsConfig:
    """测试配置固化功能"""

    def test_save_as_config_basic(self):
        """测试基础配置保存"""
        with tempfile.TemporaryDirectory() as tmpdir, \
             tempfile.NamedTemporaryFile(suffix='.h5', delete=False) as tmpfile:
            tmpfile.write(b'test data')
            tmpfile_path = tmpfile.name

            try:
                unified_exp = MockUnifiedExperiment('#TEST001', '1', tmpfile_path)
                features = FeatureSet(
                    unified_experiment=unified_exp,
                    config_name='test_save',
                    config_version='1.0'
                )

                # 模拟计算结果
                features._computed_results = ExecutionContext()
                features._computed_results.set('feature1', np.array([1.0, 2.0, 3.0]), 0)

                # 模拟添加特征节点
                from infra.features_v2.core.compute_graph import ComputeNode
                node = ComputeNode(
                    name='feature1',
                    func='transfer.gm_max',
                    inputs=['transfer'],
                    params={},
                    is_extractor=True
                )
                features.graph.add_node(node)

                # 保存配置
                result = features.save_as_config(
                    config_name='test_save',
                    save_parquet=False,  # 跳过 Parquet（避免路径问题）
                    config_dir=tmpdir,
                    description='Test config'
                )

                # 验证返回值
                assert 'config_file' in result
                assert 'features_added' in result
                assert 'config_version' in result

                # 验证文件存在
                config_file = Path(result['config_file'])
                assert config_file.exists()

                # 读取配置验证内容
                import yaml
                with open(config_file, 'r') as f:
                    config = yaml.safe_load(f)

                assert config['name'] == 'test_save'
                assert config['config_version'] == '1.0'
                assert len(config['features']) == 1
                assert config['features'][0]['name'] == 'feature1'

            finally:
                Path(tmpfile_path).unlink()


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
