"""
Inference script for Multimodal Parallel-LLM
Demonstrates generating text descriptions from images.
"""
import torch
from transformers import AutoTokenizer, AutoImageProcessor
from PIL import Image
import requests

# Updated imports
from parallel_llm.core import DiffusionTransformer, MultimodalConfig
from parallel_llm.inference import ParallelGenerator, GenerationConfig

def main():
    print("="*50)
    print("Parallel-LLM Multimodal Inference")
    print("="*50)

    # 1. Load Model
    config = MultimodalConfig(
        vocab_size=50257, 
        vision_encoder="vit",
        fusion_type="cross_attention"
    )
    model = DiffusionTransformer(config)
    model.eval()
    
    if torch.cuda.is_available():
        model = model.cuda()

    # 2. Prepare Inputs
    # Load Image
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    # Process Image
    image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
    pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
    
    # Prepare Text Prompt (optional, can be empty for captioning)
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    prompt = "A picture of"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    if torch.cuda.is_available():
        pixel_values = pixel_values.cuda()
        input_ids = input_ids.cuda()

    # 3. Configure Generation
    gen_config = GenerationConfig(
        max_new_tokens=64,
        num_parallel_tokens=32,
        temperature=0.7
    )

    # 4. Initialize Generator
    generator = ParallelGenerator(model, gen_config)

    print("Generating caption...")

    # 5. Generate
    # Note: The generator handles multimodal inputs if the model is configured for it
    with torch.no_grad():
        output_ids = generator.generate(
            input_ids,
            pixel_values=pixel_values
        )

    # 6. Decode
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"\nResult: {output_text}")

if __name__ == "__main__":
    main()
