# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from netket.utils.moduletools import export, hide_unexported

from netket.utils.group.axial import cuboid, cuboid_rotations  # noqa: F401
from netket.utils.group.axial import inversion_group as _inv_group
from netket.utils.group.axial import reflection_group as _refl_group
from netket.utils.group.axial import rotation as _rotation
from netket.utils.group._point_group import PointGroup
from netket.utils.group._semigroup import Identity


hide_unexported(__name__)

__all__ = [
    "tetrahedral_rotations",
    "tetrahedral",
    "pyritohedral",
    "octahedral_rotations",
    "cubic_rotations",
    "octahedral",
    "cubic",
    "diamond",
    "pyrochlore",
    "Fd3m",
]


@export
def T() -> PointGroup:
    """
    Rotational symmetries of a tetrahedron with vertices
    (1,1,1), (1,-1,-1), (-1,1,-1), (-1,-1,1).
    """
    return PointGroup(
        [
            Identity(),
            _rotation(120, [1, 1, 1]),
            _rotation(120, [1, -1, -1]),
            _rotation(120, [-1, 1, -1]),
            _rotation(120, [-1, -1, 1]),
            _rotation(-120, [1, 1, 1]),
            _rotation(-120, [1, -1, -1]),
            _rotation(-120, [-1, 1, -1]),
            _rotation(-120, [-1, -1, 1]),
            _rotation(180, [0, 0, 1]),
            _rotation(180, [0, 1, 0]),
            _rotation(180, [1, 0, 0]),
        ],
        ndim=3,
    )


tetrahedral_rotations = T


@export
def Td() -> PointGroup:
    r"""
    Symmetry group of a tetrahedron with vertices
    (1,1,1), (1,-1,-1), (-1,1,-1), (-1,-1,1).
    """
    return _refl_group([1, 1, 0]) @ T()


tetrahedral = Td


@export
def Th() -> PointGroup:
    """Pyritohedral symmetry group generated by T and inversion."""
    return _inv_group() @ T()


pyritohedral = Th


@export
def O() -> PointGroup:  # noqa: E741, E743
    """Rotational symmetries of a cube/octahedron aligned with the Cartesian axes."""
    # NB the first factor isn't an actual point group but this is fine
    # we only use it to generate a coset of T in O
    return PointGroup([Identity(), _rotation(90, [0, 0, 1])], ndim=3) @ T()


octahedral_rotations = cubic_rotations = O


@export
def Oh() -> PointGroup:
    """Symmetry group of a cube/octahedron aligned with the Cartesian axes."""
    return _inv_group() @ O()


octahedral = cubic = Oh


@export
def Fd3m() -> PointGroup:
    """Nonsymmorphic "point group" of the diamond and pyrochlore lattices with a
    cubic unit cell of side length 1."""
    return (_inv_group().change_origin([1 / 8, 1 / 8, 1 / 8]) @ Td()).replace(
        unit_cell=np.asarray([[0, 0.5, 0.5], [0.5, 0, 0.5], [0.5, 0.5, 0]])
    )


diamond = pyrochlore = Fd3m
