# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for drhagis dataset."""

import importlib

import pytest
from click.testing import CliRunner


def id_function(val):
    if isinstance(val, dict):
        return str(val)
    return repr(val)


@pytest.mark.parametrize(
    "split,lengths",
    [
        ("default", dict(train=19, test=20)),
    ],
    ids=id_function,  # just changes how pytest prints it
)
def test_protocol_consistency(
    database_checkers,
    split: str,
    lengths: dict[str, int],
):
    from mednet.data.split import make_split

    database_checkers.check_split(
        make_split("mednet.config.segment.data.drhagis", f"{split}.json"),
        lengths=lengths,
        prefixes=["Fundus_Images/"],
        possible_labels=[],
    )


@pytest.mark.skip_if_rc_var_not_set("datadir.drhagis")
def test_database_check():
    from mednet.scripts.database import check

    runner = CliRunner()
    result = runner.invoke(check, ["drhagis"])
    assert (
        result.exit_code == 0
    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"


@pytest.mark.skip_if_rc_var_not_set("datadir.drhagis")
@pytest.mark.parametrize(
    "dataset",
    [
        "train",
        "test",
    ],
)
@pytest.mark.parametrize(
    "name",
    [
        "default",
    ],
)
def test_loading(database_checkers, name: str, dataset: str):
    datamodule = importlib.import_module(
        f".{name}",
        "mednet.config.segment.data.drhagis",
    ).datamodule

    datamodule.model_transforms = []  # should be done before setup()
    datamodule.setup("predict")  # sets up all datasets

    loader = datamodule.predict_dataloader()[dataset]

    limit = 3  # limit load checking
    for batch in loader:
        if limit == 0:
            break
        database_checkers.check_loaded_batch(
            batch,
            batch_size=1,
            color_planes=3,
            expected_num_labels=1,
            expected_meta_size=1,
            prefixes=["Fundus_Images/"],
            possible_labels=[],
        )
        limit -= 1


@pytest.mark.skip_if_rc_var_not_set("datadir.drhagis")
def test_raw_transforms_image_quality(database_checkers, datadir):
    reference_histogram_file = (
        datadir / "histograms/raw_data/histograms_drhagis_default.json"
    )

    datamodule = importlib.import_module(
        ".default",
        "mednet.config.segment.data.drhagis",
    ).datamodule

    datamodule.model_transforms = []
    datamodule.setup("predict")

    database_checkers.check_image_quality(datamodule, reference_histogram_file)


@pytest.mark.skip_if_rc_var_not_set("datadir.drhagis")
@pytest.mark.parametrize(
    "model_name",
    ["lwnet"],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
    reference_histogram_file = (
        datadir / f"histograms/models/histograms_{model_name}_drhagis_default.json"
    )

    datamodule = importlib.import_module(
        ".default", "mednet.config.segment.data.drhagis"
    ).datamodule

    model = importlib.import_module(
        f".{model_name}", "mednet.config.segment.models"
    ).model

    datamodule.model_transforms = model.model_transforms
    datamodule.setup("predict")

    database_checkers.check_image_quality(
        datamodule,
        reference_histogram_file,
        compare_type="statistical",
        pearson_coeff_threshold=0.005,
    )
