"""
Training script for Unimodal (Text-only) Parallel-LLM
Demonstrates distributed training with FSDP, diffusion loss, and custom data loading.
"""
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
from datasets import load_dataset

# Updated imports for v0.4.0 structure
from parallel_llm.core import DiffusionTransformer, ModelConfig
from parallel_llm.training import DistributedTrainer, TrainingConfig, DiffusionLoss
from parallel_llm.utils import TextDataset

def setup_distributed():
    """Initialize distributed training environment"""
    if "LOCAL_RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        return local_rank
    else:
        print("Not running in distributed mode. Using CPU/Single GPU.")
        return 0

def main():
    # 1. Setup Distributed Environment
    local_rank = setup_distributed()
    is_main_process = local_rank == 0

    if is_main_process:
        print("="*50)
        print("Starting Unimodal Training")
        print("="*50)

    # 2. Configuration
    # Model Configuration
    model_config = ModelConfig(
        vocab_size=50257,      # GPT-2 vocab size
        hidden_size=1024,      # Model dimension
        num_hidden_layers=12,  # Number of layers
        num_attention_heads=16,
        num_diffusion_steps=100, # Diffusion steps
        use_energy_model=True,   # Hybrid energy-based model
        use_flash_attention=True if torch.cuda.is_available() else False
    )

    # Training Configuration
    train_config = TrainingConfig(
        output_dir="./checkpoints/unimodal",
        num_train_steps=50000,
        batch_size=8,
        learning_rate=3e-4,
        warmup_steps=1000,
        use_fsdp=True if torch.cuda.device_count() > 1 else False,
        mixed_precision="bf16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "no",
        logging_steps=10,
        save_steps=1000,
        use_wandb=False # Set to True to enable WandB logging
    )

    # 3. Data Preparation
    if is_main_process:
        print("Loading tokenizer and dataset...")
    
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # Load a small dataset for demonstration (WikiText-2)
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    
    train_dataset = TextDataset(
        dataset=dataset,
        tokenizer=tokenizer,
        max_length=512
    )

    sampler = None
    if dist.is_initialized():
        sampler = DistributedSampler(train_dataset)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_config.batch_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=4,
        pin_memory=True
    )

    # 4. Model Initialization
    if is_main_process:
        print("Initializing DiffusionTransformer model...")
    
    model = DiffusionTransformer(model_config)

    # 5. Trainer Initialization
    trainer = DistributedTrainer(
        model=model,
        train_config=train_config,
        model_config=model_config,
        train_dataloader=train_dataloader,
        # Custom loss function is automatically handled by DistributedTrainer 
        # using DiffusionLoss, but you can override it if needed.
    )

    # 6. Start Training
    if is_main_process:
        print("Starting training loop...")
    
    trainer.train()

    if is_main_process:
        print("Training complete!")

if __name__ == "__main__":
    main()
