import os

import numpy as np


def _get_shell_index(cls, key):
    """Determine shell index based on distance grouping.

    Parameters
    ----------
    cls : SpinIO
        SpinIO instance
    key : tuple
        (R, i, j) key for exchange parameter

    Returns
    -------
    int
        Shell index (1 for nearest neighbors, 2 for next nearest, etc.)
    """
    if not hasattr(cls, "_distance_shells"):
        # Calculate all distances and group into shells
        distances = []
        for k in cls.exchange_Jdict:
            if cls.distance_dict and k in cls.distance_dict:
                distances.append(cls.distance_dict[k][1])

        if distances:
            # Sort distances and group by rounding to nearest 0.01 Å
            sorted_distances = sorted(set(round(d * 100) for d in distances))
            cls._distance_shells = {}
            for shell_idx, dist_int in enumerate(sorted_distances, 1):
                cls._distance_shells[dist_int / 100] = shell_idx
        else:
            cls._distance_shells = {}

    if cls.distance_dict and key in cls.distance_dict:
        distance = cls.distance_dict[key][1]
        rounded_dist = round(distance * 100)
        return cls._distance_shells.get(rounded_dist / 100, 1)

    return 1


def write_espins(cls, path="TB2J_results/ESPInS"):
    """Write ESPInS format input files.

    Parameters
    ----------
    cls : SpinIO
        SpinIO instance containing exchange parameters
    path : str
        Output directory path
    """
    if not os.path.exists(path):
        os.makedirs(path)

    write_espins_input(cls, os.path.join(path, "espins.in"))


def write_espins_input(cls, fname):
    """Write the main ESPInS input file.

    Parameters
    ----------
    cls : SpinIO
        SpinIO instance
    fname : str
        Output filename
    """
    with open(fname, "w") as myfile:
        # Write unit cell
        myfile.write("! ESPInS input file generated by TB2J\n")
        # myfile.write("! Compatible with ESPInS version > 1.0 \n")
        myfile.write("! The unit cell in angstrom\n")
        myfile.write("Begin Unit_Cell_Cart\n")
        cell = cls.atoms.get_cell()
        for i in range(3):
            myfile.write("     " + " ".join(f"{x:12.8f}" for x in cell[i]) + "\n")
        myfile.write("End Unit_Cell_Cart\n\n")

        # Write atomic positions
        myfile.write("! Atomic positions in reduced coordinates\n")
        myfile.write("Begin Atoms_Frac\n")
        scaled_positions = cls.atoms.get_scaled_positions()
        symbols = cls.atoms.get_chemical_symbols()

        for i, (symbol, pos) in enumerate(zip(symbols, scaled_positions)):
            if cls.index_spin[i] >= 0:  # Only magnetic atoms
                myfile.write(
                    f"  {symbol:<8} {pos[0]:10.7f} {pos[1]:10.7f} {pos[2]:10.7f}    1.00\n"
                )
        myfile.write("End Atoms_Frac\n\n\n")

        # Write temperature settings (default values)
        myfile.write("tem_start          =   1\n")
        myfile.write("tem_end            =   30\n")
        myfile.write("tems_num           =   30\n")
        # myfile.write("!! tems_mode          = man\n")
        # myfile.write("!! tems               = 5.00 10.00 15.00 20.00\n\n")

        # Write Monte Carlo parameters (default values)
        myfile.write("! Pt                 = .True.\n")
        myfile.write("! Pt_steps_swap      = 40\n\n")

        myfile.write("steps_warmup      =       100000\n")
        myfile.write("steps_mc          =       100000\n")
        myfile.write("steps_measure     =           2\n\n")

        myfile.write("initial_sconfig   =        ferro\n")
        myfile.write("mcarlo_mode       =       random\n\n")

        myfile.write("supercell_size    =      10     10     1\n\n")

        # Write Hamiltonian settings
        myfile.write(" ## Hamiltonian\n")
        myfile.write("Ham_bij           = .False.\n")
        myfile.write("Ham_jij_matrix    = .True.\n")
        myfile.write("! We don't have single ion anisotropy, put it to .False.\n")
        myfile.write("Ham_singleion_matrix    = .False.\n\n\n")

        # Write single ion anisotropy matrix if available
        # if cls.has_uniaxial_anistropy:
        #    myfile.write("! Don't put this block\n")
        #    myfile.write("!Begin SingleIon_Matrix\n")
        #    # Placeholder for single ion anisotropy
        #    myfile.write("@Axx=0.000 , Axy=-1.65e-07 , Axz=6e-08 , Ayx=-1.65e-07 , Ayy=7.73e-05 , Ayz=-1.5e-05 , Azx=6e-08 , Azz= 6.8e-05\n")
        #    myfile.write("!End SingleIon_Matrix\n\n")

        # Write exchange parameters
        myfile.write(
            """! f1 is the fractional coordinate of atom i, and f2 is of atom j (rj+Rj). Then jij in eV, sh is index of shell, t1 is index of i, t2 is index of j (counting from 1).  \n
            !The convention of the Hamiltonian is H = - sum_<ij> Jij Si Sj, and i, j are ordererd so that there is no double counting (ij and ji). To convert from TB2J exchange.out (which has both ij and ji), multiply Jij by 2.  \n 
            \n"""
        )

        if cls.has_exchange:
            myfile.write("Begin Jij_parameters\n")

            # Create a list of unique exchange parameters sorted by distance
            exchange_list = []
            # written_keys = set()

            for key, jval in cls.exchange_Jdict.items():
                R, i, j = key

                # Skip if this is the symmetric counterpart of an already written pair
                # symmetric_key = (tuple(-np.array(R)), j, i)
                # if symmetric_key in written_keys:
                #    continue

                # Get distance for sorting
                distance = (
                    cls.distance_dict[key][1]
                    if cls.distance_dict and key in cls.distance_dict
                    else 0.0
                )

                exchange_list.append((distance, key, jval))
                # written_keys.add(key)

            # Sort by distance
            exchange_list.sort(key=lambda x: x[0])

            for distance, key, jval in exchange_list:
                R, i, j = key
                iatom = cls.iatom(i)
                jatom = cls.iatom(j)

                # Get fractional coordinates
                pos_i = cls.atoms.get_scaled_positions()[iatom]
                pos_j = cls.atoms.get_scaled_positions()[jatom]

                # Calculate fractional coordinates for j+R
                pos_jR = pos_j + R
                # Determine shell index based on distance
                shell = _get_shell_index(cls, key)

                myfile.write(
                    f"  f1=    {pos_i[0]:10.6f},    {pos_i[1]:10.6f},    {pos_i[2]:10.6f}:f2=    {pos_jR[0]:10.6f},    {pos_jR[1]:10.6f},    {pos_jR[2]:10.6f}:jij=  {jval*2:10.8f}!:sh=  {shell}!:t1=  {i+1}:t2=  {j+1}\n"
                )

            myfile.write("End Jij_parameters\n\n")

        # Write exchange matrix if available
        if cls.has_exchange:
            myfile.write(
                "! Each matrix element is corresponding to the previous block. The J tensor include the isotropic, anisotropic exchange and DMI. \n"
            )
            myfile.write("Begin Jij_matrix\n")

            # Create a list of unique exchange parameters sorted by distance
            exchange_list = []
            # written_keys = set()

            for key in cls.exchange_Jdict:
                R, i, j = key

                # Skip if this is the symmetric counterpart of an already written pair
                # symmetric_key = (tuple(-np.array(R)), j, i)
                # if symmetric_key in written_keys:
                #    continue

                # Get distance for sorting
                distance = (
                    cls.distance_dict[key][1]
                    if cls.distance_dict and key in cls.distance_dict
                    else 0.0
                )

                exchange_list.append((distance, key))
                # written_keys.add(key)

            # Sort by distance
            exchange_list.sort(key=lambda x: x[0])

            for distance, key in exchange_list:
                R, i, j = key
                # Get full J tensor
                J_tensor = cls.get_J_tensor(i, j, R, Jiso=True, Jani=True, DMI=True) * 2

                # Format matrix elements exactly as in the example
                myfile.write(
                    f"Jxx={J_tensor[0, 0]:.8f}, Jxy={J_tensor[0, 1]:.8f}, Jxz={J_tensor[0, 2]:.8f}, Jyx={J_tensor[1, 0]:.8f}, Jyy={J_tensor[1, 1]:.8f}, Jyz={J_tensor[1, 2]:.8f}, Jzx={J_tensor[2, 0]:.8f}, Jzy={J_tensor[2, 1]:.8f}, Jzz={J_tensor[2, 2]:.8f}\n"
                )

            myfile.write("End Jij_matrix\n\n")

        # Write biquadratic exchange if available
        # if cls.has_biquadratic and cls.biquadratic_Jdict:
        if False:
            myfile.write("! Biquadratic, don't put it. \n")
            myfile.write("Begin Bij_parameters\n")

            # Create a list of unique biquadratic parameters sorted by distance
            biquadratic_list = []
            written_keys = set()

            for key, bval in cls.biquadratic_Jdict.items():
                R, i, j = key

                # Skip if this is the symmetric counterpart of an already written pair
                symmetric_key = (tuple(-np.array(R)), j, i)
                if symmetric_key in written_keys:
                    continue

                # Get distance for sorting
                distance = (
                    cls.distance_dict[key][1]
                    if cls.distance_dict and key in cls.distance_dict
                    else 0.0
                )

                biquadratic_list.append((distance, key, bval))
                written_keys.add(key)

            # Sort by distance
            biquadratic_list.sort(key=lambda x: x[0])

            for distance, key, bval in biquadratic_list:
                R, i, j = key
                iatom = cls.iatom(i)
                jatom = cls.iatom(j)

                # Get fractional coordinates
                pos_i = cls.atoms.get_scaled_positions()[iatom]
                pos_j = cls.atoms.get_scaled_positions()[jatom]

                # Calculate fractional coordinates for j+R
                pos_jR = pos_j + np.dot(R, np.linalg.inv(cls.atoms.get_cell()))

                myfile.write(
                    f"  f1=    {pos_i[0]:10.6f},    {pos_i[1]:10.6f},    {pos_i[2]:10.6f}:f2=    {pos_jR[0]:10.6f},    {pos_jR[1]:10.6f},    {pos_jR[2]:10.6f}:bij= {bval:10.6f} !t1=  {i+1}:t2=  {j+1}\n"
                )

            myfile.write("End Bij_parameters\n")
