# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import functools
import pathlib

import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup

from ..scripts.click import ConfigCommand

logger = setup("mednet", format="%(levelname)s: %(message)s")


def reusable_options(f):
    """Wrap reusable training script options (for ``experiment``).

    This decorator equips the target function ``f`` with all (reusable)
    ``train`` script options.

    Parameters
    ----------
    f
        The target function to equip with options.  This function must have
        parameters that accept such options.

    Returns
    -------
        The decorated version of function ``f``
    """

    @click.option(
        "--output-folder",
        "-o",
        help="Directory in which to store results (created if does not exist)",
        required=True,
        type=click.Path(
            file_okay=False,
            dir_okay=True,
            writable=True,
            path_type=pathlib.Path,
        ),
        default="results",
        show_default=True,
        cls=ResourceOption,
    )
    @click.option(
        "--model",
        "-m",
        help="A lightning module instance implementing the network to be trained",
        required=True,
        cls=ResourceOption,
    )
    @click.option(
        "--datamodule",
        "-d",
        help="A lightning DataModule containing the training and validation sets.",
        required=True,
        cls=ResourceOption,
    )
    @click.option(
        "--batch-size",
        "-b",
        help="Number of samples in every batch (this parameter affects "
        "memory requirements for the network).  If the number of samples in "
        "the batch is larger than the total number of samples available for "
        "training, this value is truncated.  If this number is smaller, then "
        "batches of the specified size are created and fed to the network "
        "until there are no more new samples to feed (epoch is finished).  "
        "If the total number of training samples is not a multiple of the "
        "batch-size, the last batch will be smaller than the first, unless "
        "--drop-incomplete-batch is set, in which case this batch is not used.",
        required=True,
        show_default=True,
        default=1,
        type=click.IntRange(min=1),
        cls=ResourceOption,
    )
    @click.option(
        "--accumulate-grad-batches",
        "-a",
        help="Number of accumulations for backward propagation to accumulate "
        "gradients over k batches before stepping the optimizer. This "
        "parameter, used in conjunction with the batch-size, may be used to "
        "reduce the number of samples loaded in each iteration, to affect memory "
        "usage in exchange for processing time (more iterations). This is "
        "useful interesting when one is training on GPUs with a limited amount "
        "of onboard RAM. The default of 1 forces the whole batch to be "
        "processed at once. Otherwise the batch is multiplied by "
        "accumulate-grad-batches pieces, and gradients are accumulated "
        "to complete each training step.",
        required=True,
        show_default=True,
        default=1,
        type=click.IntRange(min=1),
        cls=ResourceOption,
    )
    @click.option(
        "--drop-incomplete-batch/--no-drop-incomplete-batch",
        "-D",
        help="If set, the last batch in an epoch will be dropped if "
        "incomplete.  If you set this option, you should also consider "
        "increasing the total number of epochs of training, as the total number "
        "of training steps may be reduced.",
        required=True,
        show_default=True,
        default=False,
        cls=ResourceOption,
    )
    @click.option(
        "--epochs",
        "-e",
        help="""Number of epochs (complete training set passes) to train for.
        If continuing from a saved checkpoint, ensure to provide a greater
        number of epochs than was saved in the checkpoint to be loaded.""",
        show_default=True,
        required=True,
        default=1000,
        type=click.IntRange(min=1),
        cls=ResourceOption,
    )
    @click.option(
        "--validation-period",
        "-p",
        help="""Number of epochs after which validation happens.  By default,
        we run validation after every training epoch (period=1).  You can
        change this to make validation more sparse, by increasing the
        validation period. Notice that this affects checkpoint saving.  While
        checkpoints are created after every training step (the last training
        step always triggers the overriding of latest checkpoint), and
        this process is independent of validation runs, evaluation of the
        'best' model obtained so far based on those will be influenced by this
        setting.""",
        show_default=True,
        required=True,
        default=1,
        type=click.IntRange(min=1),
        cls=ResourceOption,
    )
    @click.option(
        "--device",
        "-x",
        help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
        show_default=True,
        required=True,
        default="cpu",
        cls=ResourceOption,
    )
    @click.option(
        "--cache-samples/--no-cache-samples",
        help="If set to True, loads the sample into memory, "
        "otherwise loads them at runtime.",
        required=True,
        show_default=True,
        default=False,
        cls=ResourceOption,
    )
    @click.option(
        "--seed",
        "-s",
        help="Seed to use for the random number generator",
        show_default=True,
        required=False,
        default=42,
        type=click.IntRange(min=0),
        cls=ResourceOption,
    )
    @click.option(
        "--parallel",
        "-P",
        help="""Use multiprocessing for data loading: if set to -1 (default),
        disables multiprocessing data loading.  Set to 0 to enable as many data
        loading instances as processing cores available in the system.  Set to
        >= 1 to enable that many multiprocessing instances for data
        loading.""",
        type=click.IntRange(min=-1),
        show_default=True,
        required=True,
        default=-1,
        cls=ResourceOption,
    )
    @click.option(
        "--monitoring-interval",
        "-I",
        help="""Time between checks for the use of resources during each training
        epoch, in seconds.  An interval of 5 seconds, for example, will lead to
        CPU and GPU resources being probed every 5 seconds during each training
        epoch. Values registered in the training logs correspond to averages
        (or maxima) observed through possibly many probes in each epoch.
        Notice that setting a very small value may cause the probing process to
        become extremely busy, potentially biasing the overall perception of
        resource usage.""",
        type=click.FloatRange(min=0.1),
        show_default=True,
        required=True,
        default=5.0,
        cls=ResourceOption,
    )
    @click.option(
        "--augmentations",
        "-A",
        help="""Models that can be trained in this package are shipped without
        explicit data augmentations. This option allows you to define a list of
        data augmentations to use for training the selected model.""",
        type=click.UNPROCESSED,
        default=[],
        cls=ResourceOption,
    )
    @click.option(
        "--balance-classes/--no-balance-classes",
        "-B/-N",
        help="""If set, balances weights of the random sampler during
        training so that samples from all sample classes are picked
        equitably.""",
        required=True,
        show_default=True,
        default=True,
        cls=ResourceOption,
    )
    @functools.wraps(f)
    def wrapper_reusable_options(*args, **kwargs):
        return f(*args, **kwargs)

    return wrapper_reusable_options


@click.command(
    entry_point_group="mednet.config",
    cls=ConfigCommand,
    epilog="""Examples:

1. Train a Pasa model with the montgomery dataset (classification task):

   .. code:: sh

      mednet train -vv pasa montgomery

1. Train a Little WNet model with the drive dataset (vessel segmentation task):

   .. code:: sh

      mednet train -vv lwnet drive
""",
)
@reusable_options
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train(
    model,
    output_folder,
    epochs,
    batch_size,
    accumulate_grad_batches,
    drop_incomplete_batch,
    datamodule,
    validation_period,
    device,
    cache_samples,
    seed,
    parallel,
    monitoring_interval,
    augmentations,
    balance_classes,
    **_,
) -> None:  # numpydoc ignore=PR01
    """Train a model on a given datamodule (task-specific).

    Training is performed for a configurable number of epochs, and
    generates checkpoints.  Checkpoints are model files with a .ckpt
    extension that are used in subsequent tasks or from which training
    can be resumed.
    """
    from lightning.pytorch import seed_everything

    from ..engine.device import DeviceManager
    from ..engine.trainer import (
        get_checkpoint_file,
        load_checkpoint,
        run,
        setup_datamodule,
        validate_model_datamodule,
    )
    from ..scripts.utils import save_json_metadata

    validate_model_datamodule(model, datamodule)

    seed_everything(seed)

    # report model/transforms options - set data augmentations
    logger.info(f"Network model: {type(model).__module__}.{type(model).__name__}")
    model.augmentation_transforms = augmentations

    device_manager = DeviceManager(device)

    # reset datamodule with user configurable options
    setup_datamodule(
        datamodule, model, batch_size, drop_incomplete_batch, cache_samples, parallel
    )

    # If asked, rebalances the loss criterion based on the relative proportion
    # of class examples available in the training set.  Also affects the
    # validation loss if a validation set is available on the DataModule.
    if balance_classes:
        if datamodule.task == "classification":
            logger.info("Applying train/valid loss balancing...")
            model.balance_losses(datamodule)
        else:
            logger.warn(f"Skipping loss balancing for {datamodule.task} task...")
            balance_classes = False

    checkpoint_file = get_checkpoint_file(output_folder)
    load_checkpoint(checkpoint_file, datamodule, model)

    # stores all information we can think of, to reproduce this later
    save_json_metadata(
        output_file=output_folder / "train.meta.json",
        datamodule=datamodule,
        model=model,
        augmentations=augmentations,
        device_manager=device_manager,
        output_folder=output_folder,
        epochs=epochs,
        batch_size=batch_size,
        accumulate_grad_batches=accumulate_grad_batches,
        drop_incomplete_batch=drop_incomplete_batch,
        validation_period=validation_period,
        cache_samples=cache_samples,
        seed=seed,
        parallel=parallel,
        monitoring_interval=monitoring_interval,
        balance_classes=balance_classes,
    )

    logger.info(f"Training for at most {epochs} epochs.")

    run(
        model=model,
        datamodule=datamodule,
        validation_period=validation_period,
        device_manager=device_manager,
        max_epochs=epochs,
        output_folder=output_folder,
        monitoring_interval=monitoring_interval,
        accumulate_grad_batches=accumulate_grad_batches,
        checkpoint=checkpoint_file,
    )
