"""Collection of Cell Readers from different sources (Pure HDF5, SynTool...)"""

import logging
from collections import defaultdict
from os import path as ospath

import libsonata
import numpy as np

from neurodamus.core import NeuronWrapper as Nd, run_only_rank0
from neurodamus.core.configuration import SimConfig
from neurodamus.metype import METypeManager
from neurodamus.utils.logging import log_verbose

EMPTY_GIDVEC = np.empty(0, dtype="uint32")


def split_round_robin(all_gids, stride=1, stride_offset=0, total_cells=None):
    """Splits a numpy ndarray[uint32] round-robin.
    If the array is None generates new arrays based on the nr of total cells
    """
    if all_gids is not None:
        gidvec = all_gids[stride_offset::stride] if stride > 1 else all_gids
        gidvec.sort()
    else:
        assert total_cells, "split_round_robin: total_cells required without gids"
        cell_i = stride_offset + 1  # gids start from 1
        gidvec = np.arange(cell_i, total_cells + 1, stride, dtype="uint32")
    return gidvec


def dry_run_distribution(gid_metype_bundle, stride=1, stride_offset=0):
    """Distribute gid in metype bundles for dry run.

    The principle is the following: all gids with the same metype
    have to be assigned to the same rank. This function receives
    a list of list of gids, each sublist containing gids of the same
    metype. The gid_metype_bundle list of lists is generated by the
    retrieve_unique_metype function. This function performs a
    round robin distribution of the inner lists, i.e. it returns
    a list of gids that are sequentially in the same metype.
    The return is a flattened numpy array of gids that shall be
    instantiated on the same rank.

    Example:
        gid_metype_bundle = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]

        stride = 2
        stride_offset = 0
        return = [1, 2, 3, 7, 8, 9]

        stride = 2
        stride_offset = 1
        return = [4, 5, 6, 10]

    Args:
        gid_metype_bundle: list of lists of gids to be distributed
        mpi_size: MPI size
        mpi_rank: MPI rank
    Returns:
        A numpy array of gids that are sequentially in the same metype
    """
    if not gid_metype_bundle:
        return EMPTY_GIDVEC

    # if mpi_size is 1, return all gids flattened
    if stride == 1:
        return np.concatenate(gid_metype_bundle)
    groups = gid_metype_bundle[stride_offset::stride]
    return np.concatenate(groups) if groups else EMPTY_GIDVEC


def load_sonata(  # noqa: C901, PLR0915
    circuit_conf,
    all_gids,
    stride=1,
    stride_offset=0,
    *,
    node_population,
    load_dynamic_props=(),
    has_extra_data=False,
    dry_run_stats=None,
    load_mode=None,
):
    """A reader supporting additional dynamic properties from Sonata files."""
    import libsonata

    node_file = circuit_conf.CellLibraryFile
    node_store = libsonata.NodeStorage(node_file)
    node_pop = node_store.open_population(node_population)
    attr_names = node_pop.attribute_names
    dynamics_attr_names = node_pop.dynamics_attribute_names
    total_cells = node_pop.size

    def load_base_info_dry_run():
        CELL_NODE_INFO_LIMIT = 100
        log_verbose("Sonata dry run mode: looking for unique metype instances")
        meinfos = METypeManager()
        # Get global METype counts (computed in rank0, broadcasted)
        metype_gids, counts = _retrieve_unique_metypes(node_pop, all_gids)
        dry_run_stats.metype_counts += counts
        dry_run_stats.pop_metype_gids[node_population] = metype_gids
        gid_metype_bundle = list(metype_gids.values())
        gidvec = dry_run_distribution(gid_metype_bundle, stride, stride_offset)

        log_verbose("Loading node attributes... (subset of cells from each metype)")
        for gids in metype_gids.values():
            if not len(gids):
                continue
            gids = gids[:CELL_NODE_INFO_LIMIT]
            node_sel = libsonata.Selection(gids - 1)  # Load 0-based node ids
            morpho_names = node_pop.get_attribute("morphology", node_sel)
            mtypes = node_pop.get_attribute("mtype", node_sel)
            etypes = node_pop.get_attribute("etype", node_sel)
            model_templates = node_pop.get_attribute("model_template", node_sel)
            emodel_templates = [emodel.removeprefix("hoc:") for emodel in model_templates]
            meinfos.load_infoNP(gids, morpho_names, emodel_templates, mtypes, etypes)

        return gidvec, meinfos, total_cells

    def load_nodes_base_info():
        if SimConfig.dry_run or load_mode == "load_nodes_metype":
            return load_base_info_dry_run()

        meinfos = METypeManager()
        gidvec = split_round_robin(all_gids, stride, stride_offset, total_cells)

        if not len(gidvec):
            # Not enough cells to give this rank a few
            return gidvec, meinfos, total_cells

        log_verbose("Loading nodes info")
        node_sel = libsonata.Selection(gidvec - 1)  # 0-based node indices
        morpho_names = node_pop.get_attribute("morphology", node_sel)
        mtypes = node_pop.get_attribute("mtype", node_sel)
        try:
            etypes = node_pop.get_attribute("etype", node_sel)
        except libsonata.SonataError:
            logging.warning("etype not found in node population, setting to None")
            etypes = None
        model_templates = node_pop.get_attribute("model_template", node_sel)
        emodel_templates = [emodel.removeprefix("hoc:") for emodel in model_templates]
        if {"exc_mini_frequency", "inh_mini_frequency"}.issubset(attr_names):
            exc_mini_freqs = node_pop.get_attribute("exc_mini_frequency", node_sel)
            inh_mini_freqs = node_pop.get_attribute("inh_mini_frequency", node_sel)
        else:
            exc_mini_freqs = None
            inh_mini_freqs = None
        if {"threshold_current", "holding_current"}.issubset(dynamics_attr_names):
            threshold_currents = node_pop.get_dynamics_attribute("threshold_current", node_sel)
            holding_currents = node_pop.get_dynamics_attribute("holding_current", node_sel)
        else:
            threshold_currents = None
            holding_currents = None
        positions = np.array(
            [
                node_pop.get_attribute("x", node_sel),
                node_pop.get_attribute("y", node_sel),
                node_pop.get_attribute("z", node_sel),
            ]
        ).T
        rotations = _get_rotations(node_pop, node_sel)

        # For Sonata and new emodel hoc template, we need additional attributes for building metype
        # TODO: validate it's really the emodel_templates var we should pass here, or etype
        add_params_list = (
            None
            if not has_extra_data
            else _getNeededAttributes(
                node_pop, circuit_conf.METypePath, emodel_templates, gidvec - 1
            )
        )

        meinfos.load_infoNP(
            gidvec,
            morpho_names,
            emodel_templates,
            mtypes,
            etypes,
            threshold_currents,
            holding_currents,
            exc_mini_freqs,
            inh_mini_freqs,
            positions,
            rotations,
            add_params_list,
        )
        return gidvec, meinfos, total_cells

    # If dynamic properties are not specified simply return early
    if not load_dynamic_props:
        return load_nodes_base_info()

    # Check properties exist, eventually removing prefix
    def validate_property(prop_name):
        if prop_name.startswith("@dynamics:"):
            actual_prop_name = prop_name[len("@dynamics:") :]  # remove prefix
            if actual_prop_name not in dynamics_attr_names:
                raise Exception(f"Required Dynamics property {prop_name} not present")
        elif prop_name not in attr_names:
            raise Exception(f"Required extra property {prop_name} not present")

    [validate_property(p) for p in load_dynamic_props]

    # All good. Lets start reading!
    gidvec, meinfos, fullsize = load_nodes_base_info()

    if SimConfig.dry_run:
        load_nodes = np.fromiter(meinfos.keys(), dtype="uint32") - 1
        node_sel = libsonata.Selection(load_nodes)
    else:
        node_sel = libsonata.Selection(gidvec - 1)  # 0-based node indices

    for prop_name in load_dynamic_props:
        log_verbose("Loading extra property: %s ", prop_name)
        if prop_name.startswith("@dynamics:"):
            prop_name = prop_name[len("@dynamics:") :]
            prop_data = node_pop.get_dynamics_attribute(prop_name, node_sel)
        else:
            prop_data = node_pop.get_attribute(prop_name, node_sel)
        for gid, val in zip(meinfos.keys(), prop_data):
            meinfos[gid].extra_attrs[prop_name] = val

    return gidvec, meinfos, fullsize


def _getNeededAttributes(node_reader, etype_path, emodels, gidvec):
    """Read additional attributes required by emodel templates global var <emodel>__NeededAttributes
    Args:
        node_reader: libsonata node population
        etype_path: Location of emodel hoc templates
        emodels: Array of emodel names
        gidvec: Array of 0-based cell gids
    """
    add_params_list = []
    for gid, emodel in zip(gidvec, emodels):
        Nd.h.load_file(ospath.join(etype_path, emodel) + ".hoc")  # hoc doesn't throw
        attr_names = getattr(Nd, emodel + "_NeededAttributes", None)  # format "attr1;attr2;attr3"
        vals = []
        if attr_names is not None:
            vals = [node_reader.get_dynamics_attribute(name, gid) for name in attr_names.split(";")]
        add_params_list.append(vals)
    return add_params_list


def _get_rotations(node_reader, selection):
    """Get quaternions to rotate the cells

    Args:
        node_reader: libsonata node population
        selection: libsonata selection

    Returns:
        double vector of size [N][4] with the rotation quaternions in the order (x,y,z,w)
    """
    attr_names = node_reader.attribute_names
    if {"orientation_x", "orientation_y", "orientation_z", "orientation_w"}.issubset(attr_names):
        # Preferred way to present the rotation as quaternions
        return np.array(
            [
                node_reader.get_attribute("orientation_x", selection),
                node_reader.get_attribute("orientation_y", selection),
                node_reader.get_attribute("orientation_z", selection),
                node_reader.get_attribute("orientation_w", selection),
            ]
        ).T

    if {"rotation_angle_xaxis", "rotation_angle_yaxis", "rotation_angle_zaxis"}.intersection(
        attr_names
    ):
        # Some sonata nodes files use the Euler angle rotations, convert them to quaternions
        from scipy.spatial.transform import Rotation

        angle_x = (
            node_reader.get_attribute("rotation_angle_xaxis", selection)
            if "rotation_angle_xaxis" in attr_names
            else 0
        )
        angle_y = (
            node_reader.get_attribute("rotation_angle_yaxis", selection)
            if "rotation_angle_yaxis" in attr_names
            else 0
        )
        angle_z = (
            node_reader.get_attribute("rotation_angle_zaxis", selection)
            if "rotation_angle_yaxis" in attr_names
            else 0
        )
        euler_rots = np.array([angle_x, angle_y, angle_z]).T
        return Rotation.from_euler("xyz", euler_rots).as_quat()

    return None


@run_only_rank0
def _retrieve_unique_metypes(node_reader, all_gids, skip_metypes=()) -> dict:
    """Find unique mtype+emodel combinations in target to estimate resources in dry run.
    This function returns a list of lists of unique mtype+emodel combinations.
    Each of the inner lists contains gid for the same mtype+emodel combinations.

    Args:
        node_reader: node reader, libsonata only
        all_gids: list of all gids in target
    Returns:
        list of lists of unique mtype+emodel combinations
    """
    gidvec = np.array(all_gids)
    indexes = gidvec - 1
    if len(indexes) < 10:  # Ensure array is not too small (pybind11 #1392)
        indexes = indexes.tolist()

    if isinstance(node_reader, libsonata.NodePopulation):
        etypes = node_reader.get_attribute("etype", libsonata.Selection(indexes))
        mtypes = node_reader.get_attribute("mtype", libsonata.Selection(indexes))
    else:
        msg = f"Reader type {type(node_reader)} incompatible with dry run."
        raise TypeError(msg)

    gids_per_metype = defaultdict(list)
    count_per_metype = defaultdict(int)
    for gid, mtype, etype in zip(gidvec, mtypes, etypes):
        metype = f"{mtype}-{etype}"
        gids_per_metype[metype].append(gid)
        count_per_metype[metype] += 1

    logging.info(
        "Out of %d cells, found %d unique mtype+emodel combination",
        len(gidvec),
        len(gids_per_metype),
    )
    for metype, gid_list in gids_per_metype.items():
        logging.debug(
            "METype: %-20s instances: %-8d gids: %s", metype, len(gid_list), gid_list[:10]
        )

    # If the list is longer than 50, truncate it to 50 elements.
    # If the metype is already computed, skip it
    gid_metype_instantiate = {}
    for metype, gid_list in gids_per_metype.items():
        if metype not in skip_metypes:
            gid_metype_instantiate[metype] = np.array(gid_list, dtype="uint32")
        else:
            log_verbose("Skipping METype '%s' since it's already known", metype)

    return gid_metype_instantiate, count_per_metype
