"""
=======================================================================
Models for optoelectronic devices (:mod:`optic.models.devices`)
=======================================================================

.. autosummary::
   :toctree: generated/

   pm                    -- Optical phase modulator.
   mzm                   -- Optical Mach-Zhender modulator.
   iqm                   -- Optical In-Phase/Quadrature Modulator (IQM).
   pbs                   -- Polarization beam splitter (PBS).
   opticalHybrid2x4      -- Optical hybrid 2 x 4 90°.
   voa                   -- Variable optical attenuator (VOA).
   photodiode            -- Pin photodiode.
   balancedPD            -- Balanced photodiode pair.
   coherentReceiver      -- Optical coherent receiver (single polarization).
   pdmCoherentReceiver   -- Optical polarization-multiplexed coherent receiver.
   edfa                  -- Simple EDFA model (gain + AWGN noise).
   basicLaserModel       -- Laser model with Maxwellian random walk phase noise and RIN.
   adc                   -- Analog-to-digital converter (ADC) model.
"""

"""Basic physical models for optical/electronic devices."""
import logging as logg

import numpy as np
import scipy.constants as const

from optic.dsp.core import (
    clockSamplingInterp,
    gaussianComplexNoise,
    lowPassFIR,
    phaseNoise,
    quantizer,
)
from optic.utils import dBm2W, parameters

try:
    from optic.dsp.coreGPU import checkGPU

    if checkGPU():
        from optic.dsp.coreGPU import firFilter
    else:
        from optic.dsp.core import firFilter
except ImportError:
    from optic.dsp.core import firFilter


def pm(Ai, u, Vπ):
    """
    Optical Phase Modulator (PM).

    Parameters
    ----------
    Ai : scalar or np.array
        Amplitude of the optical field at the input of the PM.
    u : np.array
        Electrical driving signal.
    Vπ : scalar
        PM's Vπ voltage.
    Returns
    -------
    Ao : np.array
        Modulated optical field at the output of the PM.

    References
    ----------
    [1] G. P. Agrawal, Fiber-Optic Communication Systems. Wiley, 2021.

    """
    try:
        u.shape
    except AttributeError:
        u = np.array([u])

    try:
        if Ai.shape == () and u.shape != ():
            Ai = Ai * np.ones(u.shape)
        else:
            assert Ai.shape == u.shape, "Ai and u need to have the same dimensions"
    except AttributeError:
        Ai = Ai * np.ones(u.shape)

    π = np.pi
    return Ai * np.exp(1j * (u / Vπ) * π)


def mzm(Ai, u, param=None):
    """
    Optical Mach-Zehnder Modulator (MZM).

    Parameters
    ----------
    Ai : scalar or np.array
        Amplitude of the optical field at the input of the MZM.
    u : np.array
        Electrical driving signal.
    param : optic.utils.parameters object, optional
        Parameters of the MZM model.

        - param.Vpi: MZM's Vpi voltage [V][default: 2 V]
        - param.Vb: MZM's bias voltage [V][default: -1 V]

    Returns
    -------
    Ao : np.array
        Modulated optical field at the output of the MZM.

    References
    ----------
    [1] G. P. Agrawal, Fiber-Optic Communication Systems. Wiley, 2021.

    [2] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    """
    if param is None:
        param = []

    # check input parameters
    Vpi = getattr(param, "Vpi", 2)
    Vb = getattr(param, "Vb", -1)

    try:
        u.shape
    except AttributeError:
        u = np.array([u])

    try:
        if Ai.shape == () and u.shape != ():
            Ai = Ai * np.ones(u.shape)
        else:
            assert Ai.shape == u.shape, "Ai and u need to have the same dimensions"
    except AttributeError:
        Ai = Ai * np.ones(u.shape)

    π = np.pi
    return Ai * np.cos(0.5 / Vpi * (u + Vb) * π)


def iqm(Ai, u, param=None):
    """
    Optical In-Phase/Quadrature Modulator (IQM).

    Parameters
    ----------
    Ai : scalar or np.array
        Amplitude of the optical field at the input of the IQM.
    u : complex-valued np.array
        Modulator's driving signal (complex-valued baseband).
    param : optic.utils.parameters object, optional
        Parameters of the MZM models.

        - param.Vpi: MZM's Vpi voltage [V][default: 2 V]
        - param.VbI: I-MZM's bias voltage [V][default: -2 V]
        - param.VbQ: Q-MZM's bias voltage [V][default: -2 V]
        - param.Vphi: PM bias voltage [V][default: 1 V]

    Returns
    -------
    Ao : complex-valued np.array
        Modulated optical field at the output of the IQM.

    References
    ----------
    [1] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    """
    if param is None:
        param = []

    # check input parameters
    Vpi = getattr(param, "Vpi", 2)
    VbI = getattr(param, "VbI", -2)
    VbQ = getattr(param, "VbQ", -2)
    Vphi = getattr(param, "Vphi", 1)

    try:
        u.shape
    except AttributeError:
        u = np.array([u])

    try:
        if Ai.shape == () and u.shape != ():
            Ai = Ai * np.ones(u.shape)
        else:
            assert Ai.shape == u.shape, "Ai and u need to have the same dimensions"
    except AttributeError:
        Ai = Ai * np.ones(u.shape)

    # define parameters for the I-MZM:
    paramI = parameters()
    paramI.Vpi = Vpi
    paramI.Vb = VbI

    # define parameters for the Q-MZM:
    paramQ = parameters()
    paramQ.Vpi = Vpi
    paramQ.Vb = VbQ

    return mzm(Ai / np.sqrt(2), u.real, paramI) + pm(
        mzm(Ai / np.sqrt(2), u.imag, paramQ), Vphi * np.ones(u.shape), Vpi
    )


def pbs(E, θ=0):
    """
    Polarization beam splitter (PBS).

    Parameters
    ----------
    E : (N,2) np.array
        Input pol. multiplexed optical field.
    θ : scalar, optional
        Rotation angle of input field in radians. The default is 0.

    Returns
    -------
    Ex : (N,) np.array
        Ex output single pol. field.
    Ey : (N,) np.array
        Ey output single pol. field.

    References
    ----------
    [1] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    """
    try:
        assert E.shape[1] == 2, "E need to be a (N,2) or a (N,) np.array"
    except IndexError:
        E = np.repeat(E, 2).reshape(-1, 2)
        E[:, 1] = 0

    rot = np.array([[np.cos(θ), -np.sin(θ)], [np.sin(θ), np.cos(θ)]]) + 1j * 0

    E = E @ rot

    Ex = E[:, 0]
    Ey = E[:, 1]

    return Ex, Ey


def voa(E, A=0):
    """
    Variable optical attenuator (VOA).

    Parameters
    ----------
    E : np.array
        Input optical field.
    A : float
        attenuation [dB][default: 0 dB]

    Returns
    -------
    Eo : np.array
          Output optical field.

    References
    ----------
    [1] G. P. Agrawal, Fiber-Optic Communication Systems. Wiley, 2021.

    """
    assert A >= 0, "Attenuation should be a positive scalar"

    return E * 10 ** (-A / 20)


def photodiode(E, param=None):
    """
    Pin photodiode (PD).

    Parameters
    ----------
    E : np.array
        Input optical field.
    param : optic.utils.parameters object, optional
        Parameters of the photodiode model.

        - param.R: photodiode responsivity [A/W][default: 1 A/W]
        - param.Tc: temperature [°C][default: 25°C]
        - param.Id: dark current [A][default: 5e-9 A]
        - param.Ipd_sat: saturation value of the photocurrent [A][default: 5e-3 A]
        - param.RL: impedance load [Ω] [default: 50Ω]
        - param.B bandwidth [Hz][default: 30e9 Hz]
        - param.Fs: sampling frequency [Hz] [default: None]
        - param.fType: frequency response type [default: 'rect']
        - param.N: number of the frequency resp. filter taps. [default: 255]
        - param.ideal: consider ideal photodiode (i.e. :math:`i_{pd}(t) = R|E(t)|^2`) [default: False]
        - param.shotNoise: add shot noise to photocurrent. [default: True]
        - param.thermalNoise: add thermal noise to photocurrent. [default: True]
        - param.currentSaturation: consider photocurrent saturation. [default: False]
        - param.bandwidthLimitation: consider bandwidth limitation. [default: True]
        - param.seed: seed for the random number generator [default: None]

    Returns
    -------
    ipd : np.array
          photocurrent.

    References
    ----------
    [1] G. P. Agrawal, Fiber-Optic Communication Systems. Wiley, 2021.

    """
    if param is None:
        param = []
    kB = const.value("Boltzmann constant")
    q = const.value("elementary charge")

    # check input parameters
    R = getattr(param, "R", 1)
    Tc = getattr(param, "Tc", 25)
    Id = getattr(param, "Id", 5e-9)
    RL = getattr(param, "RL", 50)
    B = getattr(param, "B", 30e9)
    Ipd_sat = getattr(param, "Ipd_sat", 5e-3)
    N = getattr(param, "N", 255)
    fType = getattr(param, "fType", "rect")
    seed = getattr(param, "seed", None)
    ideal = getattr(param, "ideal", False)
    shotNoise = getattr(param, "shotNoise", True)
    thermalNoise = getattr(param, "thermalNoise", True)
    currentSaturation = getattr(param, "currentSaturation", False)
    bandwidthLimitation = getattr(param, "bandwidthLimitation", True)

    assert R > 0, "PD responsivity should be a positive scalar"

    try:
        nModes = E.shape[1]
    except IndexError:
        nModes = 1

    if nModes > 1:
        ipd = R * np.sum(
            np.abs(E) ** 2, axis=1
        )  # ideal photocurrent with two or more modes
    else:
        ipd = R * E * np.conj(E)  # ideal photocurrent

    if N % 2 == 0:
        N += 1  # make sure N is odd
        logg.warning("Number of filter taps (N) was even, incrementing by one to make it odd.")

    if seed is not None:
        np.random.seed(seed)  # set seed for reproducibility

    if not (ideal):
        try:
            Fs = param.Fs
        except AttributeError:
            logg.error("Simulation sampling frequency (Fs) not provided.")

        assert Fs >= 2 * B, "Sampling frequency Fs needs to be at least twice of B."

        if currentSaturation:
            ipd[ipd > Ipd_sat] = Ipd_sat  # saturation of the photocurrent

        if shotNoise:
            # shot noise
            σ2_s = 2 * q * (ipd + Id) * B  # shot noise variance
            Is = np.sqrt(Fs * (σ2_s / (2 * B))) * np.random.normal(0, 1, ipd.shape)
            # add shot noise to photocurrent
            ipd += Is
        if thermalNoise:
            # thermal noise
            T = Tc + 273.15  # temperature in Kelvin
            σ2_T = 4 * kB * T * B / RL  # thermal noise variance
            It = np.sqrt(Fs * (σ2_T / (2 * B))) * np.random.normal(0, 1, ipd.shape)
            # add thermal noise to photocurrent
            ipd += It
        if bandwidthLimitation:
            # lowpass filtering
            h = lowPassFIR(B, Fs, N, typeF=fType)
            ipd = firFilter(h, ipd)

    return ipd.real


def balancedPD(E1, E2, param=None):
    """
    Balanced photodiode pair (BPD).

    Parameters
    ----------
    E1 : np.array
        Input optical field.
    E2 : np.array
        Input optical field.
    param : optic.utils.parameters object, optional
        Parameters of the photodiode models.

        - param.R: photodiode responsivity [A/W][default: 1 A/W].
        - param.Tc: temperature [°C][default: 25°C].
        - param.Id: dark current [A][default: 5e-9 A].
        - param.RL: impedance load [Ω] [default: 50Ω].
        - param.B bandwidth [Hz][default: 30e9 Hz].
        - param.Fs: sampling frequency [Hz] [default: 60e9 Hz].
        - param.fType: frequency response type [default: 'rect'].
        - param.N: number of the frequency resp. filter taps. [default: 255].
        - param.ideal: ideal PD?(i.e. no noise, no frequency resp.) [default: True].
        - param.seed: seed for the random number generator [default: None].

    Returns
    -------
    ibpd : np.array
           Balanced photocurrent.

    References
    ----------
    [1] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    [2] K. Kikuchi, “Fundamentals of Coherent Optical Fiber Communications”, J. Lightwave Technol., JLT, vol. 34, nº 1, p. 157–179, jan. 2016.
    """
    assert E1.shape == E2.shape, "E1 and E2 need to have the same shape"
    
    # check if input parameters are provided
    if param is None:
        paramPD1 = None
        paramPD2 = None
    else:
        # duplicate PD parameters:
        paramPD1 = param.copy()
        paramPD2 = param.copy()
    
        # check for seed in parameters
        if hasattr(paramPD1, "seed"):
            # in case the seed is provided, make sure to use different seeds for each photodiode
            if paramPD1.seed is not None:            
                paramPD2.seed = paramPD1.seed + 1  # to ensure different seeds for each photodiode

    i1 = photodiode(E1, paramPD1)
    i2 = photodiode(E2, paramPD2)

    return i1 - i2


def opticalHybrid2x4(Es, Elo):
    """
    Optical hybrid 2 x 4 90°.

    Parameters
    ----------
    Es : np.array
        Input signal optical field.
    Elo : np.array
        Input LO optical field.

    Returns
    -------
    Eo : np.array
        Optical hybrid outputs.

    References
    ----------
    [1] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    [2] K. Kikuchi, “Fundamentals of Coherent Optical Fiber Communications”, J. Lightwave Technol., JLT, vol. 34, nº 1, p. 157–179, jan. 2016.
    """
    assert Es.shape == (len(Es),), "Es need to have a (N,) shape"
    assert Elo.shape == (len(Elo),), "Elo need to have a (N,) shape"
    assert Es.shape == Elo.shape, "Es and Elo need to have the same (N,) shape"

    # optical hybrid transfer matrix
    T = np.array(
        [
            [1 / 2, 1j / 2, 1j / 2, -1 / 2],
            [1j / 2, -1 / 2, 1 / 2, 1j / 2],
            [1j / 2, 1 / 2, -1j / 2, -1 / 2],
            [-1 / 2, 1j / 2, -1 / 2, 1j / 2],
        ]
    )

    Ei = np.array([Es, np.zeros((Es.size,)), np.zeros((Es.size,)), Elo])

    return T @ Ei


def coherentReceiver(Es, Elo, param=None):
    """
    Single polarization coherent optical front-end.

    Parameters
    ----------
    Es : np.array
        Input signal optical field.
    Elo : np.array
        Input LO optical field.
    param : parameter object (struct), optional
        Parameters of the photodiodes.

    Returns
    -------
    s : np.array
        Downconverted signal after balanced detection.

    References
    ----------
    [1] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    [2] K. Kikuchi, “Fundamentals of Coherent Optical Fiber Communications”, J. Lightwave Technol., JLT, vol. 34, nº 1, p. 157–179, jan. 2016.
    """
    assert Es.shape == (len(Es),), "Es need to have a (N,) shape"
    assert Elo.shape == (len(Elo),), "Elo need to have a (N,) shape"
    assert Es.shape == Elo.shape, "Es and Elo need to have the same (N,) shape"

    # optical hybrid 2 x 4 90° 
    Eo = opticalHybrid2x4(Es, Elo)

    # balanced photodetection
    sI = balancedPD(Eo[1, :], Eo[0, :], param)
    sQ = balancedPD(Eo[2, :], Eo[3, :], param)

    return sI + 1j * sQ


def pdmCoherentReceiver(Es, Elo, θsig=0, param=None):
    """
    Polarization multiplexed coherent optical front-end.

    Parameters
    ----------
    Es : np.array
        Input signal optical field.
    Elo : np.array
        Input LO optical field.
    θsig : scalar, optional
        Input polarization rotation angle in rad. [default: 0 rad].
    param : parameter object (struct), optional
        Parameters of the photodiodes.

    Returns
    -------
    S : np.array
        Downconverted signal after balanced detection.

    References
    ----------
    [1] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    [2] K. Kikuchi, “Fundamentals of Coherent Optical Fiber Communications”, J. Lightwave Technol., JLT, vol. 34, nº 1, p. 157–179, jan. 2016.
    """
    assert len(Es) == len(Elo), "Es and Elo need to have the same length"

    Elox, Eloy = pbs(Elo, θ=np.pi / 4)  # split LO into two orth. polarizations
    Esx, Esy = pbs(Es, θ=θsig)  # split signal into two orth. polarizations

    Sx = coherentReceiver(Esx, Elox, param)  # coherent detection of pol.X
    Sy = coherentReceiver(Esy, Eloy, param)  # coherent detection of pol.Y

    return np.array([Sx, Sy]).T


def edfa(Ei, param=None):
    """
    Implement simple EDFA model.

    Parameters
    ----------
    Ei : np.array
        Input signal field.
    param : optic.utils.parameters object, optional
        Parameters of the EDFA model.

        - param.G : amplifier gain [dB][default: 20 dB]
        - param.NF : EDFA noise figure [dB][default: 4.5 dB]
        - param.Fc : central optical frequency [Hz][default: 193.1 THz]
        - param.Fs : sampling frequency in [samples/s]
        - param.seed : random seed for noise generation [default: None]

    Returns
    -------
    Eo : np.array
        Amplified noisy optical signal.

    References
    ----------
    [1] R. -J. Essiambre,et al, "Capacity Limits of Optical Fiber Networks," in Journal of Lightwave Technology, vol. 28, no. 4, pp. 662-701, 2010, doi: 10.1109/JLT.2009.2039464.

    """
    try:
        Fs = param.Fs
    except AttributeError:
        logg.error("Simulation sampling frequency (Fs) not provided.")

    # check input parameters
    G = getattr(param, "G", 20)
    NF = getattr(param, "NF", 4.5)
    Fc = getattr(param, "Fc", 193.1e12)
    seed = getattr(param, "seed", None)

    assert G > 0, "EDFA gain should be a positive scalar"
    assert NF >= 3, "The minimal EDFA noise figure is 3 dB"

    NF_lin = 10 ** (NF / 10)
    G_lin = 10 ** (G / 10)
    nsp = (G_lin * NF_lin - 1) / (2 * (G_lin - 1))

    # ASE noise power calculation:
    # Ref. Eq.(54) of R. -J. Essiambre,et al, "Capacity Limits of Optical Fiber
    # Networks," in Journal of Lightwave Technology, vol. 28, no. 4,
    # pp. 662-701, Feb.15, 2010, doi: 10.1109/JLT.2009.2039464.

    N_ase = (G_lin - 1) * nsp * const.h * Fc
    p_noise = N_ase * Fs

    noise = gaussianComplexNoise(Ei.shape, p_noise, seed)

    return Ei * np.sqrt(G_lin) + noise


def basicLaserModel(param=None):
    """
    Laser model with Maxwellian random walk phase noise and RIN.

    Parameters
    ----------
    param : optic.utils.parameters object, optional
        Parameters of the laser model.

        - param.P: laser power [dBm] [default: 10 dBm]
        - param.lw: laser linewidth [Hz] [default: 1 kHz]
        - param.RIN_var: variance of the RIN noise [default: 1e-20]
        - param.Fs: sampling rate [samples/s]
        - param.Ns: number of signal samples [default: 1e3]
        - param.seed: random seed for noise generation [default: None]

    Returns
    -------
    optical_signal : np.array
          Optical signal with phase noise and RIN.

    References
    ----------
    [1] M. Seimetz, High-Order Modulation for Optical Fiber Transmission. em Springer Series in Optical Sciences. Springer Berlin Heidelberg, 2009.

    """
    try:
        Fs = param.Fs
    except AttributeError:
        logg.error("Simulation sampling frequency (Fs) not provided.")

    P = getattr(param, "P", 10)  # Laser power in dBm
    lw = getattr(param, "lw", 1e3)  # Linewidth in Hz
    RIN_var = getattr(param, "RIN_var", 1e-20)  # RIN variance
    Ns = getattr(param, "Ns", 1000)  # Number of samples of the signal
    seed = getattr(param, "seed", None)  # Seed for the random number generator

    if seed is None:
        seedPN = None
        seedRIN = None
    else:
        seedPN = seed
        seedRIN = seed + 1  # to ensure different seeds for phase noise and RIN

    # Simulate Maxwellian random walk phase noise
    pn = phaseNoise(lw, Ns, 1 / Fs, seedPN)

    # Simulate relative intensity noise  (RIN)[todo:check correct model]
    deltaP = gaussianComplexNoise(pn.shape, RIN_var, seedRIN)

    # Return optical signal
    return np.sqrt(dBm2W(P) + deltaP) * np.exp(1j * pn)


def adc(Ei, param):
    """
    Analog-to-digital converter (ADC) model.

    Parameters
    ----------
    Ei : ndarray
        Input signal.
    param : optic.utils.parameters object, optional
        Parameters of the ADC model.

        - param.Fs_in  : sampling frequency of the input signal [samples/s][default: 1 sample/s]
        - param.Fs_out : sampling frequency of the output signal [samples/s][default: 1 sample/s]
        - param.jitter_rms : root mean square (RMS) value of the jitter in seconds [s][default: 0 s]
        - param.nBits : number of bits used for quantization [default: 8 bits]
        - param.Vmax : maximum value for the ADC's full-scale range [V][default: 1V]
        - param.Vmin : minimum value for the ADC's full-scale range [V][default: -1V]
        - param.AAF : flag indicating whether to use anti-aliasing filters [default: True]
        - param.N : number of taps of the anti-aliasing filters [default: 201]

    Returns
    -------
    Eo : ndarray
        Resampled and quantized signal.

    """
    # Check and set default values for input parameters
    param.Fs_in = getattr(param, "Fs_in", 1)
    param.Fs_out = getattr(param, "Fs_out", 1)
    param.jitter_rms = getattr(param, "jitter_rms", 0)
    param.nBits = getattr(param, "nBits", 8)
    param.Vmax = getattr(param, "Vmax", 1)
    param.Vmin = getattr(param, "Vmin", -1)
    param.AAF = getattr(param, "AAF", True)
    param.N = getattr(param, "N", 201)

    # Extract individual parameters for ease of use
    Fs_in = param.Fs_in
    Fs_out = param.Fs_out
    jitter_rms = param.jitter_rms
    nBits = param.nBits
    Vmax = param.Vmax
    Vmin = param.Vmin
    AAF = param.AAF
    N = param.N

    # Reshape the input signal if needed to handle single-dimensional inputs
    try:
        Ei.shape[1]
    except IndexError:
        Ei = Ei.reshape(len(Ei), 1)

    # Get the number of modes (columns) in the input signal
    nModes = Ei.shape[1]

    # Apply anti-aliasing filters if AAF is enabled
    if AAF:
        # Anti-aliasing filters:
        Ntaps = min(Ei.shape[0], N)
        hi = lowPassFIR(param.Fs_out / 2, param.Fs_in, Ntaps, typeF="rect")
        ho = lowPassFIR(param.Fs_out / 2, param.Fs_out, Ntaps, typeF="rect")

        Ei = firFilter(hi, Ei)

    if np.iscomplexobj(Ei):
        # Signal interpolation to the ADC's sampling frequency
        Eo = clockSamplingInterp(
            Ei.reshape(-1, nModes).real, Fs_in, Fs_out, jitter_rms
        ) + 1j * clockSamplingInterp(
            Ei.reshape(-1, nModes).imag, Fs_in, Fs_out, jitter_rms
        )

        # Uniform quantization of the signal according to the number of bits of the ADC
        Eo = quantizer(Eo.real, nBits, Vmax, Vmin) + 1j * quantizer(
            Eo.imag, nBits, Vmax, Vmin
        )
    else:
        # Signal interpolation to the ADC's sampling frequency
        Eo = clockSamplingInterp(Ei.reshape(-1, nModes), Fs_in, Fs_out, jitter_rms)

        # Uniform quantization of the signal according to the number of bits of the ADC
        Eo = quantizer(Eo, nBits, Vmax, Vmin)

    # Apply anti-aliasing filters to the output if AAF is enabled
    if AAF:
        Eo = firFilter(ho, Eo)

    if Eo.shape[1] == 1:
        # If the output is a single column, return it as a 1D array
        Eo = Eo.flatten()

    return Eo
