"""
Inference script for Unimodal Parallel-LLM using TinyLlama architecture.
This script demonstrates high-speed parallel token generation using a model configured
to match TinyLlama-1.1B dimensions, fitting comfortably on a Tesla P100 (16GB).
"""
import sys
import os
import torch
from transformers import AutoTokenizer
from parallel_llm.core import DiffusionTransformer, ModelConfig
from parallel_llm.inference import ParallelGenerator, GenerationConfig

def main():
    print("="*60)
    print("Parallel-LLM Unimodal Inference Example (TinyLlama-1.1B Config)")
    print("="*60)

    # Check for CUDA
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    if device == "cuda":
        props = torch.cuda.get_device_properties(device)
        print(f"GPU: {props.name} | VRAM: {props.total_memory / 1024**3:.2f} GB")

    # 1. Load Tokenizer (Real-world tokenizer)
    model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    print(f"\n[1/4] Loading tokenizer from {model_id}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
    except Exception as e:
        print(f"Failed to load tokenizer: {e}")
        print("Fallback to gpt2 tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")

    # 2. Initialize Model (Reduced size for GPU compatibility)
    print("\n[2/4] Initializing DiffusionTransformer with GPU-friendly config...")
    # Smaller config that fits comfortably on 14GB GPU (~500M params)
    config = ModelConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=768,            # Reduced from 2048
        num_hidden_layers=12,       # Reduced from 22
        num_attention_heads=12,     # Reduced from 32
        num_diffusion_steps=10,     # Reduced from 64
        use_flash_attention=True if device == "cuda" else False,
        dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
    )
    
    model = DiffusionTransformer(config)
    model.to(device)
    model.eval()
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

    # 3. Configure Generation
    print("\n[3/4] Configuring Parallel Generation...")
    gen_config = GenerationConfig(
        max_new_tokens=128,
        num_parallel_tokens=64,  # Generate 64 tokens in parallel
        num_refinement_steps=5,
        temperature=0.8,
        top_k=50,
        repetition_penalty=1.2,  # Prevent token repetition
        use_adaptive_steps=True,
        use_torch_compile=True if device == "cuda" else False # Enable torch.compile for speed
    )

    generator = ParallelGenerator(
        model=model,
        config=gen_config,
        use_kv_cache=True,
        use_cuda_graphs=True if device == "cuda" else False
    )

    # 4. Run Inference
    print("\n[4/4] Running Inference...")
    prompt = "The future of artificial intelligence is"
    print(f"Prompt: {prompt}")

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Warmup
    print("Warming up...")
    with torch.no_grad():
        _ = generator.generate(input_ids, max_new_tokens=10, use_autoregressive=True)

    # Actual generation
    print("Generating...")
    print("Mode: Autoregressive (generates tokens one-by-one for coherent output)")
    with torch.no_grad():
        output_ids = generator.generate(input_ids, use_autoregressive=True)

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    print("\nGenerated Text:")
    print("-" * 40)
    print(output_text)
    print("-" * 40)
    print("\nNote: Since this model is initialized with random weights, the output text will be incoherent.")
    print("To generate meaningful text, please train the model using `train_unimodal.py`.")

if __name__ == "__main__":
    main()
