"""
Training script for Unimodal (Text-only) Parallel-LLM using WikiText-2.
Demonstrates distributed training with FSDP, diffusion loss, and real-world data loading.
"""
import os
import sys
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
from datasets import load_dataset
from itertools import islice
from parallel_llm.core import DiffusionTransformer, ModelConfig
from parallel_llm.training import DistributedTrainer, TrainingConfig
from parallel_llm.utils import TextDataset

def setup_distributed():
    """Initialize distributed training environment"""
    if "LOCAL_RANK" not in os.environ:
        return 0  # DistributedTrainer will handle multi-GPU auto-detection
    
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank

def main():
    print("="*60)
    print("Parallel-LLM Unimodal Training Example (WikiText-2)")
    print("="*60)
    
    # Detect available GPUs
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    if num_gpus > 0:
        print(f"\n🎮 Detected {num_gpus} GPU(s)")
        for i in range(num_gpus):
            gpu_name = torch.cuda.get_device_name(i)
            print(f"   GPU {i}: {gpu_name}")
    else:
        print("\n💻 No GPU detected, using CPU")

    # 1. Setup Distributed Environment
    local_rank = setup_distributed()
    is_main_process = local_rank == 0

    if is_main_process:
        print("\nStarting Unimodal Training")

    # 2. Configuration
    # Model Configuration (GPU-friendly size)
    model_config = ModelConfig(
        vocab_size=32000,      # Will be updated after loading tokenizer
        hidden_size=768,       # Reduced from 2048
        num_hidden_layers=12,  # Reduced from 22
        num_attention_heads=12, # Reduced from 32
        num_diffusion_steps=10, # Reduced from 64
        use_flash_attention=True if torch.cuda.is_available() else False,
        dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
    )

    # Training Configuration
    train_config = TrainingConfig(
        output_dir="./checkpoints/unimodal_wikitext",
        num_train_steps=1000,
        batch_size=4,  # Adjust based on VRAM (4 fits on 16GB with grad checkpointing)
        learning_rate=3e-4,
        warmup_steps=100,
        use_fsdp=False, # Enable if multiple GPUs available
        mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
        gradient_checkpointing=True, # Save memory
        logging_steps=10,
        save_steps=500,
        eval_steps=200,
        use_torch_compile=True
    )

    # 3. Data Preparation
    if is_main_process:
        print("\n[Step 1/5] Loading tokenizer (TinyLlama)...")

    tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    tokenizer.pad_token = tokenizer.eos_token
    model_config.vocab_size = tokenizer.vocab_size
    
    if is_main_process:
        print(f"✓ Tokenizer loaded: {tokenizer.vocab_size:,} tokens")

    if is_main_process:
        print("\n[Step 2/5] Loading WikiText-2 dataset...")
    
    # Load streaming to avoid memory issues
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", streaming=True)
    
    # Take a subset for this example
    train_data = list(islice(dataset, 1000))
    
    if is_main_process:
        print(f"✓ Dataset loaded: {len(train_data)} training samples")

    train_dataset = TextDataset(
        dataset=train_data,
        tokenizer=tokenizer,
        max_length=512
    )

    sampler = None
    if "LOCAL_RANK" in os.environ and dist.is_initialized():
        sampler = DistributedSampler(train_dataset)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_config.batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # 4. Initialize Model
    if is_main_process:
        print("\n[Step 3/5] Initializing DiffusionTransformer model (may take 30-60s)...")
    
    model = DiffusionTransformer(model_config)
    
    if is_main_process:
        num_params = sum(p.numel() for p in model.parameters()) / 1e6
        print(f"✓ Model initialized: {num_params:.1f}M parameters")
    
    # 5. Setup Distributed Trainer
    if is_main_process:
        print("\n[Step 4/5] Setting up DistributedTrainer...")
    
    trainer = DistributedTrainer(
        model=model,
        train_config=train_config,
        model_config=model_config,
        train_dataloader=train_dataloader
    )

    # 6. Start Training
    if is_main_process:
        print("\n[Step 5/5] Starting training loop...")
        print("="*60)
    
    trainer.train()

    if is_main_process:
        print("Training complete!")

if __name__ == "__main__":
    main()
