"""
Inference script for Multimodal Parallel-LLM
Demonstrates generating text descriptions from images.

KNOWN ISSUE: PyTorch has compatibility issues on Windows.
This script will demonstrate the import structure and provide guidance.
For actual execution, use Linux/WSL with CUDA support.
"""
import sys
import os

print("="*60)
print("Parallel-LLM Multimodal Inference Example")
print("="*60)

# Check Python version
print(f"Python version: {sys.version}")

# Check platform
print(f"Platform: {sys.platform}")

# Add project root to sys.path to allow importing parallel_llm from source
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
print(f"Added to path: {os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))}")

# Try to import dependencies
PYTORCH_AVAILABLE = False
PARALLEL_LLM_AVAILABLE = False

print("\n[1/3] Checking PyTorch availability...")
try:
    import torch
    from transformers import AutoTokenizer, AutoImageProcessor
    from PIL import Image
    import requests
    PYTORCH_AVAILABLE = True
    print("✅ PyTorch, transformers, PIL, and requests imported successfully")
except Exception as e:
    print("❌ PyTorch/transformers/PIL not available:")
    print(f"   Error: {e}")
    print("   This is a known issue on Windows with PyTorch binaries.")
    print("   Solutions:")
    print("   - Use WSL (Windows Subsystem for Linux)")
    print("   - Use a Linux environment")
    print("   - Use Docker with CUDA support")

if PYTORCH_AVAILABLE:
    print("\n[2/3] Checking parallel_llm package...")
    try:
        # Updated imports
        from parallel_llm.core import DiffusionTransformer, MultimodalConfig
        from parallel_llm.inference import ParallelGenerator, GenerationConfig
        PARALLEL_LLM_AVAILABLE = True
        print("✅ parallel_llm package imported successfully")
    except ImportError as e:
        print("❌ parallel_llm package not available:")
        print(f"   Error: {e}")
        print("   Please install the package: pip install -e .")
else:
    print("\n[2/3] Skipping parallel_llm check (PyTorch not available)")

print("\n[3/3] Configuration check...")
if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
    print("✅ All dependencies available - example can run")
else:
    print("❌ Dependencies missing - example cannot run on this system")
    print("\nTo run this example:")
    print("1. Use Linux or WSL environment")
    print("2. Install CUDA 12.1+")
    print("3. Install PyTorch: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
    print("4. Install transformers and PIL: pip install transformers pillow requests")
    print("5. Install parallel-llm: pip install -e .")

def main():
    if PYTORCH_AVAILABLE and PARALLEL_LLM_AVAILABLE:
        print("\n🚀 Running actual multimodal inference...")

        # 1. Load Model
        config = MultimodalConfig(
            vocab_size=1000,      # Small vocab for demo
            vision_encoder="vit", # Use ViT but smaller
            image_size=64,        # Very small images for demo
            fusion_type="cross_attention",
            hidden_size=128,      # Smaller model
            num_hidden_layers=2,  # Fewer layers
            num_diffusion_steps=3 # Fewer steps
        )
        model = DiffusionTransformer(config)
        model.eval()

        if torch.cuda.is_available():
            model = model.cuda()
            print("✅ Using CUDA GPU")
        else:
            print("ℹ️ Using CPU (GPU not available)")

        # 2. Prepare Inputs
        # Load Image
        url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        image = Image.open(requests.get(url, stream=True).raw)

        # Process Image
        image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
        pixel_values = image_processor(images=image, return_tensors="pt").pixel_values

        # Prepare Text Prompt (optional, can be empty for captioning)
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        prompt = "A picture of"
        input_ids = tokenizer.encode(prompt, return_tensors="pt")

        if torch.cuda.is_available():
            pixel_values = pixel_values.cuda()
            input_ids = input_ids.cuda()

        # 3. Configure Generation
        gen_config = GenerationConfig(
            max_new_tokens=64,
            num_parallel_tokens=32,
            temperature=0.7
        )

        # 4. Initialize Generator
        generator = ParallelGenerator(model, gen_config)

        print("Generating caption...")

        # 5. Generate
        # Note: The generator handles multimodal inputs if the model is configured for it
        with torch.no_grad():
            output_ids = generator.generate(
                input_ids,
                pixel_values=pixel_values
            )

        # 6. Decode
        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        print(f"\nResult: {output_text}")
        print("✅ Multimodal inference completed successfully!")
    else:
        print("\n📋 Example structure demonstration:")
        print("This example would perform the following steps:")
        print("1. Load/create a Multimodal DiffusionTransformer with ViT vision encoder")
        print("2. Download and process an image from COCO dataset")
        print("3. Use ViT image processor to extract pixel values")
        print("4. Tokenize text prompt using GPT-2 tokenizer")
        print("5. Generate caption using multimodal parallel generation")
        print("6. Decode and display the generated caption")
        print("\nTo run this example, use a Linux environment with CUDA support.")

if __name__ == "__main__":
    main()
