"""
Training script for Unimodal (Text-only) Parallel-LLM
Demonstrates distributed training with FSDP, diffusion loss, and custom data loading.

KNOWN ISSUE: PyTorch has compatibility issues on Windows.
This script will demonstrate the import structure and provide guidance.
For actual execution, use Linux/WSL with CUDA support.
"""
import os
import sys

print("="*60)
print("Parallel-LLM Unimodal Training Example")
print("="*60)

# Check Python version
print(f"Python version: {sys.version}")

# Check platform
print(f"Platform: {sys.platform}")

# Add project root to sys.path to allow importing parallel_llm from source
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
print(f"Added to path: {os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))}")

# Try to import dependencies
PYTORCH_AVAILABLE = False
PARALLEL_LLM_AVAILABLE = False

print("\n[1/3] Checking PyTorch availability...")
try:
    import torch
    import torch.distributed as dist
    from torch.utils.data import DataLoader, DistributedSampler
    from transformers import AutoTokenizer
    from datasets import load_dataset
    PYTORCH_AVAILABLE = True
    print("✅ PyTorch, transformers, and datasets imported successfully")
except Exception as e:
    print("❌ PyTorch/transformers/datasets not available:")
    print(f"   Error: {e}")
    print("   This is a known issue on Windows with PyTorch binaries.")
    print("   Solutions:")
    print("   - Use WSL (Windows Subsystem for Linux)")
    print("   - Use a Linux environment")
    print("   - Use Docker with CUDA support")

if PYTORCH_AVAILABLE:
    print("\n[2/3] Checking parallel_llm package...")
    try:
        # Updated imports for v0.4.0 structure
        from parallel_llm.core import DiffusionTransformer, ModelConfig
        from parallel_llm.training import DistributedTrainer, TrainingConfig, DiffusionLoss
        from parallel_llm.utils import TextDataset
        PARALLEL_LLM_AVAILABLE = True
        print("✅ parallel_llm package imported successfully")
    except ImportError as e:
        print("❌ parallel_llm package not available:")
        print(f"   Error: {e}")
        print("   Please install the package: pip install -e .")
else:
    print("\n[2/3] Skipping parallel_llm check (PyTorch not available)")

print("\n[3/3] Configuration check...")
if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
    print("✅ All dependencies available - example can run")
else:
    print("❌ Dependencies missing - example cannot run on this system")
    print("\nTo run this example:")
    print("1. Use Linux or WSL environment")
    print("2. Install CUDA 12.1+")
    print("3. Install PyTorch: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
    print("4. Install dependencies: pip install transformers datasets")
    print("5. Install parallel-llm: pip install -e .")

def setup_distributed():
    """Initialize distributed training environment"""
    if "LOCAL_RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        return local_rank
    else:
        print("Not running in distributed mode. Using CPU/Single GPU.")
        return 0

def main():
    if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
        print("\n🚀 Running actual unimodal training...")

        # 1. Setup Distributed Environment
        local_rank = setup_distributed()
        is_main_process = local_rank == 0

        if is_main_process:
            print("="*50)
            print("Starting Unimodal Training")
            print("="*50)

        # 2. Configuration
        # Model Configuration
        model_config = ModelConfig(
            vocab_size=50257,      # GPT-2 vocab size
            hidden_size=1024,      # Model dimension
            num_hidden_layers=12,  # Number of layers
            num_attention_heads=16,
            num_diffusion_steps=100, # Diffusion steps
            use_energy_model=True,   # Hybrid energy-based model
            use_flash_attention=True if torch.cuda.is_available() else False
        )

        # Training Configuration
        train_config = TrainingConfig(
            output_dir="./checkpoints/unimodal",
            num_train_steps=50000,
            batch_size=8,
            learning_rate=3e-4,
            warmup_steps=1000,
            use_fsdp=True if torch.cuda.device_count() > 1 else False,
            mixed_precision="bf16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "no",
            logging_steps=10,
            save_steps=1000,
            use_wandb=False # Set to True to enable WandB logging
        )

        # 3. Data Preparation
        if is_main_process:
            print("Loading tokenizer and dataset...")

        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token

        # Load a small dataset for demonstration (WikiText-2)
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

        train_dataset = TextDataset(
            dataset=dataset,
            tokenizer=tokenizer,
            max_length=512
        )

        sampler = None
        if dist.is_initialized():
            sampler = DistributedSampler(train_dataset)

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=train_config.batch_size,
            sampler=sampler,
            shuffle=(sampler is None),
            num_workers=4,
            pin_memory=True
        )

        # 4. Model Initialization
        if is_main_process:
            print("Initializing DiffusionTransformer model...")

        model = DiffusionTransformer(model_config)

        # 5. Trainer Initialization
        trainer = DistributedTrainer(
            model=model,
            train_config=train_config,
            model_config=model_config,
            train_dataloader=train_dataloader,
            # Custom loss function is automatically handled by DistributedTrainer
            # using DiffusionLoss, but you can override it if needed.
        )

        # 6. Start Training
        if is_main_process:
            print("Starting training loop...")

        trainer.train()

        if is_main_process:
            print("Training complete!")
    else:
        print("\n📋 Example structure demonstration:")
        print("This example would perform the following steps:")
        print("1. Set up distributed training environment (NCCL, local rank)")
        print("2. Configure DiffusionTransformer model with GPT-2 vocabulary")
        print("3. Set up training config with FSDP, mixed precision, and diffusion loss")
        print("4. Load WikiText-2 dataset and create TextDataset")
        print("5. Initialize DistributedTrainer with custom data loading")
        print("6. Run training loop with 50K steps, logging, and checkpointing")
        print("\nTo run this example, use a Linux environment with CUDA support.")

if __name__ == "__main__":
    main()
