"""
Simple Interactive Chat - No advanced dependencies
Works even with library issues
"""
import sys
import os
sys.path.insert(0, os.path.abspath('.'))

import torch
from transformers import AutoTokenizer
from parallel_llm.core import DiffusionTransformer, ModelConfig
from parallel_llm.inference import ParallelGenerator, GenerationConfig

def simple_chat():
    """Simplified chat interface"""
    print("="*60)
    print("🤖 Parallel-LLM Simple Chat")
    print("="*60)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}\n")
    
    # Load tokenizer
    print("[1/3] Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    
    # Create model
    print("[2/3] Loading model...")
    config = ModelConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        use_flash_attention=False,  # Disabled for compatibility
    )
    
    model = DiffusionTransformer(config)
    
    # Try to load checkpoint
    checkpoint_path = "checkpoints/unimodal_wikitext/checkpoint-1000/model.pt"
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint: {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            print("✓ Checkpoint loaded")
        except:
            print("⚠ Using random weights")
    else:
        print("⚠ No checkpoint found - using random weights")
    
    model.to(device)
    model.eval()
    
    # Configure generation (OPTIMIZED FOR SPEED)
    print("[3/3] Configuring generation...")
    gen_config = GenerationConfig(
        max_new_tokens=40,  # Shorter for speed
        num_refinement_steps=2,  # Reduced for speed
        temperature=0.9,
        repetition_penalty=1.5,  # Strong penalty
        top_k=30,
        use_torch_compile=False
    )
    
    generator = ParallelGenerator(
        model=model,
        config=gen_config,
        use_kv_cache=False,
        use_cuda_graphs=False
    )
    
    print("\n" + "="*60)
    print("💬 Chat Ready! (Type 'quit' to exit)")
    print("="*60)
    print("⚠ Note: Model has minimal training - expect gibberish")
    print("For better quality, train longer or load pretrained weights")
    print("="*60 + "\n")
    
    while True:
        try:
            user_input = input("You: ").strip()
            
            if not user_input or user_input.lower() == 'quit':
                print("\nGoodbye!\n")
                break
            
            # Generate response
            prompt = f"User: {user_input}\nAssistant:"
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
            
            print("Bot: ", end="", flush=True)
            
            with torch.no_grad():
                # Use PARALLEL mode for speed
                output_ids = generator.generate(
                    input_ids,
                    use_autoregressive=False  # FAST parallel mode
                )
            
            full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            response = full_output.split("Assistant:")[-1].strip()
            
            # Clean up
            if "\n" in response:
                response = response.split("\n")[0]
            
            print(response + "\n")
            
        except KeyboardInterrupt:
            print("\n\nGoodbye!\n")
            break
        except Exception as e:
            print(f"Error: {e}\n")

if __name__ == "__main__":
    simple_chat()
