import hashlib
import base64
from collections import defaultdict, OrderedDict
from operator import attrgetter

import spglib
import numpy as np

from matid.utils.segfault_protect import segfault_protect
from matid.utils.exceptions import CellNormalizationError, MatIDError
from matid.data.symmetry_data import CHIRALITY_PRESERVING_EUCLIDEAN_NORMALIZERS
from matid.data.symmetry_data import SPACE_GROUP_INFO, WYCKOFF_SETS
from matid.data import constants
from matid.core.system import System
from matid.symmetry import WyckoffSet
import matid.geometry

from ase import Atoms


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


class SymmetryAnalyzer(object):
    """A base class for getting symmetry related properties of unit cells."""

    def __init__(self, system=None, symmetry_tol=None, min_2d_thickness=1):
        """
        Args:
            system(ASE.Atoms): The system to inspect.
            symmetry_tol(float): The tolerance for the symmetry detection.
            min_2d_thickness(float): The minimum thickness in angstroms for the
                conventional cell that is returned for 2D systems.
        """
        self._original_system = None
        self._analyzed_system = None
        self.min_2d_thickness = min_2d_thickness
        if symmetry_tol is None:
            self.symmetry_tol = constants.SYMMETRY_TOL
        else:
            self.symmetry_tol = symmetry_tol

        self.set_system(system)

    def set_system(self, system):
        """Sets a new system for analysis."""
        self.reset()
        self._original_system = system

        # Analyze whether the system is 2D or not. 2D systems will be first
        # un-symmetrized in the nonperiodic direction by adding sufficient
        # vacuum.

        # Determine if the system has three periodic directions or two.
        pbc = system.get_pbc()
        n_pbc = np.sum(pbc)
        self.n_pbc = n_pbc

        # Regular bulk structures
        if n_pbc == 3:
            self._analyzed_system = system
        elif n_pbc == 2:
            # Get the index of the non-periodic axis
            i_pbc = np.argwhere(pbc == False)[0]  # noqa: E712

            # Before calculating the conventional system, make sure that there
            # is enough vacuum in the periodic direction to remove any
            # translational symmetries that are smaller than the basis vector
            # in the non-periodic direction.
            symmetry_broken_system = system.copy()
            thickness = max(
                5, 3 * matid.geometry.get_thickness(symmetry_broken_system, i_pbc)
            )
            old_cell = symmetry_broken_system.get_cell()
            old_basis = old_cell[i_pbc, :]
            old_basis_len = np.linalg.norm(old_basis)
            old_basis_norm = old_basis / old_basis_len
            new_basis = thickness * old_basis_norm
            old_cell[i_pbc, :] = new_basis
            symmetry_broken_system.set_cell(old_cell)
            self._analyzed_system = symmetry_broken_system
        else:
            raise ValueError(
                "No symmetry routines defined for system that do not have 3D or"
                " 2D periodicity."
            )

    def reset(self):
        """Used to reset all the cached values."""
        self._symmetry_dataset = None

        self._conventional_system = None
        self._conventional_wyckoff_letters = None
        self._conventional_equivalent_atoms = None
        self._conventional_lattice_fit = None

        self._spglib_conventional_system = None
        self._spglib_wyckoff_letters_conventional = None
        self._spglib_equivalent_atoms_conventional = None

        self._spglib_primitive_system = None
        self._spglib_wyckoff_letters_primitive = None
        self._spglib_equivalent_atoms_primitive = None
        self._spglib_primitive_to_original_mapping = None

        self._primitive_system = None
        self._primitive_wyckoff_letters = None
        self._primitive_equivalent_atoms = None

        self._best_transform = None

    def get_material_id(self):
        """Returns a 28-character identifier for this material. The identifier
        is calculated by hashing a set of the symmetry properties found in the
        material, including:

         - Space group number
         - Wyckoff position letters and the species occupied in them
        """
        spg_number = self.get_space_group_number()
        wyckoff_sets = self.get_wyckoff_sets_conventional(False)
        wyckoff_strings = []
        for group in wyckoff_sets:
            element = group.element
            wyckoff_letter = group.wyckoff_letter
            n_atoms = len(group.indices)
            i_string = "{} {} {}".format(element, wyckoff_letter, n_atoms)
            wyckoff_strings.append(i_string)
        wyckoff_string = ", ".join(sorted(wyckoff_strings))
        string = "{} {}".format(spg_number, wyckoff_string)
        if self.n_pbc == 2:
            string = f"2D {string}"

        # Create a SHA-512 hash out of the string
        hash_value = hashlib.sha512()
        hash_value.update(string.encode("utf-8"))

        # Make websafe
        hash_length = 28
        return base64.b64encode(hash_value.digest(), altchars=b"-_")[
            :hash_length
        ].decode("utf-8")

    def get_space_group_number(self):
        """
        Returns:
            int: The space group number.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.number

        return value

    def get_space_group_international_short(self):
        """
        Returns:
            str: The international space group short symbol.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.international

        return value

    def get_hall_symbol(self):
        """
        Returns:
            str: The Hall symbol.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.hall

        return value

    def get_hall_number(self):
        """
        Returns:
            int: The Hall number.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.hall_number

        return value

    def get_point_group(self):
        """Symbol of the crystallographic point group in the Hermann-Mauguin
        notation.

        Returns:
            str: point group symbol
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.pointgroup

        return value

    def get_is_chiral(self):
        """Returns a boolean value that tells if this object is chiral or not
        (achiral). A chiral object has symmetry operations that are all proper,
        i.e. their determinant is +1.

        Returns:
            bool: is the object chiral.
        """
        operations = self.get_symmetry_operations()
        rotations = operations["rotations"]
        chiral = True
        for rotation in rotations:
            determinant = np.linalg.det(rotation)
            if determinant == -1.0:
                return False

        return chiral

    def get_has_free_wyckoff_parameters(self):
        """Tells whether this system has Wyckoff positions with free variables.

        Returns:
            bool: Indicates the presence of Wyckoff positions with free variables.
        """
        space_group = self.get_space_group_number()
        wyckoff_letters = set(self.get_wyckoff_letters_original())
        wyckoff_info = WYCKOFF_SETS[space_group]
        for wyckoff_letter in wyckoff_letters:
            variables = wyckoff_info[wyckoff_letter]["variables"]
            if len(variables) != 0:
                return True
        return False

    def get_crystal_system(self):
        """Get the crystal system based on the space group number. There are
        seven different crystal systems:

            - Triclinic
            - Monoclinic
            - Orthorhombic
            - Tetragonal
            - Trigonal
            - Hexagonal
            - Cubic

        Return:
            str: The name of the crystal system.
        """
        space_group = self.get_space_group_number()
        crystal_system = SPACE_GROUP_INFO[space_group]["crystal_system"]

        return crystal_system

    def get_bravais_lattice(self):
        """Return Bravais lattice in the Pearson notation, where the first
        lowercase letter indicates the crystal system, and the second uppercase
        letter indicates the centring type.

        Crystal system letters:

            - a = triclinic
            - m = monoclinic
            - o = orthorhombic
            - t = tetragonal
            - h = hexagonal and trigonal
            - c = cubic

        Lattice type letters:

            - P = Primitive
            - S (= A or B or C) = One side/face centred
            - I = Body centered
            - R = Rhombohedral centring
            - F = All faces centred

        :param crystal_system: The crystal system
        :param space_group: The space group number.
        :type crystal_system: str
        :type space_group: int

        :return: The Bravais lattice in the Pearson notation.
        :rtype: str
        """

        space_group = self.get_space_group_number()
        if space_group is None:
            return None

        bravais_lattice = SPACE_GROUP_INFO[space_group]["bravais_lattice"]

        # The different one-sided centrings are merged into one letter
        if bravais_lattice[1] in ["A", "B", "C"]:
            bravais_lattice = bravais_lattice[0] + "S"

        return bravais_lattice

    def get_primitive_system(self):
        """Returns a primitive description for this system.

        This description uses a primitive lattice where positions of the
        atoms, and the cell basis vectors are idealized to follow the
        symmetries that were found with the given precision. This means that
        e.g. the volume, density, angles between basis vectors and basis vector
        lengths may have small deviations from the original system.

        Returns:
            ASE.Atoms: The primitive system.
        """
        if self._primitive_system is not None:
            return self._primitive_system

        conv_sys = self.get_conventional_system()
        conv_wyckoff = self.get_wyckoff_letters_conventional()
        conv_equivalent = self.get_equivalent_atoms_conventional()
        space_group_short = self.get_space_group_international_short()

        prim_sys, prim_wyckoff, prim_equivalent = self._get_primitive_system(
            conv_sys, conv_wyckoff, conv_equivalent, space_group_short
        )

        self._primitive_system = prim_sys
        self._primitive_wyckoff_letters = prim_wyckoff
        self._primitive_equivalent_atoms = prim_equivalent

        return self._primitive_system

    def get_conventional_system(self):
        """Used to get the conventional representation of this system.

        This description uses a conventional lattice where positions of the
        atoms, and the cell basis vectors are idealized to follow the
        symmetries that were found with the given precision. This means that
        e.g. the volume, density, angles between basis vectors and basis vector
        lengths may have small deviations from the original system.
        """
        if self._conventional_system is not None:
            return self._conventional_system

        # Determine if the system has three periodic directions or two.
        pbc = self._original_system.get_pbc()
        n_pbc = np.sum(pbc)

        # Regular bulk structures
        if n_pbc == 3:
            spglib_conv_sys = self._get_spglib_conventional_system()

            # Find a proper rigid transformation that produces the best combination
            # of atomic species in the Wyckoff positions.
            space_group = self.get_space_group_number()
            wyckoff_letters = self._get_spglib_wyckoff_letters_conventional()
            equivalent_atoms = self._get_spglib_equivalent_atoms_conventional()
            ideal_sys, ideal_wyckoff = self._find_wyckoff_ground_state(
                space_group, wyckoff_letters, spglib_conv_sys
            )
            ideal_sys = System.from_atoms(ideal_sys)
            ideal_sys.set_equivalent_atoms(equivalent_atoms)
            ideal_sys.set_wyckoff_letters(ideal_wyckoff)

            self._conventional_system = ideal_sys
            self._conventional_wyckoff_letters = ideal_wyckoff
            self._conventional_equivalent_atoms = equivalent_atoms
            ideal_sys.set_pbc(True)
            return ideal_sys
        # 2D materials get a special treatment
        elif n_pbc == 2:
            i_pbc = np.argwhere(pbc == False)[0]  # noqa: E712

            # Get the full 3D conventional system and it's symmetries. It will
            # include some symmetries that have a translational component
            # corresponding to the non-periodic axis, but it does not matter in
            # this case.
            spglib_conv_sys = self._get_spglib_conventional_system()

            # Determine the new non-periodic direction in the normalized cell.
            # The index of the originally non-periodic dimension may not correspond
            # to the one in the normalized system, because the normalized system
            # may use a different coordinate system.
            transformation_matrix = self.get_symmetry_dataset().transformation_matrix
            nonperiodic_axis = None
            prec = 1e-8
            for i_axis, axis in enumerate(transformation_matrix):
                if (
                    abs(axis[i_pbc]) > prec
                    and abs(axis[(i_pbc + 1) % 3]) < prec
                    and abs(axis[(i_pbc + 2) % 3]) < prec
                ):
                    nonperiodic_axis = i_axis
                    break
            if nonperiodic_axis is None:
                raise MatIDError(
                    "Could not detect the non-periodic direction in the normalized "
                    "2D cell."
                )

            # Find a proper rigid transformation that produces the best combination
            # of atomic species in the Wyckoff positions.
            space_group = self.get_space_group_number()
            wyckoff_letters = self._get_spglib_wyckoff_letters_conventional()
            equivalent_atoms = self._get_spglib_equivalent_atoms_conventional()
            ideal_sys, ideal_wyckoff = self._find_wyckoff_ground_state(
                space_group,
                wyckoff_letters,
                spglib_conv_sys,
            )

            # Center the system in the non-periodic direction, also taking
            # periodicity into account. Without the centering the structure may
            # end up being split at the cell boundary. The
            # get_center_of_mass()-function in MatID takes into account
            # periodicity and can produce the correct CM unlike the similar
            # function in ASE.
            ideal_sys.set_pbc(True)  # Needed temprorarily for centering to work
            pbc_cm = matid.geometry.get_center_of_mass(ideal_sys)
            cell_center = 0.5 * np.sum(ideal_sys.get_cell(), axis=0)
            translation = cell_center - pbc_cm
            conv_pbc = np.array([True, True, True])
            conv_pbc[nonperiodic_axis] = False
            translation[conv_pbc] = 0
            ideal_sys.translate(translation)
            ideal_sys.wrap()

            # For the final system we set the correct pbc
            ideal_sys.set_pbc(conv_pbc)

            # Swap the cell axes so that the non-periodic one is always the last
            # basis (=c)
            swap_dim = 2
            for i, periodic in enumerate(ideal_sys.get_pbc()):
                if not periodic:
                    non_periodic_dim = i
                    break
            if non_periodic_dim != swap_dim:
                matid.geometry.swap_basis(ideal_sys, non_periodic_dim, swap_dim)

            # Minimize the cell to only just fit the atoms in the non-periodic
            # direction
            min_conv_cell = matid.geometry.get_minimized_cell(
                ideal_sys, swap_dim, self.min_2d_thickness
            )

            self._conventional_system = min_conv_cell
            self._conventional_wyckoff_letters = ideal_wyckoff
            self._conventional_equivalent_atoms = equivalent_atoms
            return self._conventional_system
        else:
            raise ValueError(
                "The provided system does not have 3 or 2 periodic directions."
            )

    def get_rotations(self):
        """Get the rotational parts of the Seitz matrices that are associated
        with this space group. Each rotational matrix is accompanied by a
        translation with the same index.

        Returns:
            np.ndarray: Rotation matrices.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.rotations

        return value

    def get_translations(self):
        """Get the translational parts of the Seitz matrices that are
        associated with this space group. Each translation is accompanied
        by a rotational matrix with the same index.

        Returns:
            np.ndarray: Translation vectors.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.translations

        return value

    def get_choice(self):
        """
        Returns:
            str: A string specifying the centring, origin and basis vector
            settings.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.choice

        return value

    def get_wyckoff_letters_original(self):
        """
        Returns:
            list of str: Wyckoff letters for the atoms in the original system.
        """
        spglib_wyckoffs = self._get_spglib_wyckoff_letters_original()
        if self._best_transform is None:
            self.get_conventional_system()
        permutations = self._best_transform["permutations"]
        new_wyckoffs = []
        for old_wyckoff in spglib_wyckoffs:
            new_wyckoff = permutations[old_wyckoff]
            new_wyckoffs.append(new_wyckoff)

        return np.array(new_wyckoffs)

    def get_equivalent_atoms_original(self):
        """
        The equivalent atoms are the same as what spglib already outputs, as
        changes in the wyckoff letters will not afect the equivalence:

        Returns:
            list of int: A list that maps each atom into a symmetry equivalent
                set.
        """
        spglib_equivalent_atoms = self._get_spglib_equivalent_atoms_original()
        return spglib_equivalent_atoms

    def get_wyckoff_letters_conventional(self):
        """Get the Wyckoff letters of the atoms in the conventional system.

        Returns:
            list of str: Wyckoff letters.
        """
        if self._conventional_wyckoff_letters is None:
            self.get_conventional_system()
        return self._conventional_wyckoff_letters

    def get_wyckoff_sets_conventional(self, return_parameters=True):
        """Get a list of Wyckoff sets for this system. Wyckoff sets combine
        information about the atoms and their positions at specific Wyckoff
        positions.

        Args:
            return_parameters (bool): Whether to return the value of possible
                free Wyckoff parameters. Set to false if they are not needed,
                as their determination can take some time.

        Returns:
            list of WyckoffSets: A list of :class:`.WyckoffSet` objects for the
            conventional system.
        """
        space_group = self.get_space_group_number()
        conv_sys = self.get_conventional_system()
        wyckoff_letters = self.get_wyckoff_letters_conventional()
        equivalent_atoms = self.get_equivalent_atoms_conventional()
        sets = self._get_wyckoff_sets(
            conv_sys,
            space_group,
            wyckoff_letters,
            equivalent_atoms,
            precision=self.symmetry_tol,
            return_parameters=return_parameters,
        )

        return sets

    def get_equivalent_atoms_conventional(self):
        """List of equivalent atoms in the idealized system.

        Returns:
            list of int: A list that maps each atom into a symmetry equivalent
                set.
        """
        if self._spglib_equivalent_atoms_conventional is None:
            self._get_spglib_equivalent_atoms_conventional()
        return self._spglib_equivalent_atoms_conventional

    def get_wyckoff_letters_primitive(self):
        """Get the Wyckoff letters of the atoms in the primitive system.

        Returns:
            list of str: Wyckoff letters.
        """
        if self._primitive_wyckoff_letters is None:
            self.get_primitive_system()
        return self._primitive_wyckoff_letters

    def get_equivalent_atoms_primitive(self):
        """List of equivalent atoms in the primitive system.

        Returns:
            list of int: A list that maps each atom into a symmetry equivalent
                set.
        """
        if self._primitive_equivalent_atoms is None:
            self.get_primitive_system()
        return self._primitive_equivalent_atoms

    def get_symmetry_dataset(self):
        """Calculates the symmetry dataset with spglib for the given system."""
        if self._symmetry_dataset is not None:
            return self._symmetry_dataset

        description = self._system_to_spglib_description(self._analyzed_system)
        # Spglib has been observed to cause segmentation faults when fed with
        # invalid data, so run in separate process to catch those cases
        try:
            symmetry_dataset = segfault_protect(
                spglib.get_symmetry_dataset, description, self.symmetry_tol
            )
        except RuntimeError:
            raise CellNormalizationError(
                "Segfault in spglib when finding symmetry dataset. Please check "
                " the given cell, scaled positions and atomic numbers."
            )
        if symmetry_dataset is None:
            raise CellNormalizationError("Spglib error when finding symmetry dataset.")

        # Prior to spglib 2.5.0 the dataset is returned as a dictionary: this
        # provides backwards compatibility
        if isinstance(symmetry_dataset, dict):
            symmetry_dataset = AttrDict(symmetry_dataset)

        self._symmetry_dataset = symmetry_dataset

        return symmetry_dataset

    def _get_spglib_conventional_system(self):
        """Returns an idealized description for this material as defined by
        spglib.

        Returns:
            ASE.Atoms: The idealized system as defined by spglib.
        """
        if self._spglib_conventional_system is not None:
            return self._spglib_conventional_system

        dataset = self.get_symmetry_dataset()
        cell = dataset.std_lattice
        pos = dataset.std_positions
        num = dataset.std_types
        spg_conv_sys = self._spglib_description_to_system((cell, pos, num))

        self._spglib_conventional_system = spg_conv_sys
        return spg_conv_sys

    def _get_spglib_wyckoff_letters_original(self):
        """
        Returns:
            list of str: Wyckoff letters for the atoms in the original system.
        """
        dataset = self.get_symmetry_dataset()
        value = np.array(dataset.wyckoffs)

        return value

    def _get_spglib_equivalent_atoms_original(self):
        """
        Returns:
            list of int: A list that maps each atom into a symmetry equivalent
                set.
        """
        dataset = self.get_symmetry_dataset()

        # Must use crystallographic_orbits instead of equivalent atoms. The
        # equivalent atoms reported by spglib are based on the symmetry of the
        # original cell. Equivalence in crystallographic_orbits is instead
        # based on the primitive cell/conventional cell which is what we want.
        value = dataset.crystallographic_orbits

        return value

    def _get_spglib_wyckoff_letters_conventional(self):
        """
        Returns:
            np.array of str: Wyckoff letters for the atoms in the conventioal
            system as defined by spglib.
        """
        if self._spglib_wyckoff_letters_conventional is None:
            wyckoff_letters_primitive = self._get_spglib_wyckoff_letters_primitive()
            dataset = self.get_symmetry_dataset()
            mapping = dataset.std_mapping_to_primitive
            self._spglib_wyckoff_letters_conventional = wyckoff_letters_primitive[
                mapping
            ]
        return self._spglib_wyckoff_letters_conventional

    def _get_spglib_equivalent_atoms_conventional(self):
        """
        Returns:
            np.array of int: List of numbers where the atoms are grouped to
            symmetrically equivalent groups by number.
        """
        if self._spglib_equivalent_atoms_conventional is None:
            equivalent_atoms_primitive = self._get_spglib_equivalent_atoms_primitive()
            dataset = self.get_symmetry_dataset()
            mapping = dataset.std_mapping_to_primitive
            self._spglib_equivalent_atoms_conventional = equivalent_atoms_primitive[
                mapping
            ]

        return self._spglib_equivalent_atoms_conventional

    def _get_spglib_origin_shift(self):
        """The origin shift s that is needed to transform points in the
        original system to the conventional system. The relation between the
        original coordinates and the conventional coordinates is defined by:

        x' = P*x + s

        where x' is a coordinate in the conventional system, P is the
        transformation matrix, x is a coordinate in the original system and s
        is the origin shift.

        Returns:
            3*1 np.ndarray: The shift of the origin as a vector.
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.origin_shift

        return value

    def get_symmetry_operations(self):
        """The symmetry operations of the original structure as rotations and
        translations.

        Returns:
            Dictionary containing an entry for rotations containing a np.array
            with 3x3 matrices for each symmetry operation and an entry
            "translations" containing np.array of translations for each
            symmetry operation.
            3*1 np.ndarray: The shift of the origin as a vector.
        """
        dataset = self.get_symmetry_dataset()
        operations = {
            "rotations": dataset.rotations,
            "translations": dataset.translations,
        }

        return operations

    def _get_spglib_transformation_matrix(self):
        """The transformation matrix P that transforms points in the original
        system to the conventional system. The relation between the original
        coordinates and the conventional coordinates is defined by:

        x' = P*x + s

        where x' is a coordinate in the conventional system, P is the
        transformation matrix, x is a coordinate in the original system and s
        is the origin shift.

        Returns:
            3x3 np.ndarray:
        """
        dataset = self.get_symmetry_dataset()
        value = dataset.transformation_matrix

        return value

    def _get_spglib_primitive_system(self):
        """Returns a primitive description as defined by spglib for this
        system.

        Returns:
            ASE.Atoms: The primitive system.
        """
        if self._spglib_primitive_system is not None:
            return self._spglib_primitive_system

        spglib_conv_sys = self._get_spglib_conventional_system()
        spglib_conv_wyckoff = self._get_spglib_wyckoff_letters_conventional()
        spglib_conv_equivalent = self._get_spglib_equivalent_atoms_conventional()
        space_group_short = self.get_space_group_international_short()

        (
            spglib_prim_sys,
            spglib_prim_wyckoff,
            spglib_prim_equivalent,
        ) = self._get_primitive_system(
            spglib_conv_sys,
            spglib_conv_wyckoff,
            spglib_conv_equivalent,
            space_group_short,
        )

        self._spglib_primitive_system = spglib_prim_sys
        self._spglib_wyckoff_letters_primitive = spglib_prim_wyckoff
        self._spglib_equivalent_atoms_primitive = spglib_prim_equivalent

        return self._spglib_primitive_system

    def _get_spglib_wyckoff_letters_primitive(self):
        """Get the Wyckoff letters of the atoms in the primitive system as
        defined by spglib.

        Returns:
            list of str: Wyckoff letters.
        """
        if self._spglib_wyckoff_letters_primitive is None:
            wyckoff_letters_original = self._get_spglib_wyckoff_letters_original()
            mapping = self._get_spglib_primitive_to_original_mapping()
            self._spglib_wyckoff_letters_primitive = wyckoff_letters_original[mapping]
        return self._spglib_wyckoff_letters_primitive

    def _get_spglib_equivalent_atoms_primitive(self):
        """List of equivalent atoms in the primitive system as defined by
        spglib.

        Returns:
            list of int: A list that maps each atom into a symmetry equivalent
                set.
        """
        if self._spglib_equivalent_atoms_primitive is None:
            equivalent_atoms_original = self._get_spglib_equivalent_atoms_original()
            mapping = self._get_spglib_primitive_to_original_mapping()
            self._spglib_equivalent_atoms_primitive = equivalent_atoms_original[mapping]
        return self._spglib_equivalent_atoms_primitive

    def _get_spglib_primitive_to_original_mapping(self):
        """Returns a mapping from that links an atom in the primitive cell to
        one of the duplicates in the original system.

        Returns:
            np.ndarray: A list of integer indices, one for each atom in the
            primitive system as returned by spglib. The indices refer to an
            atom in the original simulation system.
        """
        if self._spglib_primitive_to_original_mapping is None:
            dataset = self.get_symmetry_dataset()
            mapping = dataset.mapping_to_primitive
            _, indices = np.unique(mapping, return_index=True)
            self._spglib_primitive_to_original_mapping = indices

        return self._spglib_primitive_to_original_mapping

    def _get_primitive_system(
        self,
        conv_system,
        conv_wyckoff,
        conv_equivalent,
        space_group_international_short,
    ):
        """Returns an primitive description for an idealized system in the
        conventional cell. This description uses a primitive lattice
        where positions of the atoms, and the cell basis vectors are idealized
        to follow the symmetries that were found with the given precision. This
        means that e.g. the volume, angles between basis vectors and basis
        vector lengths may have deviations from the the original system.

        The transformation matrices from conventional system to primitive are
        as given at: https://atztogo.github.io/spglib/definition.html#id8

        Args:
            conv_system (ase.Atoms): The conventional system from which the
                primitive system is created.
            conv_wyckoff (np.array of str): Wyckoff letters of the given
                conventional system
            conv_equivalent (np.array of int): Equivalent atoms of the given
                conventional system
            space_group_international_short (str): The space group symbol in
                international short form

        Returns:
            tuple containing ase.Atoms, wyckoff_letters and equivalent atoms
        """
        centring = space_group_international_short[0]

        # For the primitive centering the conventional lattice is the primitive
        # as well
        if centring == "P":
            return conv_system, conv_wyckoff, conv_equivalent

        primitive_transformations = {
            "A": np.array(
                [
                    [1, 0, 0],
                    [0, 1 / 2, -1 / 2],
                    [0, 1 / 2, 1 / 2],
                ]
            ),
            "C": np.array(
                [
                    [1 / 2, 1 / 2, 0],
                    [-1 / 2, 1 / 2, 0],
                    [0, 0, 1],
                ]
            ),
            "R": np.array(
                [
                    [2 / 3, -1 / 3, -1 / 3],
                    [1 / 3, 1 / 3, -2 / 3],
                    [1 / 3, 1 / 3, 1 / 3],
                ]
            ),
            "I": np.array(
                [
                    [-1 / 2, 1 / 2, 1 / 2],
                    [1 / 2, -1 / 2, 1 / 2],
                    [1 / 2, 1 / 2, -1 / 2],
                ]
            ),
            "F": np.array(
                [
                    [0, 1 / 2, 1 / 2],
                    [1 / 2, 0, 1 / 2],
                    [1 / 2, 1 / 2, 0],
                ]
            ),
        }

        # Transform conventional cell to the primitive cell
        transform = primitive_transformations[centring]
        conv_cell = conv_system.get_cell()
        prim_cell = np.dot(transform.T, conv_cell)

        # Transform all position to the basis of the primitive cell
        conv_pos = conv_system.get_positions()
        prim_cell_inv = np.linalg.inv(prim_cell)
        prim_pos = np.dot(conv_pos, prim_cell_inv)

        # Keep one occurrence for each atom that should be within the cell and
        # wrap it's position to tbe inside the primitive cell.
        conv_num = conv_system.get_atomic_numbers()
        conv_to_prim_map = self._symmetry_dataset.std_mapping_to_primitive
        _, inside_mask = np.unique(conv_to_prim_map, return_index=True)
        prim_pos = prim_pos[inside_mask]
        prim_num = conv_num[inside_mask]

        # Store the wyckoff letters and equivalent atoms
        prim_wyckoff = conv_wyckoff[inside_mask]
        prim_equivalent = conv_equivalent[inside_mask]

        prim_sys = Atoms(
            scaled_positions=prim_pos,
            symbols=prim_num,
            cell=prim_cell,
            pbc=conv_system.get_pbc(),
        )
        prim_sys.wrap()

        return prim_sys, prim_wyckoff, prim_equivalent

    def _system_to_spglib_description(self, system):
        """Transforms the given ASE.Atoms object into a tuple used by spglib."""
        angstrom_cell = self._analyzed_system.get_cell()
        relative_pos = self._analyzed_system.get_scaled_positions()
        atomic_numbers = self._analyzed_system.get_atomic_numbers()
        description = (angstrom_cell, relative_pos, atomic_numbers)

        return description

    def _spglib_description_to_system(self, desc):
        """Transforms a tuple used by spglib into ASE.Atoms"""
        system = Atoms(
            numbers=desc[2],
            cell=desc[0],
            scaled_positions=desc[1],
        )

        return system

    def _wrap_positions(self, positions, precision=1e-5, copy=True):
        """Wrap positions so that each element in the array is within the
        half-closed interval [0, 1)

        By wrapping values near 1 to 0 we will have a consistent way of
        presenting systems.

        Args:
            positions (np.ndarray): Atomic positions that are given in the unit cell basis.
            precision (float): The precision for wrapping coordinates that are close to

                zero or unity.
            copy (bool): Whether a the returned value is a copy or the values are
                modified in-place.

        Returns:
            np.ndarray: The new wrapped positions.
        """
        if copy:
            wrapped_positions = np.copy(positions)
        else:
            wrapped_positions = positions

        wrapped_positions %= 1
        abs_zero = np.absolute(wrapped_positions)
        abs_unity = np.absolute(abs_zero - 1)

        near_zero = np.where(abs_zero < precision)
        near_unity = np.where(abs_unity < precision)

        wrapped_positions[near_unity] = 0
        wrapped_positions[near_zero] = 0

        return wrapped_positions

    def _search_periodic_positions(
        self, target_pos, positions, cell, accuracy, wrap=True
    ):
        """Searches a list of positions for a match for the target position taking
        into account the periodicity of the system.

        Args:
            target_pos (1x3 np.array): The relative position to search.
            positions (Nx3 np.array): The relative position where to search.
            cell (3x3 np.array): The cell used to find a threshold accuracy in
                cartesian coordinates.
            accuracy (float): The minimum cartesian distance (angstroms) that is
                required for the atoms to be considered identical.
            wrap (bool): Whether the positions are wrapped to be within [0, 1].
                Only set to False if you know that the positions are already
                wrapped.

        Returns:
            If a match is found, returns the index of the match in 'positions'. If
            no match is found, returns None.
        """
        if len(positions.shape) == 1:
            positions = positions[np.newaxis, :]

        # Wrap the positions to be within [0, 1]
        if wrap:
            np.remainder(positions, 1, out=positions)
            np.remainder(target_pos, 1, out=target_pos)

        # Calculate the distances without taking into account periodicity.
        # Here we calculate all the distances although in reality we could loop
        # over the distances and calculate only until the correct one is found.
        # But it turns out it is faster to calculate the distances i one
        # vectorized operation with numpy than a python loop.
        displacements = positions - target_pos

        # Take periodicity into account by wrapping coordinate elements that
        # are bigger than 0.5 or smaller than -0.5
        indices = np.where(displacements > 0.5)
        displacements[indices] = displacements[indices] - 1
        indices = np.where(displacements < -0.5)
        displacements[indices] = displacements[indices] + 1

        # Convert displacements to cartesian coordinates
        displacements = np.dot(displacements, cell.T)

        # Try to find a match for the target by finding a distance that is less
        # than the accuracy
        distances = np.linalg.norm(displacements, axis=1)

        min_index = np.argmin(distances)
        min_distance = distances[min_index]
        if min_distance <= accuracy:
            return min_index
        else:
            return None

    def _find_wyckoff_ground_state(
        self,
        space_group,
        old_wyckoff_letters,
        system,
    ):
        """
        When given a system that has been normalized by spglib, this function
        will find a atomic positions within that cell that are most unique
        (totally unique up to isotropic scaling if no free Wyckoff parameters
        present).

        The function is based on iterating through chirality-preserving
        Euclidean normalizers (more information on normalizers can be found e.g.
        in "Space Groups for Solid State Scientists", page 246, ISBN:
        9780123946157, normalizers can be found for each space group e.g. at the
        Bilbao Crystallographic Server) http://www.cryst.ehu.es/), which are
        essentially transforms that changes Wyckoff positions of atoms within a
        cell without changing the structure itself. Each of these normalizers
        can be applied to give a different structural representation of the same
        material. The algorithm then goes through each tuple of Wyckoff letter
        and atomic number (W , Z) in a preset order: the first loop goes through
        the Wyckoff letters in alphabetical order, and the second loop goes
        through the atomic numbers from lowest to highest. Whenever some of the
        possible representations has a structural component corresponding to the
        current tuple (W , Z), the number of atoms with this tuple N is
        calculated. The representation is stored to a map structure that links
        each N to a list of representations and the highest N is tracked. After
        all the representations are covered, the candidate list of
        representations is replaced with the list corresponding to the highest
        N. The algorithm stops whenever the candidate set contains only one
        representation, which will be the standard one.

        Args:
            space_group(int): The space group of the system.
            old_wyckoff_letters(list of strings): Wyckoff letters as detected
                by spglib for the atoms in the given system.
            system(ase.Atoms): The standardized system as given by spglib.

        Returns:
            (ase.Atoms, list of strings): Returns a tuple containing the found
            conventional system and the Wyckoff letters for it.
        """
        # Gather the allowed transformations. In addition to proper rotations
        # the space group symmetries may allow improper rotations without
        # 'breaking' the structure.
        normalizers = []
        identity = {
            "transformation": np.identity(4),
            "permutations": {x: x for x in old_wyckoff_letters},
            "identity": True,
        }
        normalizers.append(identity)
        normalizers.extend(
            CHIRALITY_PRESERVING_EUCLIDEAN_NORMALIZERS.get(space_group, [])
        )

        # If no normalizers found for this space group, return the same system
        if len(normalizers) == 1:
            self._best_transform = identity
            return system, old_wyckoff_letters

        # Form all available representations
        representations = []
        atomic_numbers = system.get_atomic_numbers()
        for transform in normalizers:
            perm = transform["permutations"]
            representation = {
                "transformation": transform["transformation"],
                "permutations": perm,
            }
            wyckoff_positions = {}
            wyckoff_letters = []
            i_perm = 0
            for i_atom, old_w in enumerate(old_wyckoff_letters):
                new_w = perm.get(old_w)
                wyckoff_letters.append(new_w)
                if new_w is not None:
                    z = atomic_numbers[i_atom]
                    old_n_atoms = wyckoff_positions.get((new_w, z))
                    if old_n_atoms is None:
                        wyckoff_positions[(new_w, z)] = 1
                    else:
                        wyckoff_positions[(new_w, z)] += 1
                    i_perm += 1
            representation["wyckoff_positions"] = wyckoff_positions
            representations.append(representation)

        # Gather all available Wyckoff letters in all representations
        wyckoff_letters = set()
        for transform in normalizers:
            i_perm = transform["permutations"]
            for orig, new in i_perm.items():
                wyckoff_letters.add(new)
        wyckoff_letters = sorted(wyckoff_letters)

        # Gather all available atomic numbers
        atomic_numbers = set(system.get_atomic_numbers())
        atomic_numbers = sorted(atomic_numbers)

        # Decide the best representation
        best_representation = None
        found = False
        for w in wyckoff_letters:
            if found:
                break
            for z in atomic_numbers:
                n_atoms_map = defaultdict(list)
                n_atoms_max = 0
                for r in representations:
                    i_n = r["wyckoff_positions"].get((w, z))
                    if i_n is not None:
                        n_atoms_map[i_n].append(r)
                        n_atoms_max = max(i_n, n_atoms_max)
                if n_atoms_max != 0:
                    representations = n_atoms_map[n_atoms_max]
                if len(representations) == 1:
                    best_representation = representations[0]
                    found = True

        # If no best transformation was found, then multiple transformation are
        # equal. Ensure this and then choose the first one.
        error = MatIDError("Could not successfully decide best Wyckoff positions.")
        if len(representations) > 1:
            new_wyckoffs = representations[0]["wyckoff_positions"]
            n_items = len(new_wyckoffs)
            for representation in representations[1:]:
                i_wyckoffs = representation["wyckoff_positions"]
                if len(i_wyckoffs) != n_items:
                    raise error
                for key in new_wyckoffs.keys():
                    if i_wyckoffs[key] != new_wyckoffs[key]:
                        raise error
        best_representation = representations[0]

        # Apply the best transform
        new_system = system.copy()
        if best_representation.get("identity"):
            self._best_transform = identity
            return new_system, old_wyckoff_letters
        else:
            self._best_transform = best_representation
            best_transformation_matrix = best_representation["transformation"]
            best_permutations = best_representation["permutations"]
            new_wyckoff_letters = []
            for i_atom, old_w in enumerate(old_wyckoff_letters):
                new_w = best_permutations.get(old_w)
                new_wyckoff_letters.append(new_w)
            new_wyckoff_letters = np.array(new_wyckoff_letters)

            # Create the homogeneus coordinates
            n_pos = len(system)
            old_pos = np.empty((n_pos, 4))
            old_pos[:, 3] = 1
            old_pos[:, 0:3] = system.get_scaled_positions()

            # Apply transformation with the augmented 3x4 matrix that is used
            # for homogeneous coordinates
            transformed_positions = np.dot(old_pos, best_transformation_matrix.T)

            # Get rid of the extra dimension of the homogeneous coordinates
            transformed_positions = transformed_positions[:, 0:3]

            # Wrap the positions to the half-closed interval [0, 1)
            wrapped_pos = matid.geometry.get_wrapped_positions(transformed_positions)
            new_system.set_scaled_positions(wrapped_pos)

            return new_system, new_wyckoff_letters

    def _get_wyckoff_sets(
        self,
        system,
        space_group,
        wyckoff_letters,
        equivalent_atoms,
        precision,
        return_parameters,
    ):
        r"""Used to get detailed information about about the sets of equivalent
        atoms. The detected Wyckoff set variables (x, y, z) are reported
        consistenly by selecting the variable sets that has lowest x value, then
        lowest y and finally lowest z.

        If return_parameters is set to True, the possible variables for the
        Wyckoff sets are returned. Because spglib does not currently print out
        detailed information about free parameters for Wyckoff sets, we use
        information from the Bilbao Crystallographic Database to get the values
        of the free variables for each group of symmetry related atoms. The
        positions R can be calculated as (using row vectors):

        .. math::
            \mathbf{R} = \mathbf{W}\mathbf{M} + \mathbf{C}

        where :math:`\mathbf{W}` is a row vector containing the Wyckoff variables,
        :math:`\mathbf{M}` is a matrix defining the multipliers for each Wyckoff
        variable and :math:`\mathbf{C}` is a vector containing the constant
        offsets. The Wyckoff variables :math:`\mathbf{W}` can be simply
        determined from the first representative position of each set. The
        representative position alone uniquely defines all necessary variables.

        Args:
            system (System): The atomic system in which the atomic positions
                are inspected.
            space_group (int): The space group number
            wyckoff_letters (list): List of Wyckoff letters for each atom in
                the system
            equivalent_atoms (list): List of integers that distinguish atoms
                that are related by symmetry.
            precision (float): The precision for matching atoms to Wyckoff
                positions.
            return_parameters (bool): Whether to return the value of possible
                free Wyckoff parameters. Set to False if not needed, as their
                determination can take time.

        Returns:
            list of WyckoffSets: A list of :class:`.WyckoffSet` objects for this
            system.
        """
        cell = system.get_cell()
        elements = system.get_chemical_symbols()
        numbers = system.get_atomic_numbers()
        positions = system.get_scaled_positions()
        wyckoff_infos = WYCKOFF_SETS[space_group]

        # Form the Wyckoff sets
        sets = OrderedDict()
        unique_sets, unique_indices = np.unique(equivalent_atoms, return_index=True)
        for i_set, index in enumerate(unique_indices):
            set_index = unique_sets[i_set]

            set_data = WyckoffSet(
                wyckoff_letter=str(wyckoff_letters[index]),
                element=str(elements[index]),
                atomic_number=int(numbers[index]),
                space_group=space_group,
                representative=wyckoff_infos[wyckoff_letters[index]]["expressions"][0],
            )
            set_data.indices = []
            sets[set_index] = set_data

        # Add the indices of the atoms that belong to these sets, and add the
        # multiplicity of the set
        for i_atom, set_number in enumerate(equivalent_atoms):
            sets[set_number].indices.append(i_atom)
        for wset in sets.values():
            wset.multiplicity = len(wset.indices)

        # For each set, solve the free variables if any present.
        if return_parameters:
            translations = wyckoff_infos["translations"]
            for i_set, wset in sets.items():
                indices = wset.indices[:]
                wyckoff_letter = wset.wyckoff_letter
                wyckoff_info = wyckoff_infos[wyckoff_letter]
                variables_present = wyckoff_info["variables"]
                all_pos = positions[indices]

                # Get the precalculated matrices and vectors that are needed
                # for solving the system of linear equations.
                Ms = wyckoff_info["matrices"]
                Cs = wyckoff_info["constants"]

                # Resolve the needed variables
                if variables_present:
                    variable_map = {}
                    for ivar, variable in enumerate(["x", "y", "z"]):
                        if variable in variables_present:
                            variable_map[ivar] = variable
                    n_expr = Cs.shape[0]
                    n_trans = len(translations)
                    W_final = None

                    # Sort the positions to find a consistent set of variables
                    # each time. The positions are ordered by the present
                    # variables: first x, then y, and finally z. Another option
                    # would have been to find all possible  Wyckoff variable
                    # alternatives and instead sort them to select one. That
                    # however is much more time-consuming for larger, highly
                    # symmetric structures with many variable alternatives.
                    # Notice that lexsort orders by last column first, thus the
                    # reversing. The values used for sorting are rounded to
                    # machine precision to avoid seeming inconsistencies in
                    # sorting.
                    sort_columns = np.array(
                        [
                            all_pos[:, ivar]
                            for ivar in reversed(list(variable_map.keys()))
                        ]
                    )
                    np.around(sort_columns, decimals=9, out=sort_columns)
                    sorted_indices = np.lexsort(sort_columns)
                    sorted_pos = all_pos[sorted_indices]

                    # Calculate the variables (x, y, z) based on the first
                    # representative Wyckoff position. The variables are
                    # calculated in a loop for each atom until a valid set is
                    # found.
                    for atom_index in indices:
                        # The first representative Wyckoff position is used for
                        # solving the variables. Only the first occurrence of
                        # the variable with multiplier 1 is used.
                        R = positions[atom_index]
                        W = np.zeros(3)
                        M = Ms[0]
                        C = Cs[0]
                        for idx, var in variable_map.items():
                            for icomp in range(3):
                                if M[idx][icomp] == 1:
                                    W[idx] = R[idx] - C[idx]
                                    break

                        # Check that found variables make sense. Otherwise
                        # continue with next position.
                        if (
                            self._search_periodic_positions(
                                np.dot(W, M) + C, R, cell, 1e-3
                            )
                            is None
                        ):
                            continue

                        # Calculate the positions of all other atoms with the
                        # currently tested Wyckoff variables.
                        test_positions = np.zeros(((n_trans + 1) * n_expr, 3))
                        first_test_pos = np.dot(W, Ms) + Cs
                        test_positions[0:n_expr, :] = first_test_pos
                        i_trans = 1
                        for trans in translations:
                            test_positions[
                                i_trans * n_expr : (i_trans + 1) * n_expr, :
                            ] = first_test_pos + trans
                            i_trans += 1

                        # Test if each test positions can be matched to an atom
                        # in the actual structure. If yes, save this set of
                        # variables and break the loop.
                        found = True
                        for test_pos in test_positions:
                            if (
                                self._search_periodic_positions(
                                    test_pos, sorted_pos, cell, precision
                                )
                                is None
                            ):
                                found = False
                                break
                        if found:
                            W_final = W
                            break

                    if W_final is None:
                        raise ValueError(
                            "Could not resolve the free Wyckoff parameters for "
                            "Wyckoff letter '{}' in space group {}. Problem in "
                            "determining variables for element '{}' at indices "
                            "'{}'.".format(
                                wset.wyckoff_letter,
                                wset.space_group,
                                wset.element,
                                wset.indices,
                            )
                        )
                    else:
                        # Wrap Wyckoff variables to be between [0, 1] and save
                        # to the WyckoffSet
                        W_final = matid.geometry.get_wrapped_positions(W_final)
                        for ivar, var in variable_map.items():
                            setattr(wset, var, W_final[ivar])

        # Sort the list so that sets with Wyckoff letter earlier in the
        # alphabet are first, and sets with the same Wyckoff letter are
        # sorted by atomic number. Groups with the same Wyckoff letter and
        # atomic number are still randomly sorted.
        unsorted_list = list(sets.values())
        sorted_list = sorted(
            unsorted_list, key=attrgetter("wyckoff_letter", "atomic_number")
        )

        return sorted_list
