"""
Training script for Unimodal (Text-only) Parallel-LLM
Demonstrates distributed training with FSDP, diffusion loss, and custom data loading.

KNOWN ISSUE: PyTorch has compatibility issues on Windows.
This script will demonstrate the import structure and provide guidance.
For actual execution, use Linux/WSL with CUDA support.
"""
import os
import sys

print("="*60)
print("Parallel-LLM Unimodal Training Example")
print("="*60)

# Check Python version
print(f"Python version: {sys.version}")

# Check platform
print(f"Platform: {sys.platform}")

# Add project root to sys.path to allow importing parallel_llm from source
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
print(f"Added to path: {os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))}")

# Try to import dependencies
PYTORCH_AVAILABLE = False
PARALLEL_LLM_AVAILABLE = False

print("\n[1/3] Checking PyTorch availability...")
try:
    import torch
    import torch.distributed as dist
    from torch.utils.data import DataLoader, DistributedSampler
    from transformers import AutoTokenizer
    from datasets import load_dataset
    PYTORCH_AVAILABLE = True
    print("✅ PyTorch, transformers, and datasets imported successfully")
except Exception as e:
    print("❌ PyTorch/transformers/datasets not available:")
    print(f"   Error: {e}")
    print("   This is a known issue on Windows with PyTorch binaries.")
    print("   Solutions:")
    print("   - Use WSL (Windows Subsystem for Linux)")
    print("   - Use a Linux environment")
    print("   - Use Docker with CUDA support")

if PYTORCH_AVAILABLE:
    print("\n[2/3] Checking parallel_llm package...")
    try:
        # Updated imports for v0.4.0 structure
        from parallel_llm.core import DiffusionTransformer, ModelConfig
        from parallel_llm.training import DistributedTrainer, TrainingConfig, DiffusionLoss
        from parallel_llm.utils import TextDataset
        PARALLEL_LLM_AVAILABLE = True
        print("✅ parallel_llm package imported successfully")
    except ImportError as e:
        print("❌ parallel_llm package not available:")
        print(f"   Error: {e}")
        print("   Please install the package: pip install -e .")
else:
    print("\n[2/3] Skipping parallel_llm check (PyTorch not available)")

print("\n[3/3] Configuration check...")
if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
    print("✅ All dependencies available - example can run")
else:
    print("❌ Dependencies missing - example cannot run on this system")
    print("\nTo run this example:")
    print("1. Use Linux or WSL environment")
    print("2. Install CUDA 12.1+")
    print("3. Install PyTorch: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
    print("4. Install dependencies: pip install transformers datasets")
    print("5. Install parallel-llm: pip install -e .")

def setup_distributed():
    """Initialize distributed training environment"""
    if "LOCAL_RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        return local_rank
    else:
        print("Not running in distributed mode. Using CPU/Single GPU.")
        return 0

def main():
    if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
        print("\n🚀 Running actual unimodal training...")

        # 1. Setup Distributed Environment
        local_rank = setup_distributed()
        is_main_process = local_rank == 0

        if is_main_process:
            print("="*50)
            print("Starting Unimodal Training")
            print("="*50)

        # 2. Configuration
        # Model Configuration (very small for demo)
        model_config = ModelConfig(
            vocab_size=1000,       # Small vocab for demo (not full GPT-2)
            hidden_size=128,       # Very small model for demo
            num_hidden_layers=2,   # Very few layers for demo
            num_attention_heads=4, # Fewer heads
            num_diffusion_steps=5, # Very few diffusion steps for demo
            use_energy_model=False, # Disable for compatibility
            use_flash_attention=False  # Disable for compatibility
        )

        # Training Configuration
        train_config = TrainingConfig(
            output_dir="./checkpoints/unimodal",
            num_train_steps=50,   # Very small for ultra-quick demo
            batch_size=2,  # Very small batch size for demo
            learning_rate=1e-3,  # Higher learning rate for quick demo
            warmup_steps=5,   # Very short warmup
            use_fsdp=False,  # Disable FSDP for single-device training
            use_deepspeed=False,  # Disable DeepSpeed for single-device training
            mixed_precision="no",  # Disable mixed precision for compatibility
            logging_steps=5,
            save_steps=50,
            eval_steps=25,
            use_wandb=False # Set to True to enable WandB logging
        )

        # 3. Data Preparation
        if is_main_process:
            print("Loading tokenizer and dataset...")

        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token

        # Create a simple mock dataset for demonstration
        print("Creating mock dataset for demonstration...")
        mock_texts = [
            "The quick brown fox jumps over the lazy dog.",
            "Machine learning is transforming technology.",
            "Natural language processing enables computers to understand text.",
            "Deep learning models require significant computational resources.",
            "Transformers revolutionized the field of NLP.",
            "Artificial intelligence is becoming more sophisticated.",
            "Neural networks can learn complex patterns.",
            "Computer vision enables machines to see.",
            "Large language models understand context.",
            "Training requires significant computational power."
        ] * 50  # Create 500 samples for training
        dataset = [{"text": text} for text in mock_texts]

        train_dataset = TextDataset(
            dataset=dataset,
            tokenizer=tokenizer,
            max_length=64  # Very short sequences for demo
        )

        sampler = None
        # Only use distributed sampler if actually running distributed training
        if "LOCAL_RANK" in os.environ and dist.is_initialized():
            sampler = DistributedSampler(train_dataset)

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=train_config.batch_size,
            sampler=sampler,
            shuffle=(sampler is None),
            num_workers=0,  # Use 0 workers to avoid multiprocessing issues in demos
            pin_memory=False  # Disable pin_memory for CPU training
        )

        # 4. Model Initialization
        if is_main_process:
            print("Initializing DiffusionTransformer model...")

        model = DiffusionTransformer(model_config)

        # 5. Trainer Initialization
        trainer = DistributedTrainer(
            model=model,
            train_config=train_config,
            model_config=model_config,
            train_dataloader=train_dataloader,
            # Custom loss function is automatically handled by DistributedTrainer
            # using DiffusionLoss, but you can override it if needed.
        )

        # 6. Start Training
        if is_main_process:
            print("Starting training loop...")

        trainer.train()

        if is_main_process:
            print("Training complete!")
    else:
        print("\n📋 Example structure demonstration:")
        print("This example would perform the following steps:")
        print("1. Set up distributed training environment (NCCL, local rank)")
        print("2. Configure DiffusionTransformer model with GPT-2 vocabulary")
        print("3. Set up training config with FSDP, mixed precision, and diffusion loss")
        print("4. Load WikiText-2 dataset and create TextDataset")
        print("5. Initialize DistributedTrainer with custom data loading")
        print("6. Run training loop with 50K steps, logging, and checkpointing")
        print("\nTo run this example, use a Linux environment with CUDA support.")

if __name__ == "__main__":
    main()
