from dataclasses import dataclass, fields

import numpy as np

from boltzgen.data import const
from boltzgen.data.const import (
    prot_letter_to_token,
    protein_letters_3to1_extended,
    nucleic_letters_3to1_extended,
)
from boltzgen.data.data import (
    Token,
    TokenBond,
    Tokenized,
    Structure,
    convert_ccd,
)


def tokendata_to_tuple(token):
    return tuple(getattr(token, f.name) for f in fields(token))


@dataclass
class TokenData:
    """TokenData datatype."""

    token_idx: int
    atom_idx: int
    atom_num: int
    res_idx: int
    res_type: int
    res_name: str
    sym_id: int
    asym_id: int
    entity_id: int
    mol_type: int
    center_idx: int
    disto_idx: int
    center_coords: np.ndarray
    disto_coords: np.ndarray
    resolved_mask: bool
    disto_mask: bool
    modified: bool
    frame_rot: np.ndarray
    frame_t: np.ndarray
    frame_mask: bool
    cyclic_period: int
    is_standard: bool
    design: bool
    binding_type: int
    structure_group: int
    ccd: np.ndarray
    target_msa_mask: bool
    design_ss_mask: bool
    feature_asym_id: int
    feature_res_idx: int


def compute_frame(
    n: np.ndarray,
    ca: np.ndarray,
    c: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    """Compute the frame for a residue.

    Parameters
    ----------
    n : np.ndarray
        The N atom.
    ca : np.ndarray
        The C atom.
    c : np.ndarray
        The CA atom.

    Returns
    -------
    np.ndarray
        The frame.

    """
    v1 = c - ca
    v2 = n - ca
    e1 = v1 / (np.linalg.norm(v1) + 1e-10)
    u2 = v2 - e1 * np.dot(e1.T, v2)
    e2 = u2 / (np.linalg.norm(u2) + 1e-10)
    e3 = np.cross(e1, e2)
    rot = np.column_stack([e1, e2, e3])
    t = ca
    return rot, t


def map_modified_residue_to_res_type(chain, res, map_to_closest_residue=False):
    if map_to_closest_residue and (
        res["name"] in protein_letters_3to1_extended
        or res["name"] in nucleic_letters_3to1_extended
    ):
        modified_letter = (
            protein_letters_3to1_extended[res["name"]]
            if res["name"] in protein_letters_3to1_extended
            else nucleic_letters_3to1_extended[res["name"]]
        )

        if chain["mol_type"] == const.chain_type_ids["PROTEIN"]:
            token_type_name = prot_letter_to_token[modified_letter]
        elif chain["mol_type"] == const.chain_type_ids["DNA"]:
            token_type_name = const.dna_letter_to_token[modified_letter]
        elif chain["mol_type"] == const.chain_type_ids["RNA"]:
            token_type_name = const.rna_letter_to_token[modified_letter]
        else:
            msg = "Only polymers should be present here"
            raise Exception(msg)

        res_id = const.token_ids[token_type_name]

    else:
        unk_token = (
            const.unk_token["DNA"]
            if chain["mol_type"] == const.chain_type_ids["DNA"]
            else (
                const.unk_token["RNA"]
                if chain["mol_type"] == const.chain_type_ids["RNA"]
                else const.unk_token["PROTEIN"]
            )
        )
        res_id = const.token_ids[unk_token]

    return res_id


class Tokenizer:
    """Tokenize an input structure for training."""

    def __init__(
        self,
        atomize_modified_residues: bool = False,
        map_to_closest_residue: bool = False,
    ) -> None:
        """Initialize the Tokenizer.

        Parameters
        ----------
        atomize_modified_residues : bool
            Whether to atomize modified residues.
        map_to_closest_residue : bool
            Whether to map modified residues to the closest residue.

        """
        self.atomize_modified_residues = atomize_modified_residues
        self.map_to_closest_residue = map_to_closest_residue

    def tokenize(
        self,
        struct: Structure,
        inverse_fold: bool = False,
    ) -> Tokenized:  # noqa: C901, PLR0915
        """Tokenize the input data.

        Parameters
        ----------
        struct : Structure
            The input structure.

        Returns
        -------
        Tokenized
            The tokenized data.

        """
        # Create token data
        token_data = []

        # Keep track of atom_idx to token_idx
        token_idx = 0
        atom_to_token = {}
        token_to_res = []

        # Filter to valid chains only
        chains = struct.chains[struct.mask]

        # Ensemble atom id start in coords table.
        # For cropper and other operations, harcoded to 0th conformer.
        offset = struct.ensemble[0]["atom_coord_idx"]

        for chain in chains:
            # Get residue indices
            res_start = chain["res_idx"]
            res_end = chain["res_idx"] + chain["res_num"]
            is_protein = chain["mol_type"] == const.chain_type_ids["PROTEIN"]

            for res_index_local, res in enumerate(struct.residues[res_start:res_end]):
                res_index_global = res_index_local + res_start

                # Get atom indices
                atom_start = res["atom_idx"]
                atom_end = res["atom_idx"] + res["atom_num"]

                # Standard residues are tokens
                if res["is_standard"]:
                    # Get center and disto atoms
                    center = struct.atoms[res["atom_center"]]
                    disto = struct.atoms[res["atom_disto"]]

                    # Token is present if centers are
                    is_present = res["is_present"] & center["is_present"]
                    is_disto_present = res["is_present"] & disto["is_present"]

                    # Apply chain transformation
                    c_coords = struct.coords[offset + res["atom_center"]]["coords"]
                    d_coords = struct.coords[offset + res["atom_disto"]]["coords"]

                    # If protein, compute frame, only used for templates
                    frame_rot = np.eye(3).flatten()
                    frame_t = np.zeros(3)
                    frame_mask = False

                    if is_protein:
                        # Get frame atoms
                        atom_st = res["atom_idx"]
                        atom_en = res["atom_idx"] + res["atom_num"]
                        atoms = struct.atoms[atom_st:atom_en]

                        # Atoms are always in the order N, CA, C
                        atom_n = atoms[0]
                        atom_ca = atoms[1]
                        atom_c = atoms[2]

                        # Compute frame and mask
                        frame_mask = atom_ca["is_present"]
                        frame_mask &= atom_c["is_present"]
                        frame_mask &= atom_n["is_present"]
                        frame_mask = bool(frame_mask)
                        if frame_mask and not inverse_fold:
                            frame_rot, frame_t = compute_frame(
                                atom_n["coords"],
                                atom_ca["coords"],
                                atom_c["coords"],
                            )
                            frame_rot = frame_rot.flatten()

                    # Create token
                    token = TokenData(
                        token_idx=token_idx,
                        atom_idx=res["atom_idx"],
                        atom_num=res["atom_num"],
                        res_idx=res["res_idx"],
                        res_type=res["res_type"],
                        res_name=res["name"],
                        sym_id=chain["sym_id"],
                        asym_id=chain["asym_id"],
                        entity_id=chain["entity_id"],
                        mol_type=chain["mol_type"],
                        center_idx=res["atom_center"],
                        disto_idx=res["atom_disto"],
                        center_coords=c_coords,
                        disto_coords=d_coords,
                        resolved_mask=is_present,
                        disto_mask=is_disto_present,
                        modified=False,
                        frame_rot=frame_rot,
                        frame_t=frame_t,
                        frame_mask=frame_mask,
                        cyclic_period=chain["cyclic_period"],
                        is_standard=True,
                        design=False,
                        binding_type=const.binding_type_ids["UNSPECIFIED"],
                        structure_group=0,
                        ccd=convert_ccd(res["name"]),
                        target_msa_mask=0,
                        design_ss_mask=0,
                        feature_asym_id=chain["asym_id"],
                        feature_res_idx=res["res_idx"],
                    )
                    token_data.append(tokendata_to_tuple(token))

                    # Update atom_idx to token_idx
                    for atom_idx in range(atom_start, atom_end):
                        atom_to_token[atom_idx] = token_idx
                    # Update token_idx to res_idx
                    token_to_res.append(res_index_global)
                    token_idx += 1

                # Non-standard are tokenized per atom
                elif (
                    chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]
                    or self.atomize_modified_residues
                ):
                    # We use the unk protein token as res_type
                    unk_token = const.unk_token["PROTEIN"]
                    unk_id = const.token_ids[unk_token]

                    # Get atom coordinates
                    atom_data = struct.atoms[atom_start:atom_end]
                    atom_coords = struct.coords[
                        offset + atom_start : offset + atom_end
                    ]["coords"]

                    # Tokenize each atom
                    for i, atom in enumerate(atom_data):
                        # Token is present if atom is
                        is_present = res["is_present"] & atom["is_present"]
                        index = atom_start + i

                        # Create token
                        token = TokenData(
                            token_idx=token_idx,
                            atom_idx=index,
                            atom_num=1,
                            res_idx=res["res_idx"],
                            res_type=unk_id,
                            res_name=res["name"],
                            sym_id=chain["sym_id"],
                            asym_id=chain["asym_id"],
                            entity_id=chain["entity_id"],
                            mol_type=chain["mol_type"],
                            center_idx=index,
                            disto_idx=index,
                            center_coords=atom_coords[i],
                            disto_coords=atom_coords[i],
                            resolved_mask=is_present,
                            disto_mask=is_present,
                            modified=chain["mol_type"]
                            != const.chain_type_ids["NONPOLYMER"],
                            frame_rot=np.eye(3).flatten(),
                            frame_t=np.zeros(3),
                            frame_mask=False,
                            cyclic_period=chain["cyclic_period"],
                            is_standard=False,
                            design=False,
                            binding_type=const.binding_type_ids["UNSPECIFIED"],
                            structure_group=0,
                            ccd=convert_ccd(res["name"]),
                            target_msa_mask=0,
                            design_ss_mask=0,
                            feature_asym_id=chain["asym_id"],
                            feature_res_idx=res["res_idx"],
                        )
                        token_data.append(tokendata_to_tuple(token))

                        # Update atom_idx to token_idx
                        atom_to_token[index] = token_idx
                        # Update token_idx to res_idx
                        token_to_res.append(res_index_global)
                        token_idx += 1

                else:
                    res_type = map_modified_residue_to_res_type(
                        chain, res, self.map_to_closest_residue
                    )

                    # Get center and disto atoms
                    center = struct.atoms[res["atom_center"]]
                    disto = struct.atoms[res["atom_disto"]]

                    # Token is present if centers are
                    is_present = res["is_present"] & center["is_present"]
                    is_disto_present = res["is_present"] & disto["is_present"]

                    # Apply chain transformation
                    c_coords = struct.coords[offset + res["atom_center"]]["coords"]
                    d_coords = struct.coords[offset + res["atom_disto"]]["coords"]

                    # Create token
                    token = TokenData(
                        token_idx=token_idx,
                        atom_idx=res["atom_idx"],
                        atom_num=res["atom_num"],
                        res_idx=res["res_idx"],
                        res_type=res_type,
                        res_name=res["name"],
                        sym_id=chain["sym_id"],
                        asym_id=chain["asym_id"],
                        entity_id=chain["entity_id"],
                        mol_type=chain["mol_type"],
                        center_idx=res["atom_center"],
                        disto_idx=res["atom_disto"],
                        center_coords=c_coords,
                        disto_coords=d_coords,
                        resolved_mask=is_present,
                        disto_mask=is_disto_present,
                        modified=True,
                        frame_rot=np.eye(3).flatten(),
                        frame_t=np.zeros(3),
                        frame_mask=False,
                        cyclic_period=chain["cyclic_period"],
                        is_standard=False,
                        design=False,
                        binding_type=const.binding_type_ids["UNSPECIFIED"],
                        structure_group=0,
                        ccd=convert_ccd(res["name"]),
                        target_msa_mask=0,
                        design_ss_mask=0,
                        feature_asym_id=chain["asym_id"],
                        feature_res_idx=res["res_idx"],
                    )
                    token_data.append(tokendata_to_tuple(token))

                    # Update atom_idx to token_idx
                    for atom_idx in range(atom_start, atom_end):
                        atom_to_token[atom_idx] = token_idx
                    # Update token_idx to res_idx
                    token_to_res.append(res_index_global)
                    token_idx += 1

        # Create token bonds
        token_bonds = []

        # Add bonds for ligands
        for bond in struct.bonds:
            if (
                bond["atom_1"] not in atom_to_token
                or bond["atom_2"] not in atom_to_token
            ):
                continue
            token_bond = (
                atom_to_token[bond["atom_1"]],
                atom_to_token[bond["atom_2"]],
                bond["type"] + 1,
            )
            token_bonds.append(token_bond)

        # Consider adding missing bond for modified residues to standard?
        # I'm not sure it's necessary because the bond is probably always
        # the same and the model can use the residue indices to infer it
        token_data = np.array(token_data, dtype=Token)
        token_bonds = np.array(token_bonds, dtype=TokenBond)
        token_to_res = np.array(token_to_res)
        tokenized = Tokenized(token_data, token_bonds, struct, token_to_res)
        return tokenized
