"""
QuantumSignature - Извлечение подписей моделей

Аналогия с аудио подписью:
- Каждая модель имеет уникальную "подпись"
- Подпись = паттерн весов и архитектуры
- Синхронизация = наложение подписей

© 2025 NativeMind
"""

import torch
import numpy as np
from typing import Dict, List, Optional, Union, Tuple
from pathlib import Path
import json
import hashlib
from datetime import datetime


class SignatureExtractor:
    """
    Извлечение подписей AI моделей
    
    Подпись модели включает:
    - Архитектурные параметры
    - Статистики весов
    - Паттерны внимания
    - Частотные характеристики
    """
    
    def __init__(self, model_path: str):
        """
        Инициализация экстрактора подписей
        
        Args:
            model_path: Путь к модели или HuggingFace ID
        """
        self.model_path = model_path
        self.signature = {}
        self.model = None
        self.config = None
        
        print(f"🔍 Извлечение подписи модели: {model_path}")
    
    def extract_full_signature(self) -> Dict:
        """
        Извлечение полной подписи модели
        
        Returns:
            Словарь с подписью модели
        """
        print("📊 Извлечение полной подписи...")
        
        try:
            # Загружаем модель для анализа
            self._load_model()
            
            signature = {
                'metadata': {
                    'model_path': self.model_path,
                    'extraction_time': datetime.now().isoformat(),
                    'model_hash': self._calculate_model_hash()
                },
                'architecture': self._extract_architecture(),
                'weights': self._extract_weights_signature(),
                'attention': self._extract_attention_patterns(),
                'frequency': self._extract_frequency_characteristics(),
                'quantum_signature': self._extract_quantum_signature()
            }
            
            self.signature = signature
            print("✅ Подпись извлечена успешно")
            return signature
            
        except Exception as e:
            print(f"⚠️ Ошибка извлечения подписи: {e}")
            # Возвращаем базовую подпись
            return self._create_basic_signature()
    
    def _load_model(self):
        """Загрузка модели для анализа"""
        try:
            from transformers import AutoModel, AutoConfig
            
            print("📥 Загружаю модель...")
            self.config = AutoConfig.from_pretrained(self.model_path)
            self.model = AutoModel.from_pretrained(
                self.model_path,
                torch_dtype=torch.float16,
                device_map="cpu",
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            print("✅ Модель загружена")
            
        except Exception as e:
            print(f"⚠️ Не удалось загрузить модель: {e}")
            self.model = None
            self.config = None
    
    def _calculate_model_hash(self) -> str:
        """Вычисление хеша модели"""
        try:
            if self.model:
                # Хеш от архитектуры
                arch_str = str(self.model.__class__.__name__)
                return hashlib.sha256(arch_str.encode()).hexdigest()[:16]
            else:
                return hashlib.sha256(self.model_path.encode()).hexdigest()[:16]
        except:
            return "unknown"
    
    def _extract_architecture(self) -> Dict:
        """Извлечение архитектурных параметров"""
        if not self.model:
            return {
                'layers': 0,
                'parameters': 0,
                'model_type': 'unknown',
                'hidden_size': 0,
                'num_attention_heads': 0
            }
        
        try:
            # Подсчет параметров
            total_params = sum(p.numel() for p in self.model.parameters())
            trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            
            # Извлечение архитектурных параметров
            arch_info = {
                'layers': len(list(self.model.modules())),
                'parameters': total_params,
                'trainable_parameters': trainable_params,
                'model_type': self.model.__class__.__name__,
                'hidden_size': getattr(self.config, 'hidden_size', 0),
                'num_attention_heads': getattr(self.config, 'num_attention_heads', 0),
                'num_hidden_layers': getattr(self.config, 'num_hidden_layers', 0),
                'vocab_size': getattr(self.config, 'vocab_size', 0)
            }
            
            return arch_info
            
        except Exception as e:
            print(f"⚠️ Ошибка извлечения архитектуры: {e}")
            return {
                'layers': 0,
                'parameters': 0,
                'model_type': 'unknown',
                'hidden_size': 0,
                'num_attention_heads': 0
            }
    
    def _extract_weights_signature(self) -> Dict:
        """Извлечение подписи весов"""
        if not self.model:
            return {
                'mean': 0.0,
                'std': 0.0,
                'min': 0.0,
                'max': 0.0,
                'sparsity': 0.0,
                'weight_distribution': 'unknown'
            }
        
        try:
            # Собираем все веса
            all_weights = []
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    all_weights.extend(param.data.flatten().cpu().numpy())
            
            if not all_weights:
                return {
                    'mean': 0.0,
                    'std': 0.0,
                    'min': 0.0,
                    'max': 0.0,
                    'sparsity': 0.0,
                    'weight_distribution': 'unknown'
                }
            
            weights_array = np.array(all_weights)
            
            # Статистики весов
            mean_weight = np.mean(weights_array)
            std_weight = np.std(weights_array)
            min_weight = np.min(weights_array)
            max_weight = np.max(weights_array)
            
            # Спарсити (доля нулевых весов)
            sparsity = np.sum(weights_array == 0) / len(weights_array)
            
            # Распределение весов
            if std_weight < 0.01:
                distribution = 'uniform'
            elif abs(mean_weight) < 0.1 and std_weight < 0.5:
                distribution = 'normal'
            else:
                distribution = 'skewed'
            
            return {
                'mean': float(mean_weight),
                'std': float(std_weight),
                'min': float(min_weight),
                'max': float(max_weight),
                'sparsity': float(sparsity),
                'weight_distribution': distribution,
                'total_weights': len(weights_array)
            }
            
        except Exception as e:
            print(f"⚠️ Ошибка извлечения весов: {e}")
            return {
                'mean': 0.0,
                'std': 0.0,
                'min': 0.0,
                'max': 0.0,
                'sparsity': 0.0,
                'weight_distribution': 'unknown'
            }
    
    def _extract_attention_patterns(self) -> Dict:
        """Извлечение паттернов внимания"""
        if not self.model:
            return {
                'heads': 0,
                'layers': 0,
                'pattern': 'unknown',
                'attention_type': 'unknown'
            }
        
        try:
            # Поиск attention слоев
            attention_layers = []
            for name, module in self.model.named_modules():
                if 'attention' in name.lower() or 'attn' in name.lower():
                    attention_layers.append((name, module))
            
            # Извлечение параметров внимания
            num_heads = 0
            attention_type = 'unknown'
            
            if hasattr(self.config, 'num_attention_heads'):
                num_heads = self.config.num_attention_heads
            elif hasattr(self.config, 'num_heads'):
                num_heads = self.config.num_heads
            
            # Определение типа внимания
            if 'multihead' in str(type(self.model)).lower():
                attention_type = 'multihead'
            elif 'self' in str(type(self.model)).lower():
                attention_type = 'self'
            else:
                attention_type = 'unknown'
            
            return {
                'heads': num_heads,
                'layers': len(attention_layers),
                'pattern': 'multihead' if num_heads > 1 else 'single',
                'attention_type': attention_type,
                'attention_layers': len(attention_layers)
            }
            
        except Exception as e:
            print(f"⚠️ Ошибка извлечения внимания: {e}")
            return {
                'heads': 0,
                'layers': 0,
                'pattern': 'unknown',
                'attention_type': 'unknown'
            }
    
    def _extract_frequency_characteristics(self) -> Dict:
        """Извлечение частотных характеристик"""
        if not self.model:
            return {
                'dominant_freq': 440.0,
                'harmonics': [],
                'spectrum': {},
                'resonance_freq': 440.0
            }
        
        try:
            # Анализ весов через FFT
            all_weights = []
            for name, param in self.model.named_parameters():
                if param.requires_grad and len(param.shape) >= 2:
                    # Берем подвыборку весов
                    flat_weights = param.data.flatten().cpu().numpy()
                    if len(flat_weights) > 1000:
                        flat_weights = flat_weights[:1000]
                    all_weights.extend(flat_weights)
            
            if not all_weights:
                return {
                    'dominant_freq': 440.0,
                    'harmonics': [],
                    'spectrum': {},
                    'resonance_freq': 440.0
                }
            
            # FFT анализ
            weights_array = np.array(all_weights)
            fft_result = np.fft.fft(weights_array)
            freqs = np.fft.fftfreq(len(weights_array))
            
            # Находим доминирующую частоту
            power_spectrum = np.abs(fft_result) ** 2
            dominant_idx = np.argmax(power_spectrum[1:len(power_spectrum)//2]) + 1
            dominant_freq = freqs[dominant_idx] * len(weights_array)
            
            # Гармоники
            harmonics = []
            for i in range(2, 6):  # 2-я, 3-я, 4-я, 5-я гармоники
                harmonic_freq = dominant_freq * i
                harmonics.append(float(harmonic_freq))
            
            # Резонансная частота (базовая)
            resonance_freq = 440.0 if dominant_freq < 0.1 else float(dominant_freq)
            
            return {
                'dominant_freq': float(dominant_freq),
                'harmonics': harmonics,
                'spectrum': {
                    'power_peak': float(np.max(power_spectrum)),
                    'spectral_centroid': float(np.sum(freqs * power_spectrum) / np.sum(power_spectrum))
                },
                'resonance_freq': resonance_freq
            }
            
        except Exception as e:
            print(f"⚠️ Ошибка извлечения частот: {e}")
            return {
                'dominant_freq': 440.0,
                'harmonics': [],
                'spectrum': {},
                'resonance_freq': 440.0
            }
    
    def _extract_quantum_signature(self) -> Dict:
        """Извлечение квантовой подписи модели"""
        try:
            # Квантовая подпись основана на энтропии весов
            if not self.model:
                return {
                    'entropy': 0.0,
                    'quantum_coherence': 0.0,
                    'superposition_states': 0,
                    'quantum_signature': 'unknown'
                }
            
            # Собираем веса для анализа энтропии
            all_weights = []
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    weights = param.data.flatten().cpu().numpy()
                    all_weights.extend(weights)
            
            if not all_weights:
                return {
                    'entropy': 0.0,
                    'quantum_coherence': 0.0,
                    'superposition_states': 0,
                    'quantum_signature': 'unknown'
                }
            
            weights_array = np.array(all_weights)
            
            # Вычисление энтропии Шеннона
            hist, _ = np.histogram(weights_array, bins=50, density=True)
            hist = hist[hist > 0]  # Убираем нулевые бины
            entropy = -np.sum(hist * np.log2(hist + 1e-10))
            
            # Квантовая когерентность (на основе стандартного отклонения)
            coherence = 1.0 / (1.0 + np.std(weights_array))
            
            # Состояния суперпозиции (количество уникальных значений)
            unique_weights = len(np.unique(weights_array))
            superposition_states = min(unique_weights, 1000)  # Ограничиваем для практичности
            
            # Определение типа квантовой подписи
            if entropy > 5.0 and coherence > 0.5:
                quantum_signature = 'highly_quantum'
            elif entropy > 3.0 and coherence > 0.3:
                quantum_signature = 'moderately_quantum'
            else:
                quantum_signature = 'classical'
            
            return {
                'entropy': float(entropy),
                'quantum_coherence': float(coherence),
                'superposition_states': int(superposition_states),
                'quantum_signature': quantum_signature
            }
            
        except Exception as e:
            print(f"⚠️ Ошибка извлечения квантовой подписи: {e}")
            return {
                'entropy': 0.0,
                'quantum_coherence': 0.0,
                'superposition_states': 0,
                'quantum_signature': 'unknown'
            }
    
    def _create_basic_signature(self) -> Dict:
        """Создание базовой подписи при ошибке"""
        return {
            'metadata': {
                'model_path': self.model_path,
                'extraction_time': datetime.now().isoformat(),
                'model_hash': 'error',
                'error': True
            },
            'architecture': {
                'layers': 0,
                'parameters': 0,
                'model_type': 'unknown'
            },
            'weights': {
                'mean': 0.0,
                'std': 0.0,
                'min': 0.0,
                'max': 0.0
            },
            'attention': {
                'heads': 0,
                'layers': 0,
                'pattern': 'unknown'
            },
            'frequency': {
                'dominant_freq': 440.0,
                'harmonics': [],
                'spectrum': {}
            },
            'quantum_signature': {
                'entropy': 0.0,
                'quantum_coherence': 0.0,
                'superposition_states': 0,
                'quantum_signature': 'unknown'
            }
        }
    
    def save_signature(self, output_path: str):
        """Сохранение подписи в файл"""
        try:
            Path(output_path).parent.mkdir(parents=True, exist_ok=True)
            
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(self.signature, f, indent=2, ensure_ascii=False)
            
            print(f"💾 Подпись сохранена: {output_path}")
            
        except Exception as e:
            print(f"⚠️ Ошибка сохранения подписи: {e}")
    
    @classmethod
    def load_signature(cls, signature_path: str) -> Dict:
        """Загрузка подписи из файла"""
        try:
            with open(signature_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception as e:
            print(f"⚠️ Ошибка загрузки подписи: {e}")
            return {}
    
    def compare_signatures(self, other_signature: Dict) -> Dict:
        """Сравнение с другой подписью"""
        if not self.signature or not other_signature:
            return {'similarity': 0.0, 'differences': []}
        
        try:
            # Сравнение архитектуры
            arch_sim = self._compare_architecture(
                self.signature.get('architecture', {}),
                other_signature.get('architecture', {})
            )
            
            # Сравнение весов
            weights_sim = self._compare_weights(
                self.signature.get('weights', {}),
                other_signature.get('weights', {})
            )
            
            # Сравнение частот
            freq_sim = self._compare_frequency(
                self.signature.get('frequency', {}),
                other_signature.get('frequency', {})
            )
            
            # Общая схожесть
            overall_similarity = (arch_sim + weights_sim + freq_sim) / 3.0
            
            return {
                'similarity': overall_similarity,
                'architecture_similarity': arch_sim,
                'weights_similarity': weights_sim,
                'frequency_similarity': freq_sim,
                'differences': self._find_differences(other_signature)
            }
            
        except Exception as e:
            print(f"⚠️ Ошибка сравнения подписей: {e}")
            return {'similarity': 0.0, 'differences': []}
    
    def _compare_architecture(self, arch1: Dict, arch2: Dict) -> float:
        """Сравнение архитектур"""
        if not arch1 or not arch2:
            return 0.0
        
        # Сравнение количества слоев
        layers1 = arch1.get('layers', 0)
        layers2 = arch2.get('layers', 0)
        layer_sim = 1.0 - abs(layers1 - layers2) / max(layers1, layers2, 1)
        
        # Сравнение типов моделей
        type1 = arch1.get('model_type', '')
        type2 = arch2.get('model_type', '')
        type_sim = 1.0 if type1 == type2 else 0.0
        
        return (layer_sim + type_sim) / 2.0
    
    def _compare_weights(self, weights1: Dict, weights2: Dict) -> float:
        """Сравнение весов"""
        if not weights1 or not weights2:
            return 0.0
        
        # Сравнение статистик
        mean1 = weights1.get('mean', 0.0)
        mean2 = weights2.get('mean', 0.0)
        mean_sim = 1.0 - abs(mean1 - mean2) / max(abs(mean1), abs(mean2), 1e-6)
        
        std1 = weights1.get('std', 0.0)
        std2 = weights2.get('std', 0.0)
        std_sim = 1.0 - abs(std1 - std2) / max(std1, std2, 1e-6)
        
        return (mean_sim + std_sim) / 2.0
    
    def _compare_frequency(self, freq1: Dict, freq2: Dict) -> float:
        """Сравнение частотных характеристик"""
        if not freq1 or not freq2:
            return 0.0
        
        # Сравнение доминирующих частот
        dom1 = freq1.get('dominant_freq', 440.0)
        dom2 = freq2.get('dominant_freq', 440.0)
        freq_sim = 1.0 - abs(dom1 - dom2) / max(dom1, dom2, 1.0)
        
        return freq_sim
    
    def _find_differences(self, other_signature: Dict) -> List[str]:
        """Поиск различий между подписями"""
        differences = []
        
        try:
            # Сравнение архитектуры
            arch1 = self.signature.get('architecture', {})
            arch2 = other_signature.get('architecture', {})
            
            if arch1.get('layers', 0) != arch2.get('layers', 0):
                differences.append(f"Разное количество слоев: {arch1.get('layers', 0)} vs {arch2.get('layers', 0)}")
            
            if arch1.get('model_type', '') != arch2.get('model_type', ''):
                differences.append(f"Разные типы моделей: {arch1.get('model_type', '')} vs {arch2.get('model_type', '')}")
            
            # Сравнение весов
            weights1 = self.signature.get('weights', {})
            weights2 = other_signature.get('weights', {})
            
            mean_diff = abs(weights1.get('mean', 0.0) - weights2.get('mean', 0.0))
            if mean_diff > 0.1:
                differences.append(f"Значительная разница в средних весах: {mean_diff:.3f}")
            
        except Exception as e:
            differences.append(f"Ошибка сравнения: {e}")
        
        return differences


def apply_signature(
    student_model,
    teacher_signature: Dict,
    learning_rate: float = 0.05
) -> torch.nn.Module:
    """
    Применение подписи учителя к модели-ученику
    
    Args:
        student_model: Модель-ученик
        teacher_signature: Подпись учителя
        learning_rate: Скорость применения
    
    Returns:
        Модифицированная модель
    """
    print(f"📥 Применение подписи (learning_rate={learning_rate:.1%})...")
    
    try:
        # Извлекаем информацию о весах учителя
        teacher_weights = teacher_signature.get('weights', {})
        teacher_freq = teacher_signature.get('frequency', {})
        
        # Применяем изменения к весам ученика
        with torch.no_grad():
            for name, param in student_model.named_parameters():
                if param.requires_grad:
                    # Получаем статистики учителя
                    teacher_mean = teacher_weights.get('mean', 0.0)
                    teacher_std = teacher_weights.get('std', 1.0)
                    
                    # Нормализуем веса ученика к статистикам учителя
                    current_mean = param.data.mean().item()
                    current_std = param.data.std().item()
                    
                    if current_std > 0:
                        # Нормализация
                        normalized = (param.data - current_mean) / current_std
                        # Применение статистик учителя
                        new_weights = normalized * teacher_std + teacher_mean
                        
                        # Плавное применение
                        param.data = param.data * (1 - learning_rate) + new_weights * learning_rate
        
        print("✅ Подпись применена успешно")
        
    except Exception as e:
        print(f"⚠️ Ошибка применения подписи: {e}")
    
    return student_model