"""
Copyright (c) 2024, pypolymlp, Atsuto Seko
Copyright (c) 2025, rsspolymlp, Hayato Wakai
"""

import numpy as np
import spglib

from pypolymlp.core.data_format import PolymlpStructure
from pypolymlp.core.interface_vasp import Poscar


class SymCell:
    """Class for using spglib functions."""

    def __init__(
        self,
        poscar_name: str = None,
        st: PolymlpStructure = None,
        symprec: float = 1e-4,
    ):
        """Init method."""
        if poscar_name is not None:
            st = Poscar(poscar_name).structure

        if st.comment is not None:
            self.comment = st.comment
        else:
            self.comment = "Generated by pypolymlp and spglib"

        self.n_types = len(st.n_atoms)
        self.cell = (
            np.array(st.axis).T,
            np.array(st.positions).T,
            np.array(st.types),
        )

        self.element_map = dict()
        for e, t in zip(st.elements, st.types):
            self.element_map[t] = e

        self.symprec = symprec

    def primitive_cell(self) -> PolymlpStructure:
        """Primitive cell."""
        try:
            lattice1, position1, types1 = spglib.find_primitive(
                self.cell, symprec=self.symprec
            )
        except Exception:
            raise TypeError

        type_list = list(range(0, self.n_types))

        position1_tmp, types1_tmp = [], []
        for t_ref in type_list:
            for i, t in enumerate(types1):
                if t == t_ref:
                    position1_tmp.append(position1[i])
                    types1_tmp.append(t)
        position1, types1 = np.array(position1_tmp), types1_tmp

        n_atoms1 = [types1.count(t_ref) for t_ref in type_list]
        elements1 = [self.element_map[t] for t in types1]

        st = PolymlpStructure(
            axis=lattice1.T,
            positions=position1.T,
            n_atoms=n_atoms1,
            elements=elements1,
            types=types1,
            volume=np.linalg.det(lattice1),
            comment=self.comment,
        )
        return st

    def refine_cell(self, standardize_cell=False) -> PolymlpStructure:
        """Refine cell."""
        if standardize_cell == False:
            try:
                lattice1, position1, types1 = spglib.refine_cell(
                    self.cell, symprec=self.symprec
                )
            except:
                raise TypeError
        else:
            try:
                lattice1, position1, types1 = spglib.standardize_cell(
                    self.cell, symprec=self.symprec
                )
            except:
                raise TypeError

        type_list = list(range(0, self.n_types))

        position1_tmp, types1_tmp = [], []
        for t_ref in type_list:
            for i, t in enumerate(types1):
                if t == t_ref:
                    position1_tmp.append(position1[i])
                    types1_tmp.append(t)
        position1, types1 = np.array(position1_tmp), types1_tmp

        n_atoms1 = [types1.count(t_ref) for t_ref in type_list]
        elements1 = [self.element_map[t] for t in types1]

        st = PolymlpStructure(
            axis=lattice1.T,
            positions=position1.T,
            n_atoms=n_atoms1,
            elements=elements1,
            types=types1,
            volume=np.linalg.det(lattice1),
            comment=self.comment,
        )
        return st

    def get_spacegroup(self):
        """Return space group."""
        return spglib.get_spacegroup(self.cell, symprec=self.symprec)

    def get_spacegroup_multiple_prec(self, symprecs=[1e-2, 1e-3, 1e-4, 1e-5]):
        """Return list of space groups using multiple precisions."""
        return [spglib.get_spacegroup(self.cell, symprec=p) for p in symprecs]


def standardize_cell(cell: PolymlpStructure) -> PolymlpStructure:
    """Standardize cell for constructing cell basis."""
    map_elements = dict()
    for t, e in zip(cell.types, cell.elements):
        map_elements[t] = e

    lattice, scaled_positions, types = spglib.standardize_cell(
        (cell.axis.T, cell.positions.T, cell.types),
        to_primitive=False,
    )

    n_atoms, scaled_positions_reorder, types_reorder = [], [], []
    for i in sorted(set(types)):
        ids = np.array(types) == i
        n_atoms.append(np.count_nonzero(ids))
        scaled_positions_reorder.extend(scaled_positions[ids])
        types_reorder.extend(np.array(types)[ids])
    scaled_positions_reorder = np.array(scaled_positions_reorder)
    elements = [map_elements[t] for t in types_reorder]

    cell_standardized = PolymlpStructure(
        axis=lattice.T,
        positions=scaled_positions_reorder.T,
        n_atoms=n_atoms,
        elements=elements,
        types=types_reorder,
    )
    return cell_standardized


def get_symmetry_dataset(cell: PolymlpStructure):
    """Return symmetry dataset."""
    spg_info = spglib.get_symmetry_dataset((cell.axis.T, cell.positions.T, cell.types))
    return spg_info


def _normalize_vector(vec: np.ndarray) -> np.ndarray:
    """Normalize a vector."""
    return vec / np.linalg.norm(vec)


def construct_basis_cell(
    cell: PolymlpStructure,
    verbose: bool = False,
) -> tuple[np.ndarray, PolymlpStructure]:
    """Generate a basis set for axis matrix.
    basis (row): In the order of ax, bx, cx, ay, by, cy, az, bz, cz
    """
    cell_copy = standardize_cell(cell)
    spg_info = get_symmetry_dataset(cell_copy)
    spg_num = spg_info["number"]
    if verbose:
        if len(cell_copy.types) != len(cell.types):
            print("Number of atoms changed by standardization.", flush=True)
        print("Space group:", spg_info["international"], spg_num, flush=True)

    if spg_num >= 195:
        if verbose:
            print("Crystal system: Cubic", flush=True)
        basis = np.zeros((9, 1))
        basis[:, 0] = _normalize_vector([1, 0, 0, 0, 1, 0, 0, 0, 1])
    elif spg_num >= 168 and spg_num <= 194:
        if verbose:
            print("Crystal system: Hexagonal", flush=True)
        basis = np.zeros((9, 2))
        basis[:, 0] = _normalize_vector([1, -0.5, 0, 0, np.sqrt(3) / 2, 0, 0, 0, 0])
        basis[8, 1] = 1.0
    elif spg_num >= 143 and spg_num <= 167:
        if "P" in spg_info["international"]:
            if verbose:
                print("Crystal system: Trigonal (Hexagonal)", flush=True)
            basis = np.zeros((9, 2))
            basis[:, 0] = _normalize_vector([1, -0.5, 0, 0, np.sqrt(3) / 2, 0, 0, 0, 0])
            basis[8, 1] = 1.0
        else:
            if verbose:
                print("Crystal system: Trigonal (Rhombohedral)", flush=True)
            basis = np.zeros((9, 2))
            basis[:, 0] = _normalize_vector([1, -0.5, 0, 0, np.sqrt(3) / 2, 0, 0, 0, 0])
            basis[8, 1] = 1.0
    elif spg_num >= 75 and spg_num <= 142:
        if verbose:
            print("Crystal system: Tetragonal", flush=True)
        basis = np.zeros((9, 2))
        basis[:, 0] = _normalize_vector([1, 0, 0, 0, 1, 0, 0, 0, 0])
        basis[8, 1] = 1.0
    elif spg_num >= 16 and spg_num <= 74:
        if verbose:
            print("Crystal system: Orthorhombic", flush=True)
        basis = np.zeros((9, 3))
        basis[0, 0] = 1.0
        basis[4, 1] = 1.0
        basis[8, 2] = 1.0
    elif spg_num >= 3 and spg_num <= 15:
        if verbose:
            print("Crystal system: Monoclinic", flush=True)
        basis = np.zeros((9, 4))
        basis[0, 0] = 1.0
        basis[4, 1] = 1.0
        basis[8, 2] = 1.0
        basis[2, 3] = 1.0
    else:
        if verbose:
            print("Crystal system: Triclinic", flush=True)
        basis = np.eye(9)
    return basis, cell_copy
