import os
import json
import torch
from typing import List, Tuple, Optional

from transformers.configuration_utils import PretrainedConfig
from llm_layer_collector.auto.auto_rms import AutoRMSNorm
from llm_layer_collector.auto.auto_layer import AutoDecoderLayer

from language_pipes.util import size_of_tensor, tensor_hash
from language_pipes.job_manager.enums import ModelPartType
from language_pipes.util.meta import MetaComputed

def get_size_of_layer(config: PretrainedConfig, layer_idx: int) -> Tuple[float, str]:
    lyr = AutoDecoderLayer(config, layer_idx).to(dtype=torch.float16)
    tensors = [
        lyr.self_attn.q_proj.weight,
        lyr.self_attn.k_proj.weight,
        lyr.self_attn.v_proj.weight,
        lyr.self_attn.o_proj.weight,
        lyr.mlp.gate_proj.weight,
        lyr.mlp.up_proj.weight,
        lyr.mlp.down_proj.weight
    ]
    hash = tensor_hash(lyr.self_attn.q_proj.weight)
    return sum([size_of_tensor(t) for t in tensors]), hash

def get_avg_layer_size(model_path: str) -> Tuple[int, List[str]]:
    if not os.path.exists(model_path):
        print(f'Model {model_path} not found')
        return -1
    config_file = os.path.join(model_path, 'config.json')
    if not os.path.exists(config_file):
        print(f'Config file not found for model {model_path}')
        return -1
    with open(config_file) as f:
        config = PretrainedConfig.from_dict(json.load(f))

    total_size = 0
    layer_hashes = []
    for size, hash in [get_size_of_layer(config, i) for i in range(config.num_hidden_layers)]:
        total_size += size
        layer_hashes.append(hash)
    
    avg_layer_size = total_size / config.num_hidden_layers
    layer_hashes = layer_hashes
    
    return avg_layer_size, layer_hashes

def data_of_type(typ: ModelPartType, model_path: str) -> Tuple[float, str]:
    with open(os.path.join(model_path, 'config.json')) as f:
        config = PretrainedConfig.from_dict(json.load(f))
    
    size = 0
    hash = ''
    if typ == ModelPartType.EMBED:
        e  = torch.nn.Embedding(config.vocab_size, config.hidden_size).to(dtype=torch.float16)
        size = size_of_tensor(e.weight)
        hash = tensor_hash(e.weight)
        
    if typ == ModelPartType.NORM:
        n = AutoRMSNorm(config).to(dtype=torch.float16)
        size = size_of_tensor(n.weight)
        hash = tensor_hash(n.weight)
    if typ == ModelPartType.HEAD:
        h = torch.nn.Linear(config.hidden_size, config.vocab_size).to(dtype=torch.float16)
        size = size_of_tensor(h.weight)
        hash = tensor_hash(h.weight)
    
    return size, hash

def get_computed_data(model_path: str):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f'Model {model_path} not found')
    computed_path = os.path.join(model_path, 'computed.json')
    if os.path.exists(computed_path):
        with open(computed_path) as f:
            return json.load(f)
        
    computed = { }
    model_path = os.path.join(model_path, 'data')
    size, hash = data_of_type(ModelPartType.EMBED, model_path)
    computed['embed_size'] = size
    computed['embed_hash'] = hash
    size, hash = data_of_type(ModelPartType.NORM, model_path)
    size, hash = data_of_type(ModelPartType.HEAD, model_path)
    computed['head_size'] = size
    computed['head_hash'] = hash
    size, hash = get_avg_layer_size(model_path)
    computed['avg_layer_size'] = size
    computed['layer_hashes'] = hash

    with open(computed_path, 'w') as f:
        json.dump(computed, f)

    return computed

class ComputedData:
    embed_size: int
    head_size: int
    avg_layer_size: int
    
    embed_hash: str
    head_hash: str
    layer_hashes: List[str]

    def __init__(self, model_dir: Optional[str]):
        if model_dir is None:
            return
        data = get_computed_data(model_dir)
        self.embed_size = data['embed_size']
        self.head_size = data['head_size']
        self.avg_layer_size = data['avg_layer_size']
        self.embed_hash = data['embed_hash']
        self.head_hash = data['head_hash']
        self.layer_hashes = data['layer_hashes']

    def to_json(self):
        return {
            'embed_size': self.embed_size,
            'head_size': self.head_size,
            'avg_layer_size': self.avg_layer_size,
            'embed_hash': self.embed_hash,
            'head_hash': self.head_hash,
            'layer_hashes': self.layer_hashes
        }

    def to_meta(self):
        return MetaComputed(
            embed_size=self.embed_size,
            head_size=self.head_size,
            avg_layer_size=self.avg_layer_size,
            embed_hash=self.embed_hash,
            head_hash=self.head_hash,
            layer_hashes=self.layer_hashes
        )
    @staticmethod
    def from_meta(data: MetaComputed) -> 'ComputedData':
        c = ComputedData(None)
        c.embed_size = data.embed_size
        c.embed_hash = data.embed_hash
        c.head_size = data.head_size
        c.head_hash = data.head_hash
        c.avg_layer_size = data.avg_layer_size
        c.layer_hashes = data.layer_hashes
        return c

    @staticmethod
    def from_dict(data: dict) -> 'ComputedData':
        c = ComputedData(None)
        c.embed_size = data['embed_size']
        c.head_size = data['head_size']
        c.avg_layer_size = data['avg_layer_size']
        c.embed_hash = data['embed_hash']
        c.head_hash = data['head_hash']
        c.layer_hashes = data['layer_hashes']
        return c
    
def validate_model(c1: MetaComputed, c2: MetaComputed):
    return c1.embed_hash == c2.embed_hash and c1.head_hash == c2.head_hash and c1.layer_hashes == c2.layer_hashes
