# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488

from functools import reduce

import torch
from torch.nn.modules.module import _addindent


# ignore this space!
def _repr(model: torch.nn.Module) -> tuple[str, int]:
    # We treat the extra repr like the sub-module, one item per line
    extra_lines = []
    extra_repr = model.extra_repr()
    # empty string will be split into list ['']
    if extra_repr:
        extra_lines = extra_repr.split("\n")
    child_lines = []
    total_params = 0
    for key, module in model._modules.items():  # noqa: SLF001
        mod_str, num_params = _repr(module)
        mod_str = _addindent(mod_str, 2)
        child_lines.append("(" + key + "): " + mod_str)
        total_params += num_params
    lines = extra_lines + child_lines

    for _, p in model._parameters.items():  # noqa: SLF001
        if hasattr(p, "dtype"):
            total_params += reduce(lambda x, y: x * y, p.shape)

    main_str = model._get_name() + "("  # noqa: SLF001
    if lines:
        # simple one-liner info, which most builtin Modules will use
        if len(extra_lines) == 1 and not child_lines:
            main_str += extra_lines[0]
        else:
            main_str += "\n  " + "\n  ".join(lines) + "\n"

    main_str += ")"
    main_str += f", {total_params:,} params"
    return main_str, total_params


def summary(model: torch.nn.Module) -> tuple[str, int]:
    """Count the number of parameters in each model layer.

    Parameters
    ----------
    model
        Model to summarize.

    Returns
    -------
    tuple[int, str]
        A tuple containing a multiline string representation of the network and the number of parameters.
    """
    return _repr(model)
