"""
Inference script for Unimodal Parallel-LLM
Demonstrates high-speed parallel token generation.
"""
import torch
from transformers import AutoTokenizer
import sys
import os

# Add project root to sys.path to allow importing parallel_llm from source
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# Updated imports
from parallel_llm.core import DiffusionTransformer, ModelConfig
from parallel_llm.inference import ParallelGenerator, GenerationConfig

def main():
    print("="*50)
    print("Parallel-LLM Unimodal Inference")
    print("="*50)

    # 1. Load Model
    # In a real scenario, load from a checkpoint:
    # model = DiffusionTransformer.from_pretrained("./checkpoints/unimodal")
    
    # For demo, initialize fresh
    config = ModelConfig(vocab_size=50257, hidden_size=1024, num_hidden_layers=12)
    model = DiffusionTransformer(config)
    model.eval()
    
    if torch.cuda.is_available():
        model = model.cuda()

    # 2. Configure Generation
    gen_config = GenerationConfig(
        max_new_tokens=128,       # Generate 128 tokens
        num_parallel_tokens=64,   # Generate 64 tokens simultaneously per step
        num_refinement_steps=5,   # Diffusion refinement steps
        temperature=0.8,
        top_k=50,
        use_adaptive_refinement=True # Stop early if confidence is high
    )

    # 3. Initialize Generator
    generator = ParallelGenerator(
        model=model,
        config=gen_config,
        use_kv_cache=True,
        use_cuda_graphs=True if torch.cuda.is_available() else False
    )

    # 4. Prepare Input
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    prompt = "Artificial intelligence is transforming the world by"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    if torch.cuda.is_available():
        input_ids = input_ids.cuda()

    print(f"Prompt: {prompt}")
    print("Generating...")

    # 5. Generate
    with torch.no_grad():
        # Returns generated token IDs
        output_ids = generator.generate(input_ids)

    # 6. Decode
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    print("\nGenerated Text:")
    print("-" * 20)
    print(output_text)
    print("-" * 20)

if __name__ == "__main__":
    main()
