"""
Inference script for Unimodal Parallel-LLM
Demonstrates high-speed parallel token generation.

KNOWN ISSUE: PyTorch has compatibility issues on Windows.
This script will demonstrate the import structure and provide guidance.
For actual execution, use Linux/WSL with CUDA support.
"""
import sys
import os

print("="*60)
print("Parallel-LLM Unimodal Inference Example")
print("="*60)

# Check Python version
print(f"Python version: {sys.version}")

# Check platform
print(f"Platform: {sys.platform}")

# Add project root to sys.path to allow importing parallel_llm from source
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
print(f"Added to path: {os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))}")

# Try to import dependencies
PYTORCH_AVAILABLE = False
PARALLEL_LLM_AVAILABLE = False

print("\n[1/3] Checking PyTorch availability...")
try:
    import torch
    from transformers import AutoTokenizer
    PYTORCH_AVAILABLE = True
    print("✅ PyTorch and transformers imported successfully")
except Exception as e:
    print("❌ PyTorch/transformers not available:")
    print(f"   Error: {e}")
    print("   This is a known issue on Windows with PyTorch binaries.")
    print("   Solutions:")
    print("   - Use WSL (Windows Subsystem for Linux)")
    print("   - Use a Linux environment")
    print("   - Use Docker with CUDA support")

if PYTORCH_AVAILABLE:
    print("\n[2/3] Checking parallel_llm package...")
    try:
        # Updated imports
        from parallel_llm.core import DiffusionTransformer, ModelConfig
        from parallel_llm.inference import ParallelGenerator, GenerationConfig
        PARALLEL_LLM_AVAILABLE = True
        print("✅ parallel_llm package imported successfully")
    except ImportError as e:
        print("❌ parallel_llm package not available:")
        print(f"   Error: {e}")
        print("   Please install the package: pip install -e .")
else:
    print("\n[2/3] Skipping parallel_llm check (PyTorch not available)")

print("\n[3/3] Configuration check...")
if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
    print("✅ All dependencies available - example can run")
else:
    print("❌ Dependencies missing - example cannot run on this system")
    print("\nTo run this example:")
    print("1. Use Linux or WSL environment")
    print("2. Install CUDA 12.1+")
    print("3. Install PyTorch: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
    print("4. Install transformers: pip install transformers")
    print("5. Install parallel-llm: pip install -e .")

def main():
    if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
        print("\n🚀 Running actual inference...")

        # 1. Load Model
        # In a real scenario, load from a checkpoint:
        # model = DiffusionTransformer.from_pretrained("./checkpoints/unimodal")

        # For demo, initialize fresh
        config = ModelConfig(vocab_size=50257, hidden_size=1024, num_hidden_layers=12)
        model = DiffusionTransformer(config)
        model.eval()

        if torch.cuda.is_available():
            model = model.cuda()
            print("✅ Using CUDA GPU")
        else:
            print("ℹ️ Using CPU (GPU not available)")

        # 2. Configure Generation
        gen_config = GenerationConfig(
            max_new_tokens=128,       # Generate 128 tokens
            num_parallel_tokens=64,   # Generate 64 tokens simultaneously per step
            num_refinement_steps=5,   # Diffusion refinement steps
            temperature=0.8,
            top_k=50,
            use_adaptive_refinement=True # Stop early if confidence is high
        )

        # 3. Initialize Generator
        generator = ParallelGenerator(
            model=model,
            config=gen_config,
            use_kv_cache=True,
            use_cuda_graphs=True if torch.cuda.is_available() else False
        )

        # 4. Prepare Input
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        prompt = "Artificial intelligence is transforming the world by"
        input_ids = tokenizer.encode(prompt, return_tensors="pt")

        if torch.cuda.is_available():
            input_ids = input_ids.cuda()

        print(f"Prompt: {prompt}")
        print("Generating...")

        # 5. Generate
        with torch.no_grad():
            # Returns generated token IDs
            output_ids = generator.generate(input_ids)

        # 6. Decode
        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        print("\nGenerated Text:")
        print("-" * 20)
        print(output_text)
        print("-" * 20)
        print("✅ Inference completed successfully!")
    else:
        print("\n📋 Example structure demonstration:")
        print("This example would perform the following steps:")
        print("1. Load/create a DiffusionTransformer model with GPT-2 config")
        print("2. Configure generation with parallel token generation")
        print("3. Initialize ParallelGenerator with KV cache and CUDA graphs")
        print("4. Tokenize input prompt using GPT-2 tokenizer")
        print("5. Generate 128 tokens with 64 parallel tokens per step")
        print("6. Decode and display the generated text")
        print("\nTo run this example, use a Linux environment with CUDA support.")

if __name__ == "__main__":
    main()
