"""
Load pretrained TinyLlama weights into Parallel-LLM architecture
This enables coherent text generation without extensive training
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from parallel_llm.core import DiffusionTransformer, ModelConfig
import os

def load_pretrained_tinyllama(
    model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    save_path="checkpoints/pretrained_tinyllama",
    device="cuda"
):
    """
    Load pretrained TinyLlama weights and adapt them for Parallel-LLM
    
    Note: This is a weight transfer - the model architecture is different,
    so we initialize our diffusion model and copy compatible weights.
    """
    print("="*60)
    print("🔄 Loading Pretrained TinyLlama Weights")
    print("="*60)
    
    # Load pretrained model
    print(f"\n[1/4] Downloading pretrained model: {model_id}")
    print("This may take a few minutes on first run...")
    pretrained = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
        low_cpu_mem_usage=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    print(f"✓ Loaded {sum(p.numel() for p in pretrained.parameters()) / 1e6:.1f}M parameters")
    
    # Create Parallel-LLM model with matching dimensions
    print("\n[2/4] Creating Parallel-LLM model...")
    config = ModelConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=2048,
        num_hidden_layers=22,
        num_attention_heads=32,
        intermediate_size=5632,
        num_diffusion_steps=10,
        use_flash_attention=True if device == "cuda" else False,
        dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
    )
    
    parallel_model = DiffusionTransformer(config)
    
    print(f"✓ Created Parallel-LLM with {sum(p.numel() for p in parallel_model.parameters()) / 1e6:.1f}M parameters")
    
    # Transfer compatible weights
    print("\n[3/4] Transferring weights...")
    transferred = 0
    skipped = 0
    
    pretrained_state = pretrained.state_dict()
    parallel_state = parallel_model.state_dict()
    
    # Map pretrained weights to parallel model
    weight_map = {
        # Embeddings
        'model.embed_tokens.weight': 'embed_tokens.weight',
        'lm_head.weight': 'lm_head.weight',
        
        # Layer mappings (example for layer 0, repeat for all layers)
        # Note: This is simplified - full mapping would need all 22 layers
    }
    
    # Transfer embeddings
    if 'model.embed_tokens.weight' in pretrained_state:
        parallel_state['embed_tokens.weight'] = pretrained_state['model.embed_tokens.weight']
        transferred += 1
        print("  ✓ Transferred embeddings")
    
    if 'lm_head.weight' in pretrained_state:
        parallel_state['lm_head.weight'] = pretrained_state['lm_head.weight']
        transferred += 1
        print("  ✓ Transferred output head")
    
    # Transfer layer weights
    for layer_idx in range(min(config.num_hidden_layers, 22)):
        layer_prefix_src = f'model.layers.{layer_idx}'
        layer_prefix_dst = f'layers.{layer_idx}'
        
        # Transfer attention weights
        attn_mappings = {
            f'{layer_prefix_src}.self_attn.q_proj.weight': f'{layer});.self_attn.q_proj.weight',
            f'{layer_prefix_src}.self_attn.k_proj.weight': f'{layer_prefix_dst}.self_attn.k_proj.weight',
            f'{layer_prefix_src}.self_attn.v_proj.weight': f'{layer_prefix_dst}.self_attn.v_proj.weight',
            f'{layer_prefix_src}.self_attn.o_proj.weight': f'{layer_prefix_dst}.self_attn.o_proj.weight',
        }
        
        for src_key, dst_key in attn_mappings.items():
            if src_key in pretrained_state and dst_key in parallel_state:
                parallel_state[dst_key] = pretrained_state[src_key]
                transferred += 1
        
        # Transfer MLP weights
        mlp_mappings = {
            f'{layer_prefix_src}.mlp.gate_proj.weight': f'{layer_prefix_dst}.mlp.gate_proj.weight',
            f'{layer_prefix_src}.mlp.up_proj.weight': f'{layer_prefix_dst}.mlp.up_proj.weight',
            f'{layer_prefix_src}.mlp.down_proj.weight': f'{layer_prefix_dst}.mlp.down_proj.weight',
        }
        
        for src_key, dst_key in mlp_mappings.items():
            if src_key in pretrained_state and dst_key in parallel_state:
                parallel_state[dst_key] = pretrained_state[src_key]
                transferred += 1
        
        # Transfer norms
        norm_mappings = {
            f'{layer_prefix_src}.input_layernorm.weight': f'{layer_prefix_dst}.input_layernorm.weight',
            f'{layer_prefix_src}.post_attention_layernorm.weight': f'{layer_prefix_dst}.post_attention_layernorm.weight',
        }
        
        for src_key, dst_key in norm_mappings.items():
            if src_key in pretrained_state and dst_key in parallel_state:
                parallel_state[dst_key] = pretrained_state[src_key]
                transferred += 1
    
    # Transfer final norm
    if 'model.norm.weight' in pretrained_state:
        parallel_state['norm.weight'] = pretrained_state['model.norm.weight']
        transferred += 1
    
    # Load transferred weights
    parallel_model.load_state_dict(parallel_state, strict=False)
    
    print(f"\n  ✓ Transferred {transferred} weight tensors")
    print(f"  ⚠ Skipped {len(parallel_state) - transferred} diffusion-specific weights")
    print("    (These will be fine-tuned during training)")
    
    # Save checkpoint
    print(f"\n[4/4] Saving checkpoint to {save_path}...")
    os.makedirs(save_path, exist_ok=True)
    
    checkpoint = {
        'model_state_dict': parallel_model.state_dict(),
        'config': config.__dict__,
        'step': 0,
        'source': model_id,
        'info': 'Pretrained TinyLlama weights transferred to Parallel-LLM architecture'
    }
    
    torch.save(checkpoint, os.path.join(save_path, 'model.pt'))
    tokenizer.save_pretrained(save_path)
    
    print(f"✓ Saved checkpoint")
    print(f"\nYou can now use this checkpoint for inference:")
    print(f"  python examples/chat_interactive.py --checkpoint {save_path}/model.pt")
    
    return parallel_model, tokenizer

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Load pretrained TinyLlama weights")
    parser.add_argument(
        "--model-id",
        type=str,
        default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        help="HuggingFace model ID"
    )
    parser.add_argument(
        "--save-path",
        type=str,
        default="checkpoints/pretrained_tinyllama",
        help="Path to save converted checkpoint"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to use"
    )
    args = parser.parse_args()
    
    load_pretrained_tinyllama(args.model_id, args.save_path, args.device)
