"""
QuantumSchool - Параллельное квантовое обучение целого класса моделей

Подобно реальной школе, где учеников обучают целыми классами,
QuantumSchool позволяет одновременно обучать множество моделей,
синхронизируя их через квантовый резонанс.

Особенности:
- Параллельное обучение N учеников
- Общие учителя для всего класса
- Квантовая синхронизация между учениками
- Поддержка equalizer и pyramid методов
- Автоматическая балансировка нагрузки

© 2025 NativeMind
"""

import torch
import multiprocessing as mp
from typing import Dict, List, Optional, Tuple
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np

from .retrain import QuantumRetrain
from .equalizer import QuantumBatteryEqualizer, ModelChannel
from .pyramid import QuantumPyramid


class QuantumSchool:
    """
    Квантовая школа для параллельного обучения моделей
    
    Архитектура:
    
    УЧИТЕЛЯ (Teachers)
         ↓ ↓ ↓ ↓
    ┌────────────────┐
    │  КВАНТОВАЯ     │
    │    ШКОЛА       │ ← Резонанс 440 Hz
    └────────────────┘
         ↓ ↓ ↓ ↓
    УЧЕНИКИ (Students)
     S1  S2  S3  S4
     
    Каждый ученик обучается параллельно,
    но все синхронизированы через резонанс!
    """
    
    def __init__(
        self,
        base_frequency: float = 440.0,
        method: str = "equalizer",  # "equalizer" или "pyramid"
        max_workers: Optional[int] = None,
        debug: bool = True
    ):
        """
        Инициализация Квантовой Школы
        
        Args:
            base_frequency: Базовая частота резонанса (Hz)
            method: Метод синхронизации ("equalizer" или "pyramid")
            max_workers: Максимум параллельных процессов (None = CPU count)
            debug: Режим отладки
        """
        self.base_frequency = base_frequency
        self.method = method
        self.max_workers = max_workers or mp.cpu_count()
        self.debug = debug
        
        # Статистика
        self.students_trained = 0
        self.total_cycles = 0
        self.average_sync = 0.0
        
        self._print_header()
    
    def _print_header(self):
        """Печать заголовка школы"""
        print("\n" + "=" * 80)
        print("🏫 КВАНТОВАЯ ШКОЛА (QuantumSchool)")
        print("=" * 80)
        print(f"Резонансная частота: {self.base_frequency} Hz")
        print(f"Метод синхронизации: {self.method}")
        print(f"Максимум параллельных процессов: {self.max_workers}")
        print(f"Режим отладки: {'Включен' if self.debug else 'Выключен'}")
        print("=" * 80)
        print()
    
    def train_class(
        self,
        teachers: List[str],
        students: List[Dict[str, str]],
        teacher_amplitudes: Optional[List[float]] = None,
        cycles: int = 20,
        learning_rate: float = 0.05,
        auto_save: bool = True,
        save_mode: str = "lora",
        parallel: bool = True
    ) -> Dict:
        """
        Обучение целого класса учеников
        
        Args:
            teachers: Список путей к учителям
            students: Список учеников [{"name": "...", "path": "...", "output": "..."}]
            teacher_amplitudes: Амплитуды учителей
            cycles: Циклов синхронизации
            learning_rate: Скорость обучения
            auto_save: Автосохранение
            save_mode: "lora" или "full"
            parallel: Параллельное обучение (True) или последовательное (False)
        
        Returns:
            Полный отчет об обучении класса
        """
        print("\n" + "=" * 80)
        print(f"📚 НАЧАЛО ОБУЧЕНИЯ КЛАССА")
        print("=" * 80)
        print(f"Учителей: {len(teachers)}")
        print(f"Учеников: {len(students)}")
        print(f"Режим: {'Параллельный' if parallel else 'Последовательный'}")
        print()
        
        if parallel:
            results = self._train_parallel(
                teachers, students, teacher_amplitudes,
                cycles, learning_rate, auto_save, save_mode
            )
        else:
            results = self._train_sequential(
                teachers, students, teacher_amplitudes,
                cycles, learning_rate, auto_save, save_mode
            )
        
        # Сводный отчет
        report = self._generate_class_report(teachers, students, results)
        
        # Финальная визуализация
        self._print_class_report(report)
        
        return report
    
    def _train_parallel(
        self,
        teachers: List[str],
        students: List[Dict[str, str]],
        teacher_amplitudes: Optional[List[float]],
        cycles: int,
        learning_rate: float,
        auto_save: bool,
        save_mode: str
    ) -> List[Dict]:
        """
        Параллельное обучение учеников
        """
        print("🚀 Параллельное обучение запущено...")
        print()
        
        results = []
        
        # Создаем пул процессов
        with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
            # Запускаем обучение для каждого ученика
            futures = {}
            
            for student in students:
                future = executor.submit(
                    self._train_single_student,
                    teachers,
                    student,
                    teacher_amplitudes,
                    cycles,
                    learning_rate,
                    auto_save,
                    save_mode
                )
                futures[future] = student['name']
            
            # Собираем результаты по мере завершения
            for future in as_completed(futures):
                student_name = futures[future]
                try:
                    result = future.result()
                    results.append(result)
                    
                    if self.debug:
                        print(f"✅ {student_name}: Синхронизация {result['final_sync']:.1%}")
                    
                except Exception as e:
                    print(f"❌ {student_name}: Ошибка - {e}")
                    results.append({
                        'student_name': student_name,
                        'success': False,
                        'error': str(e)
                    })
        
        return results
    
    def _train_sequential(
        self,
        teachers: List[str],
        students: List[Dict[str, str]],
        teacher_amplitudes: Optional[List[float]],
        cycles: int,
        learning_rate: float,
        auto_save: bool,
        save_mode: str
    ) -> List[Dict]:
        """
        Последовательное обучение учеников (для отладки)
        """
        print("🐌 Последовательное обучение запущено...")
        print()
        
        results = []
        
        for i, student in enumerate(students, 1):
            print(f"\n--- Ученик {i}/{len(students)}: {student['name']} ---")
            
            try:
                result = self._train_single_student(
                    teachers,
                    student,
                    teacher_amplitudes,
                    cycles,
                    learning_rate,
                    auto_save,
                    save_mode
                )
                results.append(result)
                
                print(f"✅ {student['name']}: Синхронизация {result['final_sync']:.1%}")
                
            except Exception as e:
                print(f"❌ {student['name']}: Ошибка - {e}")
                results.append({
                    'student_name': student['name'],
                    'success': False,
                    'error': str(e)
                })
        
        return results
    
    def _train_single_student(
        self,
        teachers: List[str],
        student: Dict[str, str],
        teacher_amplitudes: Optional[List[float]],
        cycles: int,
        learning_rate: float,
        auto_save: bool,
        save_mode: str
    ) -> Dict:
        """
        Обучение одного ученика
        
        Args:
            teachers: Учителя
            student: {"name": "...", "path": "...", "output": "..."}
            teacher_amplitudes: Амплитуды
            cycles: Циклы
            learning_rate: Скорость
            auto_save: Автосохранение
            save_mode: Режим сохранения
        
        Returns:
            Результат обучения ученика
        """
        student_name = student['name']
        student_path = student['path']
        output_path = student.get('output', f'./quantum_school/{student_name}')
        
        if self.debug:
            print(f"\n🎓 Обучение ученика: {student_name}")
            print(f"   Путь: {student_path}")
            print(f"   Выход: {output_path}")
        
        # Создаем QuantumRetrain для ученика
        retrain = QuantumRetrain(
            base_frequency=self.base_frequency,
            method=self.method
        )
        
        # Полный цикл переобучения
        try:
            result = retrain.full_retrain(
                teacher_models=teachers,
                student_model=student_path,
                teacher_amplitudes=teacher_amplitudes,
                cycles=cycles,
                learning_rate=learning_rate,
                auto_save=auto_save,
                save_mode=save_mode,
                output_path=output_path
            )
            
            # Добавляем имя ученика
            result['student_name'] = student_name
            result['student_path'] = student_path
            result['output_path'] = output_path
            
            return result
            
        except Exception as e:
            if self.debug:
                print(f"❌ Ошибка при обучении {student_name}: {e}")
                import traceback
                traceback.print_exc()
            
            return {
                'student_name': student_name,
                'student_path': student_path,
                'success': False,
                'error': str(e),
                'final_sync': 0.0
            }
    
    def _generate_class_report(
        self,
        teachers: List[str],
        students: List[Dict[str, str]],
        results: List[Dict]
    ) -> Dict:
        """
        Генерация сводного отчета о классе
        """
        # Подсчет статистики
        successful = [r for r in results if r.get('success', False)]
        failed = [r for r in results if not r.get('success', False)]
        
        syncs = [r['final_sync'] for r in successful if 'final_sync' in r]
        avg_sync = np.mean(syncs) if syncs else 0.0
        min_sync = np.min(syncs) if syncs else 0.0
        max_sync = np.max(syncs) if syncs else 0.0
        
        # Обновляем статистику школы
        self.students_trained += len(successful)
        self.average_sync = avg_sync
        
        report = {
            'method': self.method,
            'base_frequency': self.base_frequency,
            'teachers': {
                'count': len(teachers),
                'paths': teachers
            },
            'students': {
                'total': len(students),
                'successful': len(successful),
                'failed': len(failed),
                'names': [s['name'] for s in students]
            },
            'synchronization': {
                'average': avg_sync,
                'min': min_sync,
                'max': max_sync,
                'distribution': syncs
            },
            'results': results,
            'success_rate': len(successful) / len(students) if students else 0.0
        }
        
        return report
    
    def _print_class_report(self, report: Dict):
        """
        Печать сводного отчета о классе
        """
        print("\n" + "=" * 80)
        print("📊 ИТОГОВЫЙ ОТЧЕТ КЛАССА")
        print("=" * 80)
        print()
        
        print(f"Метод: {report['method']}")
        print(f"Резонансная частота: {report['base_frequency']} Hz")
        print()
        
        print(f"👨‍🏫 Учителей: {report['teachers']['count']}")
        print(f"👨‍🎓 Учеников: {report['students']['total']}")
        print(f"   ✅ Успешно: {report['students']['successful']}")
        print(f"   ❌ Ошибок: {report['students']['failed']}")
        print(f"   📈 Успешность: {report['success_rate']:.1%}")
        print()
        
        if report['synchronization']['distribution']:
            print(f"🔄 Синхронизация:")
            print(f"   Средняя: {report['synchronization']['average']:.1%}")
            print(f"   Минимум: {report['synchronization']['min']:.1%}")
            print(f"   Максимум: {report['synchronization']['max']:.1%}")
            print()
        
        # Список учеников с результатами
        print("📋 Результаты учеников:")
        for result in report['results']:
            name = result['student_name']
            if result.get('success', False):
                sync = result.get('final_sync', 0.0)
                status = "✅"
                info = f"Синхронизация: {sync:.1%}"
            else:
                status = "❌"
                error = result.get('error', 'Unknown error')
                info = f"Ошибка: {error[:50]}"
            
            print(f"   {status} {name}: {info}")
        
        print()
        print("=" * 80)
        print("🎉 ОБУЧЕНИЕ КЛАССА ЗАВЕРШЕНО!")
        print("=" * 80)
        print()
    
    def get_statistics(self) -> Dict:
        """
        Получить статистику школы
        
        Returns:
            Словарь со статистикой
        """
        return {
            'students_trained': self.students_trained,
            'total_cycles': self.total_cycles,
            'average_sync': self.average_sync,
            'method': self.method,
            'base_frequency': self.base_frequency
        }


def quick_class_train(
    teachers: List[str],
    students: List[Dict[str, str]],
    method: str = "equalizer",
    parallel: bool = True
) -> Dict:
    """
    Быстрое обучение класса с настройками по умолчанию
    
    Args:
        teachers: Список учителей
        students: Список учеников [{"name": "...", "path": "..."}]
        method: "equalizer" или "pyramid"
        parallel: Параллельное обучение
    
    Returns:
        Отчет о классе
    
    Example:
        >>> teachers = ["model1", "model2"]
        >>> students = [
        ...     {"name": "Sphere027", "path": "base_model", "output": "sphere_027"},
        ...     {"name": "Sphere028", "path": "base_model", "output": "sphere_028"}
        ... ]
        >>> report = quick_class_train(teachers, students)
    """
    school = QuantumSchool(method=method, debug=True)
    
    return school.train_class(
        teachers=teachers,
        students=students,
        cycles=20,
        learning_rate=0.05,
        auto_save=True,
        save_mode="lora",
        parallel=parallel
    )

