#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ---------------------------------------------------------------------------
# Copyright 2022 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------------
# Created By  : Tomography Team at DLS <scientificsoftware@diamond.ac.uk>
# Created Date: 21 September 2023
# ---------------------------------------------------------------------------
"""Modules for memory estimation for stripe removal methods"""

import math
from typing import Tuple
import numpy as np

from httomo_backends.cufft import CufftType, cufft_estimate_1d


__all__ = [
    "_calc_memory_bytes_remove_stripe_ti",
    "_calc_memory_bytes_remove_all_stripe",
    "_calc_memory_bytes_raven_filter",
]


def _calc_memory_bytes_remove_stripe_ti(
    non_slice_dims_shape: Tuple[int, int],
    dtype: np.dtype,
    **kwargs,
) -> Tuple[int, int]:
    # This is admittedly a rough estimation, but it should be about right
    gamma_mem = non_slice_dims_shape[1] * np.float64().itemsize

    in_slice_mem = np.prod(non_slice_dims_shape) * dtype.itemsize
    slice_mean_mem = non_slice_dims_shape[1] * dtype.itemsize * 2
    slice_fft_plan_mem = slice_mean_mem * 3.5
    extra_temp_mem = slice_mean_mem * 8

    tot_memory_bytes = int(
        in_slice_mem + slice_mean_mem + slice_fft_plan_mem + extra_temp_mem
    )
    return (tot_memory_bytes, gamma_mem)


def _calc_memory_bytes_remove_all_stripe(
    non_slice_dims_shape: Tuple[int, int],
    dtype: np.dtype,
    **kwargs,
) -> Tuple[int, int]:
    # Extremely memory hungry function but it works slice-by-slice so
    # we need to compensate for that.

    input_size = np.prod(non_slice_dims_shape) * dtype.itemsize
    output_size = np.prod(non_slice_dims_shape) * dtype.itemsize

    methods_memory_allocations = int(30 * input_size)

    tot_memory_bytes = int(input_size + output_size)

    return (tot_memory_bytes, methods_memory_allocations)


def _calc_memory_bytes_raven_filter(
    non_slice_dims_shape: Tuple[int, int],
    dtype: np.dtype,
    **kwargs,
) -> Tuple[int, int]:

    pad_x = kwargs["pad_x"]
    pad_y = kwargs["pad_y"]

    # Unpadded input
    input_size = np.prod(non_slice_dims_shape) * dtype.itemsize

    # Padded input
    padded_non_slice_dims_shape = (
        non_slice_dims_shape[0] + 2 * pad_y,
        non_slice_dims_shape[1] + 2 * pad_x,
    )
    in_slice_size_pad = (
        (padded_non_slice_dims_shape[0])
        * (padded_non_slice_dims_shape[1])
        * dtype.itemsize
    )

    # Conversion of padded input data to `complex64` (implicitly done by `fft2()` function)
    complex_slice_fft_data = in_slice_size_pad / dtype.itemsize * np.complex64().nbytes

    # 2D FFT becomes two 1D FFTs (possibly due to applying 2D FFT to non-adjacent axes 0 and
    # 2), so a plan for a 1D FFT is needed rather than a plan for a 2D FFT
    fft_1d_plan = cufft_estimate_1d(
        nx=padded_non_slice_dims_shape[0],
        fft_type=CufftType.CUFFT_C2C,
        batch=non_slice_dims_shape[1],
    )

    # Copy from applying fftshift to FFT result
    complex_slice_fft_data_shifted = complex_slice_fft_data

    # Two copies of `complex64` data come from 2D IFFT becoming a loop over two 1D IFFTs, and
    # applying 1D IFFT to non-adjacent axes 0 and 2 causes data to not be C contiguous (thus,
    # needing to be copied to get a version of the data which is C contiguous)
    #
    # NOTE: The same copies are generated by the 2D FFT becoming two 1D FFTs, but the order of
    # allocations and deallocations of those copies are such that they don't contribute to peak
    # GPU memory usage. Thus, they aren't accounted for in the estimated memory, unlike the
    # copies generated by the 2D IFFT becoming two 1D IFFTs
    ifft_complex64_copies = 2 * complex_slice_fft_data

    tot_memory_bytes = int(
        input_size
        + in_slice_size_pad
        + complex_slice_fft_data
        + 2 * fft_1d_plan
        + complex_slice_fft_data_shifted
        + ifft_complex64_copies
    )

    return (tot_memory_bytes, 0)
