"""
Interactive Chat Interface for Parallel-LLM
Allows real-time conversation with the trained model
"""
import sys
import os
import torch
from transformers import AutoTokenizer
from parallel_llm.core import DiffusionTransformer, ModelConfig
from parallel_llm.inference import ParallelGenerator, GenerationConfig
import time

def load_model(checkpoint_path=None, device="cuda"):
    """Load model from checkpoint or create new one"""
    print("="*60)
    print("🤖 Parallel-LLM Interactive Chat")
    print("="*60)
    
    # Load tokenizer
    model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    print(f"\n[1/3] Loading tokenizer from {model_id}...")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Configure model
    print("\n[2/3] Initializing model...")
    config = ModelConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        num_diffusion_steps=10,
        use_flash_attention=True if device == "cuda" else False,
        dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
    )
    
    model = DiffusionTransformer(config)
    
    # Load checkpoint if provided
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}...")
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"✓ Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
            else:
                model.load_state_dict(checkpoint)
                print("✓ Loaded checkpoint")
        except Exception as e:
            print(f"⚠ Failed to load checkpoint: {e}")
            print("Using randomly initialized model...")
    else:
        print("⚠ No checkpoint provided - using random weights")
        print("Output will be incoherent. Train model first or provide checkpoint path.")
    
    model.to(device)
    model.eval()
    
    # Configure fast generation
    print("\n[3/3] Configuring inference...")
    gen_config = GenerationConfig(
        max_new_tokens=50,  # Faster responses
        num_refinement_steps=3,  # Reduced for speed
        temperature=0.8,
        top_k=40,
        top_p=0.9,
        repetition_penalty=1.3,  # Stronger penalty
        use_torch_compile=False  # Disabled for compatibility
    )
    
    generator = ParallelGenerator(
        model=model,
        config=gen_config,
        use_kv_cache=False,  # Disabled for simplicity
        use_cuda_graphs=False  # Disabled for compatibility
    )
    
    return tokenizer, generator, device

def chat_loop(tokenizer, generator, device, use_parallel=True):
    """Interactive chat loop"""
    print("\n" + "="*60)
    print("💬 Chat Interface Ready!")
    print("="*60)
    print("Mode: " + ("Parallel (Fast)" if use_parallel else "Autoregressive (Quality)"))
    print("\nCommands:")
    print("  /quit or /exit - Exit chat")
    print("  /clear - Clear conversation history")
    print("  /mode - Toggle generation mode")
    print("  /help - Show this help")
    print("="*60)
    
    conversation_history = []
    
    while True:
        try:
            # Get user input
            user_input = input("\n👤 You: ").strip()
            
            if not user_input:
                continue
            
            # Handle commands
            if user_input.lower() in ['/quit', '/exit']:
                print("\n👋 Goodbye!\n")
                break
            
            if user_input.lower() == '/clear':
                conversation_history = []
                print("🗑️  Conversation cleared!")
                continue
            
            if user_input.lower() == '/mode':
                use_parallel = not use_parallel
                mode = "Parallel (Fast)" if use_parallel else "Autoregressive (Quality)"
                print(f"🔄 Switched to {mode} mode")
                continue
            
            if user_input.lower() == '/help':
                print("\nCommands:")
                print("  /quit or /exit - Exit chat")
                print("  /clear - Clear conversation history")
                print("  /mode - Toggle generation mode")
                print("  /help - Show this help")
                continue
            
            # Add to conversation history
            conversation_history.append(f"User: {user_input}")
            
            # Build prompt with recent history (last 3 exchanges)
            recent_history = conversation_history[-6:]  # Last 3 user+bot pairs
            prompt = "\n".join(recent_history) + f"\nAssistant:"
            
            # Tokenize
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
            
            # Generate response
            print("🤖 Assistant: ", end="", flush=True)
            start_time = time.time()
            
            with torch.no_grad():
                output_ids = generator.generate(
                    input_ids,
                    use_autoregressive=not use_parallel  # Toggle mode
                )
            
            # Decode
            full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            
            # Extract only the new response (after "Assistant:")
            response = full_output.split("Assistant:")[-1].strip()
            
            # Stop at newline or next turn
            if "\nUser:" in response:
                response = response.split("\nUser:")[0].strip()
            if "\n\n" in response:
                response = response.split("\n\n")[0].strip()
            
            elapsed = time.time() - start_time
            
            print(response)
            print(f"\n⏱️  Generated in {elapsed:.2f}s", end="")
            
            # Add to history
            conversation_history.append(f"Assistant: {response}")
            
        except KeyboardInterrupt:
            print("\n\n👋 Chat interrupted. Goodbye!\n")
            break
        except Exception as e:
            print(f"\n❌ Error: {e}")
            print("Continuing...\n")

def main():
    # Parse command line arguments
    import argparse
    parser = argparse.ArgumentParser(description="Interactive chat with Parallel-LLM")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default="checkpoints/unimodal_wikitext/checkpoint-1000/model.pt",
        help="Path to model checkpoint"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to use (cuda/cpu)"
    )
    parser.add_argument(
        "--parallel",
        action="store_true",
        help="Use parallel generation (faster but lower quality)"
    )
    args = parser.parse_args()
    
    print(f"Using device: {args.device}")
    
    # Load model
    tokenizer, generator, device = load_model(args.checkpoint, args.device)
    
    # Start chat
    chat_loop(tokenizer, generator, device, use_parallel=args.parallel)

if __name__ == "__main__":
    main()
