"""
Training script for Multimodal (Text + Image) Parallel-LLM using Conceptual Captions.
Demonstrates training a model that can handle both text and image inputs using real-world data.
"""
import os
import sys
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoImageProcessor
from datasets import load_dataset
from parallel_llm.core import DiffusionTransformer, MultimodalConfig
from parallel_llm.training import DistributedTrainer, TrainingConfig
from parallel_llm.utils import MultimodalDataset

def setup_distributed():
    if "LOCAL_RANK" not in os.environ:
        print("Not running in distributed mode. Using single GPU/CPU.")
        return 0
    
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank

def main():
    print("="*60)
    print("Parallel-LLM Multimodal Training Example (Conceptual Captions)")
    print("="*60)

    local_rank = setup_distributed()
    is_main_process = local_rank == 0

    if is_main_process:
        print("Starting Multimodal Training")

    # 1. Multimodal Configuration (GPU-friendly size)
    model_config = MultimodalConfig(
        # Text parameters (reduced size)
        vocab_size=32000,
        hidden_size=768,       # Reduced from 2048
        num_hidden_layers=12,  # Reduced from 22
        
        # Vision parameters (smaller ViT)
        vision_encoder="vit",
        image_size=224,
        patch_size=16,
        vision_hidden_size=384,  # Reduced from 768
        
        # Fusion parameters
        fusion_type="cross_attention",
        num_cross_attention_layers=4,
        
        # Training objectives
        use_contrastive=True,
        contrastive_temperature=0.07,
        
        dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
    )

    train_config = TrainingConfig(
        output_dir="./checkpoints/multimodal_cc",
        num_train_steps=1000,
        batch_size=2,  # Smaller batch size for multimodal
        learning_rate=1e-4,
        warmup_steps=100,
        mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
        gradient_checkpointing=True,
        use_fsdp=False,
        logging_steps=10,
        save_steps=500,
        eval_steps=200,
        use_torch_compile=True
    )

    # 2. Data Preparation
    if is_main_process:
        print("Loading processors and dataset...")

    tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    tokenizer.pad_token = tokenizer.eos_token
    model_config.vocab_size = tokenizer.vocab_size

    image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

    # Load Conceptual Captions (streaming)
    # Note: CC requires downloading images from URLs, which can be slow/flaky.
    # For robustness in this example, we'll try to use it but handle failures gracefully
    # or use a pre-downloaded subset if available. 
    # Here we use streaming and filter for valid images.
    dataset = load_dataset("conceptual_captions", split="train", streaming=True)
    dataset = dataset.take(2000) # Take a small subset for demo start

    train_dataset = MultimodalDataset(
        dataset=dataset,
        tokenizer=tokenizer,
        image_processor=image_processor,
        text_column="caption",
        image_column="image_url", # MultimodalDataset handles URL downloading if column is URL
        max_length=128
    )

    sampler = None
    if "LOCAL_RANK" in os.environ and dist.is_initialized():
        sampler = DistributedSampler(train_dataset)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_config.batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # 3. Model & Trainer
    if is_main_process:
        print("Initializing Multimodal DiffusionTransformer...")

    model = DiffusionTransformer(model_config)

    trainer = DistributedTrainer(
        model=model,
        train_config=train_config,
        model_config=model_config,
        train_dataloader=train_dataloader
    )

    # 4. Train
    if is_main_process:
        print("Starting training...")

    trainer.train()

if __name__ == "__main__":
    main()
