from typing import Optional, Union

import nibabel as nib
import numpy as np
from nibabel.orientations import io_orientation, axcodes2ornt, ornt_transform, inv_ornt_aff

from cavass._io import ensure_output_file_dir_existence


def save_nifti(output_file,
               data,
               voxel_spacing: Optional[Union[float, list[float], tuple[float, ...]]] = None,
               orientation="LPI"):
    """
    Save improc with nii format.

    Args:
        output_file (str):
        data (numpy.ndarray):
        voxel_spacing (sequence or None, optional, default=None): `tuple(x, y, z)`. Voxel spacing of each axis. If None,
            make `voxel_spacing` as `(1.0, 1.0, 1.0)`.
        orientation (str, optional, default="LPI"): "LPI" | "ARI". LPI: Left-Posterior-Inferior;
            ARI: Anterior-Right-Inferior.

    Returns:

    """
    if voxel_spacing is None:
        voxel_spacing = (1.0, 1.0, 1.0)  # replace this with your desired voxel spacing in millimeters

    match orientation:
        case "LPI":
            affine_matrix = np.diag(list(voxel_spacing) + [1.0])
        case "ARI":
            # calculate the affine matrix based on the desired voxel spacing and ARI orientation
            affine_matrix = np.array([
                [0, -voxel_spacing[0], 0, 0],
                [-voxel_spacing[1], 0, 0, 0],
                [0, 0, voxel_spacing[2], 0],
                [0, 0, 0, 1]
            ])
        case _:
            raise ValueError(f"Unsupported orientation {orientation}.")

    # create a NIfTI improc object

    ensure_output_file_dir_existence(output_file)
    nii_img = nib.Nifti1Image(data, affine=affine_matrix)
    nib.save(nii_img, output_file)


def reorient_nifti(input_nifti_image, target_orientation):
    """
    Reorient a NIfTI image.

    Parameters:
        input_nifti_image (): NIfTI image to be reoriented.
        target_orientation (str): Target orientation, e.g., 'ARI', 'RAI', 'LPI'.
    """
    data, affine, header = input_nifti_image.get_fdata(), input_nifti_image.affine, input_nifti_image.header

    # Current and target orientation
    current_ornt = io_orientation(affine)
    target_ornt = axcodes2ornt(tuple(target_orientation))

    # Check if reorientation is needed
    if np.array_equal(current_ornt, target_ornt):
        print(f"Image already in {target_orientation} orientation.")
        return input_nifti_image

    # Compute transform
    transform = ornt_transform(current_ornt, target_ornt)

    # Apply flips
    for i, (src, dir) in enumerate(transform):
        if dir == -1:
            data = np.flip(data, axis=int(src))

    # Transpose axes
    data = np.transpose(data, transform[:, 0].astype(int))

    # Update affine
    new_affine = inv_ornt_aff(transform, input_nifti_image.shape)
    return nib.Nifti1Image(data, new_affine, header=header)
