# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import json
import pathlib
import typing

import click
import tqdm
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup

from ...engine.segment.evaluator import SUPPORTED_METRIC_TYPE
from ...scripts.click import ConfigCommand

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


@click.command(
    entry_point_group="mednet.config",
    cls=ConfigCommand,
    epilog="""Examples:

\b
  1. Produces visualisations for the given predictions:

     .. code:: sh

        $ mednet segment view -vv --predictions=path/to/predictions.json --output-folder=path/to/results

  2. Produces visualisations for the given (dumped) annotations:

     .. code:: sh

        $ mednet segment view -vv --predictions=path/to/annotations.json --output-folder=path/to/results
""",
)
@click.option(
    "--predictions",
    "-p",
    help="""Path to the JSON file describing available predictions. The actual
    predictions are supposed to lie on the same folder.""",
    required=True,
    type=click.Path(
        file_okay=True,
        dir_okay=False,
        writable=False,
        path_type=pathlib.Path,
    ),
    show_default=True,
    cls=ResourceOption,
)
@click.option(
    "--output-folder",
    "-o",
    help="Directory in which to store results (created if does not exist)",
    required=True,
    type=click.Path(
        file_okay=False,
        dir_okay=True,
        writable=True,
        path_type=pathlib.Path,
    ),
    default="results",
    cls=ResourceOption,
)
@click.option(
    "--threshold",
    "-t",
    help="""This number is used to define positives and negatives from
    probability maps, and used to report metrics based on a threshold chosen *a
    priori*. It can be set to a floating-point value, or to the name of dataset
    split in ``--predictions``.
    """,
    default="0.5",
    show_default=True,
    required=False,
    cls=ResourceOption,
)
@click.option(
    "--metric",
    "-m",
    help="""If threshold is set to the name of a split in ``--predictions``,
    then this parameter defines the metric function to be used to evaluate the
    threshold at which the metric reaches its maximum value. All other splits
    are evaluated with respect to this threshold.""",
    default="f1",
    type=click.Choice(typing.get_args(SUPPORTED_METRIC_TYPE), case_sensitive=True),
    show_default=True,
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--steps",
    "-s",
    help="""Number of steps for evaluating metrics on various splits. This
    value is used only when evaluating thresholds on a datamodule split.""",
    default=100,
    type=click.IntRange(10),
    show_default=True,
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--show-errors/--no-show-errors",
    "-e/-E",
    help="""If set, then shows a colorized version of the segmentation map in
    which false-positives are marked in red, and false-negatives in green.
    True positives are always marked in white.""",
    default=False,
    show_default=True,
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--alpha",
    "-a",
    help="""Defines the transparency weighting between the original image and
    the predicted segmentation maps. A value of 1.0 makes the program output
    only segmentation maps.  A value of 0.0 makes the program output only the
    processed image.""",
    default=0.6,
    type=click.FloatRange(0.0, 1.0),
    show_default=True,
    required=True,
    cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def view(
    predictions: pathlib.Path,
    output_folder: pathlib.Path,
    threshold: str | float,
    metric: str,
    steps: int,
    show_errors: bool,
    alpha: float,
    **_,  # ignored
):  # numpydoc ignore=PR01
    """Display predictions generated by a model."""

    import numpy

    from ...engine.segment.evaluator import (
        compute_metric,
        load_count,
        name2metric,
        validate_threshold,
    )
    from ...engine.segment.viewer import view
    from ...scripts.utils import execution_metadata, save_json_with_backup

    view_filename = "view.json"
    view_file = output_folder / view_filename

    with predictions.open("r") as f:
        predict_data = json.load(f)

    # register metadata
    json_data: dict[str, typing.Any] = execution_metadata()
    json_data.update(
        dict(
            predictions=str(predictions),
            output_folder=str(output_folder),
            threshold=threshold,
            metric=metric,
            steps=steps,
        ),
    )
    json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
    save_json_with_backup(view_file.with_suffix(".meta.json"), json_data)

    threshold = validate_threshold(threshold, predict_data)
    threshold_list = numpy.arange(
        0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64
    )

    if isinstance(threshold, str):
        # Compute threshold on specified split, if required
        logger.info(f"Evaluating threshold on `{threshold}` split using " f"`{metric}`")
        counts = load_count(predictions.parent, predict_data[threshold], threshold_list)
        metric_list = compute_metric(
            counts, name2metric(typing.cast(SUPPORTED_METRIC_TYPE, metric))
        )
        threshold_index = metric_list.argmax()
        logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")

    else:
        # must figure out the closest threshold from the list we are using
        threshold_index = (numpy.abs(threshold_list - threshold)).argmin()
        logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")

    # create visualisations
    for split_name, sample_list in predict_data.items():
        logger.info(
            f"Creating {len(sample_list)} visualisations for split `{split_name}`"
        )
        for sample in tqdm.tqdm(sample_list):
            image = view(
                predictions.parent,
                sample[1],
                threshold=threshold_list[threshold_index],
                show_errors=show_errors,
                tp_color=(255, 255, 255),
                fp_color=(255, 0, 0),
                fn_color=(0, 255, 0),
                alpha=alpha,
            )
            dest = (output_folder / sample[1]).with_suffix(".png")
            dest.parent.mkdir(parents=True, exist_ok=True)
            tqdm.tqdm.write(f"{sample[1]} -> {dest}")
            image.save(dest)
