"""
Training script for Multimodal (Text + Image) Parallel-LLM
Demonstrates training a model that can handle both text and image inputs.
"""
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoImageProcessor
from datasets import load_dataset

# Updated imports
from parallel_llm.core import DiffusionTransformer, MultimodalConfig
from parallel_llm.training import DistributedTrainer, TrainingConfig
from parallel_llm.utils import MultimodalDataset

def setup_distributed():
    if "LOCAL_RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        return local_rank
    return 0

def main():
    local_rank = setup_distributed()
    is_main_process = local_rank == 0

    if is_main_process:
        print("="*50)
        print("Starting Multimodal Training")
        print("="*50)

    # 1. Multimodal Configuration
    model_config = MultimodalConfig(
        # Text parameters
        vocab_size=50257,
        hidden_size=1024,
        num_hidden_layers=12,
        
        # Vision parameters
        vision_encoder="vit",  # 'vit', 'clip', or 'siglip'
        image_size=224,
        patch_size=16,
        vision_hidden_size=768,
        
        # Fusion parameters
        fusion_type="cross_attention",
        num_cross_attention_layers=4,
        
        # Training objectives
        use_contrastive=True,  # Enable CLIP-style contrastive loss
        contrastive_temperature=0.07
    )

    train_config = TrainingConfig(
        output_dir="./checkpoints/multimodal",
        num_train_steps=25000,
        batch_size=4,  # Smaller batch size for multimodal
        learning_rate=1e-4,
        mixed_precision="bf16",
        gradient_checkpointing=True
    )

    # 2. Data Preparation
    if is_main_process:
        print("Loading processors and dataset...")

    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    # Use a standard ViT image processor
    image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

    # Example: Load COCO captions or similar dataset
    # For demo purposes, we'll assume a dataset with 'image' and 'caption' columns
    # dataset = load_dataset("coco", split="train") 
    
    # Mock dataset for demonstration if real one isn't available
    class MockMultimodalDataset(torch.utils.data.Dataset):
        def __init__(self, length=100):
            self.length = length
        def __len__(self): return self.length
        def __getitem__(self, idx):
            return {
                "image": torch.randn(3, 224, 224), # Mock image tensor
                "text": "A description of an image."
            }
            
    dataset = MockMultimodalDataset()

    # Use the utility class from parallel_llm.utils
    # Note: In a real scenario, pass the HuggingFace dataset directly
    # train_dataset = MultimodalDataset(dataset, tokenizer, image_processor, text_column="caption")
    
    # For this mock, we'll just use a simple wrapper or the mock itself if it returns tensors
    # But let's show how to use the library's dataset class properly with a real-ish structure
    train_dataset = MultimodalDataset(
        dataset=[{"image": torch.randn(3, 224, 224), "text": "demo"} for _ in range(100)],
        tokenizer=tokenizer,
        image_processor=image_processor,
        text_column="text",
        image_column="image"
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_config.batch_size,
        shuffle=True,
        num_workers=2
    )

    # 3. Model & Trainer
    if is_main_process:
        print("Initializing Multimodal DiffusionTransformer...")

    model = DiffusionTransformer(model_config)

    trainer = DistributedTrainer(
        model=model,
        train_config=train_config,
        model_config=model_config,
        train_dataloader=train_dataloader
    )

    # 4. Train
    if is_main_process:
        print("Starting training...")
    
    trainer.train()

if __name__ == "__main__":
    main()
