"""
QuantumSchoolLORA - Экономия памяти при обучении класса с LoRA адаптерами

Если у всех учеников одинаковая базовая модель, а отличия только в LoRA адаптерах,
загружаем базу ОДИН раз, экономя гигабайты памяти!

Архитектура:
┌──────────────────┐
│  BASE MODEL      │ ← Загружена один раз!
│  (5-10 GB)       │
└──────────────────┘
       ↓
   ┌───┴───┬───────┬───────┐
   ↓       ↓       ↓       ↓
 LoRA1   LoRA2   LoRA3   LoRA4  ← Только адаптеры (100-200 MB каждый)
 (S027) (S028)  (S047)  (S048)

Экономия: 
- Без LoRA: 4 × 10 GB = 40 GB
- С LoRA: 1 × 10 GB + 4 × 0.2 GB = 10.8 GB
- Экономия: ~73%!

© 2025 NativeMind
"""

import torch
from typing import Dict, List, Optional
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, LoraConfig, get_peft_model, TaskType

from .retrain import QuantumRetrain


class QuantumSchoolLORA:
    """
    Квантовая школа с экономией памяти через LoRA
    
    Оптимизация:
    1. Базовая модель загружается ОДИН раз
    2. Для каждого ученика применяется свой LoRA адаптер
    3. Обучение идет параллельно (через потоки, не процессы!)
    4. Сохраняем только LoRA веса (~100-200 MB на ученика)
    
    Требования:
    - Все ученики должны иметь одинаковую базовую модель
    - LoRA адаптеры могут быть разными
    """
    
    def __init__(
        self,
        base_model: str,
        base_frequency: float = 440.0,
        method: str = "equalizer",
        device: str = "auto",
        torch_dtype: str = "auto",
        debug: bool = True
    ):
        """
        Инициализация QuantumSchoolLORA
        
        Args:
            base_model: Путь к базовой модели (загружается один раз!)
            base_frequency: Резонансная частота
            method: Метод синхронизации
            device: Устройство ("cuda", "mps", "cpu", "auto")
            torch_dtype: Тип данных ("auto", "float16", "bfloat16")
            debug: Режим отладки
        """
        self.base_model_path = base_model
        self.base_frequency = base_frequency
        self.method = method
        self.debug = debug
        
        # Определяем устройство
        if device == "auto":
            if torch.cuda.is_available():
                self.device = "cuda"
            elif torch.backends.mps.is_available():
                self.device = "mps"
            else:
                self.device = "cpu"
        else:
            self.device = device
        
        # Определяем dtype
        if torch_dtype == "auto":
            if self.device == "cuda":
                self.torch_dtype = torch.float16
            elif self.device == "mps":
                self.torch_dtype = torch.float16
            else:
                self.torch_dtype = torch.float32
        else:
            self.torch_dtype = getattr(torch, torch_dtype)
        
        # Базовая модель (загружается лениво)
        self.base_model = None
        self.tokenizer = None
        
        # Статистика
        self.students_trained = 0
        self.memory_saved_gb = 0.0
        
        self._print_header()
    
    def _print_header(self):
        """Печать заголовка"""
        print("\n" + "=" * 80)
        print("🏫💾 КВАНТОВАЯ ШКОЛА С LORA (QuantumSchoolLORA)")
        print("=" * 80)
        print(f"Базовая модель: {self.base_model_path}")
        print(f"Резонансная частота: {self.base_frequency} Hz")
        print(f"Метод: {self.method}")
        print(f"Устройство: {self.device}")
        print(f"Тип данных: {self.torch_dtype}")
        print()
        print("💡 Экономия памяти: База загружается ОДИН раз!")
        print("=" * 80)
        print()
    
    def _load_base_model(self):
        """Загрузка базовой модели (один раз!)"""
        if self.base_model is not None:
            if self.debug:
                print("✅ Базовая модель уже загружена, используем кэш")
            return
        
        print(f"\n📦 Загрузка базовой модели...")
        print(f"   Путь: {self.base_model_path}")
        print(f"   Устройство: {self.device}")
        print(f"   Тип: {self.torch_dtype}")
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
            self.base_model = AutoModelForCausalLM.from_pretrained(
                self.base_model_path,
                torch_dtype=self.torch_dtype,
                device_map=self.device if self.device != "mps" else None
            )
            
            if self.device == "mps":
                self.base_model = self.base_model.to(self.device)
            
            # Подсчет параметров
            total_params = sum(p.numel() for p in self.base_model.parameters())
            trainable_params = sum(p.numel() for p in self.base_model.parameters() if p.requires_grad)
            
            print(f"   ✅ Модель загружена")
            print(f"   📊 Параметров: {total_params:,}")
            print(f"   🎯 Обучаемых: {trainable_params:,}")
            
        except Exception as e:
            print(f"   ❌ Ошибка загрузки: {e}")
            raise
    
    def train_class_lora(
        self,
        teachers: List[str],
        students_lora: List[Dict[str, str]],
        lora_config: Optional[LoraConfig] = None,
        teacher_amplitudes: Optional[List[float]] = None,
        cycles: int = 20,
        learning_rate: float = 0.05,
        auto_save: bool = True,
        parallel: bool = True
    ) -> Dict:
        """
        Обучение класса с LoRA адаптерами
        
        Args:
            teachers: Список учителей
            students_lora: Список учеников [{"name": "Sphere027", "lora_path": "...", "output": "..."}]
            lora_config: Конфигурация LoRA (опционально)
            teacher_amplitudes: Амплитуды учителей
            cycles: Циклов синхронизации
            learning_rate: Скорость обучения
            auto_save: Автосохранение
            parallel: Параллельное обучение (через потоки)
        
        Returns:
            Отчет об обучении
        """
        # Загружаем базовую модель один раз
        self._load_base_model()
        
        # Создаем LoRA конфигурацию по умолчанию
        if lora_config is None:
            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=8,  # Ранг LoRA
                lora_alpha=32,
                lora_dropout=0.1,
                target_modules=["q_proj", "v_proj"]  # Attention слои
            )
        
        print("\n" + "=" * 80)
        print(f"📚 НАЧАЛО ОБУЧЕНИЯ КЛАССА (LoRA режим)")
        print("=" * 80)
        print(f"Учителей: {len(teachers)}")
        print(f"Учеников: {len(students_lora)}")
        print(f"LoRA rank: {lora_config.r}")
        print(f"Режим: {'Параллельный (потоки)' if parallel else 'Последовательный'}")
        print()
        
        if parallel:
            results = self._train_parallel_lora(
                teachers, students_lora, lora_config,
                teacher_amplitudes, cycles, learning_rate, auto_save
            )
        else:
            results = self._train_sequential_lora(
                teachers, students_lora, lora_config,
                teacher_amplitudes, cycles, learning_rate, auto_save
            )
        
        # Подсчет экономии памяти
        self._calculate_memory_savings(len(students_lora))
        
        # Сводный отчет
        report = self._generate_lora_report(teachers, students_lora, results)
        
        # Финальная визуализация
        self._print_lora_report(report)
        
        return report
    
    def _train_parallel_lora(
        self,
        teachers: List[str],
        students_lora: List[Dict[str, str]],
        lora_config: LoraConfig,
        teacher_amplitudes: Optional[List[float]],
        cycles: int,
        learning_rate: float,
        auto_save: bool
    ) -> List[Dict]:
        """
        Параллельное обучение через потоки (не процессы!)
        
        Используем ThreadPoolExecutor, потому что:
        1. Общая память для базовой модели
        2. Python GIL не мешает, т.к. большинство операций в PyTorch
        3. Экономия памяти vs ProcessPoolExecutor
        """
        print("🚀 Параллельное обучение (потоки) запущено...")
        print()
        
        results = []
        max_threads = min(4, len(students_lora))  # Ограничиваем потоки
        
        with ThreadPoolExecutor(max_workers=max_threads) as executor:
            futures = {}
            
            for student in students_lora:
                future = executor.submit(
                    self._train_single_lora,
                    teachers,
                    student,
                    lora_config,
                    teacher_amplitudes,
                    cycles,
                    learning_rate,
                    auto_save
                )
                futures[future] = student['name']
            
            # Собираем результаты
            for future in as_completed(futures):
                student_name = futures[future]
                try:
                    result = future.result()
                    results.append(result)
                    
                    if self.debug:
                        print(f"✅ {student_name}: Синхронизация {result['final_sync']:.1%}")
                    
                except Exception as e:
                    print(f"❌ {student_name}: Ошибка - {e}")
                    results.append({
                        'student_name': student_name,
                        'success': False,
                        'error': str(e)
                    })
        
        return results
    
    def _train_sequential_lora(
        self,
        teachers: List[str],
        students_lora: List[Dict[str, str]],
        lora_config: LoraConfig,
        teacher_amplitudes: Optional[List[float]],
        cycles: int,
        learning_rate: float,
        auto_save: bool
    ) -> List[Dict]:
        """Последовательное обучение (для отладки)"""
        print("🐌 Последовательное обучение запущено...")
        print()
        
        results = []
        
        for i, student in enumerate(students_lora, 1):
            print(f"\n--- Ученик {i}/{len(students_lora)}: {student['name']} ---")
            
            try:
                result = self._train_single_lora(
                    teachers, student, lora_config,
                    teacher_amplitudes, cycles, learning_rate, auto_save
                )
                results.append(result)
                print(f"✅ Успех: Синхронизация {result['final_sync']:.1%}")
                
            except Exception as e:
                print(f"❌ Ошибка: {e}")
                results.append({
                    'student_name': student['name'],
                    'success': False,
                    'error': str(e)
                })
        
        return results
    
    def _train_single_lora(
        self,
        teachers: List[str],
        student: Dict[str, str],
        lora_config: LoraConfig,
        teacher_amplitudes: Optional[List[float]],
        cycles: int,
        learning_rate: float,
        auto_save: bool
    ) -> Dict:
        """
        Обучение одного ученика с LoRA
        
        Args:
            teachers: Учителя
            student: {"name": "...", "lora_path": "..." (optional), "output": "..."}
            lora_config: Конфигурация LoRA
            teacher_amplitudes: Амплитуды
            cycles: Циклы
            learning_rate: Скорость
            auto_save: Автосохранение
        
        Returns:
            Результат обучения
        """
        student_name = student['name']
        lora_path = student.get('lora_path', None)  # Может быть None (новый LoRA)
        output_path = student.get('output', f'./quantum_school_lora/{student_name}')
        
        if self.debug:
            print(f"\n🎓 LoRA обучение: {student_name}")
            if lora_path:
                print(f"   Загружаю LoRA: {lora_path}")
            else:
                print(f"   Создаю новый LoRA адаптер")
        
        try:
            # Создаем копию базовой модели для этого ученика
            if lora_path and Path(lora_path).exists():
                # Загружаем существующий LoRA
                student_model = PeftModel.from_pretrained(
                    self.base_model,
                    lora_path
                )
                if self.debug:
                    print(f"   ✅ LoRA загружен")
            else:
                # Создаем новый LoRA адаптер
                student_model = get_peft_model(self.base_model, lora_config)
                if self.debug:
                    print(f"   ✅ Новый LoRA создан")
            
            # Квантовая синхронизация
            retrain = QuantumRetrain(
                base_frequency=self.base_frequency,
                method=self.method
            )
            
            # Временно сохраняем student_model для синхронизации
            # TODO: Интегрировать напрямую с QuantumRetrain
            temp_student_path = f"/tmp/quantum_student_{student_name}"
            student_model.save_pretrained(temp_student_path)
            
            # Полный цикл переобучения
            result = retrain.full_retrain(
                teacher_models=teachers,
                student_model=temp_student_path,
                teacher_amplitudes=teacher_amplitudes,
                cycles=cycles,
                learning_rate=learning_rate,
                auto_save=auto_save,
                save_mode="lora",
                output_path=output_path
            )
            
            # Добавляем метаданные
            result['student_name'] = student_name
            result['lora_path'] = lora_path
            result['output_path'] = output_path
            result['lora_only'] = True
            
            # Подсчет размера LoRA
            if auto_save and Path(output_path).exists():
                lora_size = sum(
                    f.stat().st_size 
                    for f in Path(output_path).rglob('*') 
                    if f.is_file()
                )
                result['lora_size_mb'] = lora_size / (1024 * 1024)
                
                if self.debug:
                    print(f"   💾 LoRA размер: {result['lora_size_mb']:.1f} MB")
            
            self.students_trained += 1
            
            return result
            
        except Exception as e:
            if self.debug:
                print(f"❌ Ошибка при обучении {student_name}: {e}")
                import traceback
                traceback.print_exc()
            
            return {
                'student_name': student_name,
                'lora_path': lora_path,
                'success': False,
                'error': str(e),
                'final_sync': 0.0
            }
    
    def _calculate_memory_savings(self, num_students: int):
        """Подсчет экономии памяти"""
        # Примерный размер модели
        base_model_gb = 10.0  # ~10 GB для обычной модели
        lora_adapter_gb = 0.2  # ~200 MB на адаптер
        
        # Без LoRA: каждый ученик = полная модель
        memory_without_lora = num_students * base_model_gb
        
        # С LoRA: одна база + N адаптеров
        memory_with_lora = base_model_gb + (num_students * lora_adapter_gb)
        
        self.memory_saved_gb = memory_without_lora - memory_with_lora
        
        if self.debug:
            print(f"\n💾 Экономия памяти:")
            print(f"   Без LoRA: {memory_without_lora:.1f} GB")
            print(f"   С LoRA: {memory_with_lora:.1f} GB")
            print(f"   Сэкономлено: {self.memory_saved_gb:.1f} GB ({self.memory_saved_gb/memory_without_lora*100:.0f}%)")
    
    def _generate_lora_report(
        self,
        teachers: List[str],
        students: List[Dict[str, str]],
        results: List[Dict]
    ) -> Dict:
        """Генерация отчета о LoRA обучении"""
        successful = [r for r in results if r.get('success', False)]
        failed = [r for r in results if not r.get('success', False)]
        
        syncs = [r['final_sync'] for r in successful if 'final_sync' in r]
        avg_sync = np.mean(syncs) if syncs else 0.0
        
        lora_sizes = [r.get('lora_size_mb', 0) for r in successful]
        total_lora_mb = sum(lora_sizes)
        
        report = {
            'method': self.method,
            'base_frequency': self.base_frequency,
            'base_model': self.base_model_path,
            'teachers': {'count': len(teachers), 'paths': teachers},
            'students': {
                'total': len(students),
                'successful': len(successful),
                'failed': len(failed)
            },
            'synchronization': {
                'average': avg_sync,
                'distribution': syncs
            },
            'memory': {
                'saved_gb': self.memory_saved_gb,
                'total_lora_mb': total_lora_mb,
                'avg_lora_mb': total_lora_mb / len(successful) if successful else 0
            },
            'results': results,
            'success_rate': len(successful) / len(students) if students else 0.0
        }
        
        return report
    
    def _print_lora_report(self, report: Dict):
        """Печать отчета"""
        print("\n" + "=" * 80)
        print("📊 ИТОГОВЫЙ ОТЧЕТ (LoRA режим)")
        print("=" * 80)
        print()
        
        print(f"Базовая модель: {report['base_model']}")
        print(f"Метод: {report['method']}")
        print(f"Резонансная частота: {report['base_frequency']} Hz")
        print()
        
        print(f"👨‍🏫 Учителей: {report['teachers']['count']}")
        print(f"👨‍🎓 Учеников: {report['students']['total']}")
        print(f"   ✅ Успешно: {report['students']['successful']}")
        print(f"   ❌ Ошибок: {report['students']['failed']}")
        print(f"   📈 Успешность: {report['success_rate']:.1%}")
        print()
        
        if report['synchronization']['distribution']:
            print(f"🔄 Средняя синхронизация: {report['synchronization']['average']:.1%}")
        
        print()
        print(f"💾 Экономия памяти:")
        print(f"   Сэкономлено: {report['memory']['saved_gb']:.1f} GB")
        print(f"   Всего LoRA: {report['memory']['total_lora_mb']:.0f} MB")
        print(f"   Средний LoRA: {report['memory']['avg_lora_mb']:.0f} MB")
        print()
        
        # Список учеников
        print("📋 Результаты:")
        for result in report['results']:
            name = result['student_name']
            if result.get('success', False):
                sync = result.get('final_sync', 0.0)
                size = result.get('lora_size_mb', 0)
                print(f"   ✅ {name}: Синхронизация {sync:.1%}, LoRA {size:.0f} MB")
            else:
                error = result.get('error', 'Unknown')[:50]
                print(f"   ❌ {name}: {error}")
        
        print()
        print("=" * 80)
        print("🎉 ОБУЧЕНИЕ КЛАССА (LoRA) ЗАВЕРШЕНО!")
        print("=" * 80)
        print()


def quick_lora_class_train(
    base_model: str,
    teachers: List[str],
    students: List[Dict[str, str]],
    method: str = "equalizer",
    parallel: bool = True
) -> Dict:
    """
    Быстрое обучение класса в LoRA режиме
    
    Args:
        base_model: Базовая модель (общая для всех)
        teachers: Список учителей
        students: [{"name": "Sphere027", "lora_path": "...", "output": "..."}]
        method: "equalizer" или "pyramid"
        parallel: Параллельное обучение
    
    Returns:
        Отчет
    
    Example:
        >>> report = quick_lora_class_train(
        ...     base_model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        ...     teachers=["teacher1", "teacher2"],
        ...     students=[
        ...         {"name": "Sphere027", "output": "sphere_027"},
        ...         {"name": "Sphere028", "output": "sphere_028"}
        ...     ]
        ... )
    """
    school = QuantumSchoolLORA(
        base_model=base_model,
        method=method,
        debug=True
    )
    
    return school.train_class_lora(
        teachers=teachers,
        students_lora=students,
        cycles=20,
        learning_rate=0.05,
        auto_save=True,
        parallel=parallel
    )

