"""
Inference script for Multimodal Parallel-LLM using TinyLlama + ViT architecture.
This script demonstrates generating text descriptions from images using a model configured
with TinyLlama-1.1B (Text) and ViT-Base (Vision) dimensions.
"""
import sys
import os
import torch
import requests
from PIL import Image
from transformers import AutoTokenizer, AutoImageProcessor
from parallel_llm.core import DiffusionTransformer, MultimodalConfig
from parallel_llm.inference import ParallelGenerator, GenerationConfig

def main():
    print("="*60)
    print("Parallel-LLM Multimodal Inference Example (TinyLlama + ViT)")
    print("="*60)

    # Check for CUDA
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # 1. Load Processors (Real-world models)
    text_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    vision_model_id = "google/vit-base-patch16-224"
    
    print(f"\n[1/4] Loading processors...")
    print(f"Text: {text_model_id}")
    print(f"Vision: {vision_model_id}")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(text_model_id)
        image_processor = AutoImageProcessor.from_pretrained(vision_model_id)
    except Exception as e:
        print(f"Failed to load processors: {e}")
        return

    # 2. Initialize Model (Reduced size for GPU compatibility)
    print("\n[2/4] Initializing Multimodal DiffusionTransformer...")
    config = MultimodalConfig(
        # Text parameters (reduced size)
        vocab_size=tokenizer.vocab_size,
        hidden_size=768,            # Reduced from 2048
        num_hidden_layers=12,       # Reduced from 22
        
        # Vision parameters (smaller ViT)
        vision_encoder="vit",
        image_size=224,
        patch_size=16,
        vision_hidden_size=384,     # Reduced from 768
        
        # Fusion parameters
        fusion_type="cross_attention",
        num_cross_attention_layers=4,
        
        # General
        num_diffusion_steps=10,     # Reduced from 64
        use_flash_attention=True if device == "cuda" else False,
        dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
    )
    
    model = DiffusionTransformer(config)
    model.to(device)
    model.eval()
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

    # 3. Prepare Inputs
    print("\n[3/4] Preparing Inputs...")
    try:
        url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        print(f"Downloading image from {url}...")
        image = Image.open(requests.get(url, stream=True).raw)
    except Exception as e:
        print(f"Failed to download image: {e}")
        print("Using mock image...")
        image = Image.new('RGB', (224, 224), color=(128, 128, 128))

    pixel_values = image_processor(images=image, return_tensors="pt").pixel_values.to(device)
    if torch.cuda.is_bf16_supported():
        pixel_values = pixel_values.to(torch.bfloat16)
    
    prompt = "A picture of"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # 4. Run Inference
    print("\n[4/4] Running Inference...")
    
    gen_config = GenerationConfig(
        max_new_tokens=64,
        num_parallel_tokens=64,  # Generate all tokens in parallel
        num_refinement_steps=5,  # Use 5 refinement steps for better quality
        temperature=0.7,
        repetition_penalty=1.2,  # Prevent token repetition
        confidence_threshold=0.5,  # Moderate confidence threshold
        use_torch_compile=True if device == "cuda" else False
    )

    generator = ParallelGenerator(model, gen_config)

    print("Generating caption...")
    print("Mode: Autoregressive (generates tokens one-by-one for coherent output)")
    with torch.no_grad():
        output_ids = generator.generate(
            input_ids,
            pixel_values=pixel_values,
            use_autoregressive=True  # Enable autoregressive mode
        )

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"\nResult: {output_text}")
    print("\nNote: Since this model is initialized with random weights, the output text will be incoherent.")

if __name__ == "__main__":
    main()
