import os
import shutil
from typing import Optional

import nibabel as nib
import numpy as np
from pydicom.uid import generate_uid

from cavass._log import logger
from cavass.dicom import init_dicom_dataset
from cavass.nifti import reorient_nifti


def get_patient_orientation_from_nifti_affine(affine):
    row_dir = affine[:3, 0] / np.linalg.norm(affine[:3, 0])
    col_dir = affine[:3, 1] / np.linalg.norm(affine[:3, 1])

    def axis_to_code(v):
        idx = np.argmax(np.abs(v))
        if idx == 0:
            return 'R' if v[idx] < 0 else 'L'
        elif idx == 1:
            return 'A' if v[idx] < 0 else 'P'
        else:
            return 'F' if v[idx] < 0 else 'H'

    return [axis_to_code(row_dir), axis_to_code(col_dir)]


def nifti2dicom(input_nifti_file, output_dicom_dir, modality: str, signed: int = 1, orientation: Optional[str] = None,
                force_overwrite=False, **kwargs):
    """
    Convert NIfTI image to DICOM image series.

    Args:
        input_nifti_file (str):
        output_dicom_dir (str):
        modality (str):
        signed (int): 0: unsigned integer, 1: signed integer.
        orientation (str, optional, default=None): If provided, use the orientation for the output DICOM series.
        force_overwrite (bool, optional, default=False): if `Ture`, overwrite `output_dicom_dir` if it exists.`

    Returns:

    """
    if not os.path.isfile(input_nifti_file):
        raise FileNotFoundError(f"Input NIfTI file {input_nifti_file} not found.")

    if os.path.exists(output_dicom_dir):
        if force_overwrite:
            logger.info(f"Overwrite {output_dicom_dir} as it already exists.")
            shutil.rmtree(output_dicom_dir)
        else:
            raise ValueError(f"Output DICOM series dir {output_dicom_dir} already exists.")

    nii_data = nib.load(input_nifti_file)
    if orientation is not None:
        nii_data = reorient_nifti(nii_data, orientation)

    header = nii_data.header
    img = nii_data.get_fdata()
    affine = nii_data.affine

    attrs = {}
    attrs["Modality"] = modality
    attrs["StudyInstanceUID"] = generate_uid()
    attrs["SeriesInstanceUID"] = generate_uid()
    attrs["FrameOfReferenceUID"] = generate_uid()
    attrs["ImageType"] = ["SECONDARY", "DERIVED"]

    row_normal = affine[:3, 0] / np.linalg.norm(affine[:3, 0])
    col_normal = affine[:3, 1] / np.linalg.norm(affine[:3, 1])
    slice_normal = affine[:3, 2] / np.linalg.norm(affine[:3, 2])

    attrs["PatientOrientation"] = get_patient_orientation_from_nifti_affine(affine)
    attrs["ImageOrientationPatient"] = [float(x) for x in np.concatenate([row_normal, col_normal])]

    # voxel size
    voxel_sizes = header.get_zooms()
    attrs["PixelSpacing"] = [float(voxel_sizes[0]), float(voxel_sizes[1])]
    attrs["SpacingBetweenSlices"] = np.linalg.norm(affine[:3, 2])
    attrs["SliceThickness"] = voxel_sizes[2]

    rescale_slope = 1.0
    rescale_intercept = 0.0
    if "scl_slope" in header and not np.isnan(header["scl_slope"]):
        rescale_slope = header["scl_slope"]
    if "scl_inter" in header and not np.isnan(header["scl_inter"]):
        rescale_intercept = header["scl_inter"]

    if not signed:
        v_min = img.min()
        if v_min < 0:
            img = img - v_min
            rescale_intercept = rescale_intercept + v_min

    attrs["RescaleSlope"] = rescale_slope
    attrs["RescaleIntercept"] = rescale_intercept

    attrs["PixelRepresentation"] = signed
    if modality in ["CT"]:
        if signed:
            dtype = np.int16
        else:
            dtype = np.uint16
        attrs["BitsAllocated"] = 16
        attrs["BitsStored"] = 16

    elif modality in ["PET"]:
        if signed:
            dtype = np.float32
            attrs["BitsAllocated"] = 32
            attrs["BitsStored"] = 32
            attrs["PixelRepresentation"] = 0
        else:
            dtype = np.uint16
            if kwargs.get("rescale_PET_value", False):
                slope = kwargs.get("PET_value_rescale_slope", 100)
                img = img * slope
                attrs["RescaleSlope"] = 1/slope
            attrs["BitsAllocated"] = 16
            attrs["BitsStored"] = 16
    else:
        raise ValueError(f"Modality {modality} is not supported.")

    os.makedirs(output_dicom_dir)
    attrs = kwargs | attrs
    for i in range(img.shape[2]):
        slice_data = img[:, :, i].astype(dtype)
        attrs["Rows"] = slice_data.shape[0]
        attrs["Columns"] = slice_data.shape[1]
        if dtype in [np.int16, np.uint16]:
            attrs["SmallestImagePixelValue"] = int(slice_data.min())
            attrs["LargestImagePixelValue"] = int(slice_data.max())
        origin = affine[:3, 3] + i * affine[:3, 2]
        attrs["ImagePositionPatient"] = [float(x) for x in origin]
        attrs["SliceLocation"] = float(np.dot(slice_normal, origin))
        attrs["InstanceNumber"] = i + 1
        attrs["PixelData"] = slice_data.tobytes()
        ds = init_dicom_dataset(**attrs)
        file_name = os.path.join(output_dicom_dir, f"slice_{i + 1:03d}.dcm")
        ds.save_as(file_name)
