import os

from anemoi.utils.config import DotDict

import bris.checkpoint


def test_metadata():
    filename = os.path.dirname(os.path.abspath(__file__)) + "/files/checkpoint.ckpt"
    checkpoint = bris.checkpoint.Checkpoint(path=filename)

    n2i = (
        {
            "10u": 0,
            "10v": 1,
            "2d": 2,
            "2t": 3,
            "cos_julian_day": 4,
            "cos_latitude": 5,
            "cos_local_time": 6,
            "cos_longitude": 7,
            "cp": 8,
            "insolation": 9,
            "lsm": 10,
            "msl": 11,
            "q_100": 12,
            "q_1000": 13,
            "q_150": 14,
            "q_200": 15,
            "q_250": 16,
            "q_300": 17,
            "q_400": 18,
            "q_50": 19,
            "q_500": 20,
            "q_600": 21,
            "q_700": 22,
            "q_850": 23,
            "q_925": 24,
            "sdor": 25,
            "sin_julian_day": 26,
            "sin_latitude": 27,
            "sin_local_time": 28,
            "sin_longitude": 29,
            "skt": 30,
            "slor": 31,
            "sp": 32,
            "t_100": 33,
            "t_1000": 34,
            "t_150": 35,
            "t_200": 36,
            "t_250": 37,
            "t_300": 38,
            "t_400": 39,
            "t_50": 40,
            "t_500": 41,
            "t_600": 42,
            "t_700": 43,
            "t_850": 44,
            "t_925": 45,
            "tcw": 46,
            "tp": 47,
            "u_100": 48,
            "u_1000": 49,
            "u_150": 50,
            "u_200": 51,
            "u_250": 52,
            "u_300": 53,
            "u_400": 54,
            "u_50": 55,
            "u_500": 56,
            "u_600": 57,
            "u_700": 58,
            "u_850": 59,
            "u_925": 60,
            "v_100": 61,
            "v_1000": 62,
            "v_150": 63,
            "v_200": 64,
            "v_250": 65,
            "v_300": 66,
            "v_400": 67,
            "v_50": 68,
            "v_500": 69,
            "v_600": 70,
            "v_700": 71,
            "v_850": 72,
            "v_925": 73,
            "w_100": 74,
            "w_1000": 75,
            "w_150": 76,
            "w_200": 77,
            "w_250": 78,
            "w_300": 79,
            "w_400": 80,
            "w_50": 81,
            "w_500": 82,
            "w_600": 83,
            "w_700": 84,
            "w_850": 85,
            "w_925": 86,
            "z": 87,
            "z_100": 88,
            "z_1000": 89,
            "z_150": 90,
            "z_200": 91,
            "z_250": 92,
            "z_300": 93,
            "z_400": 94,
            "z_50": 95,
            "z_500": 96,
            "z_600": 97,
            "z_700": 98,
            "z_850": 99,
            "z_925": 100,
        },
    )
    i2n = (
        {
            0: "10u",
            1: "10v",
            2: "2d",
            3: "2t",
            4: "cos_julian_day",
            5: "cos_latitude",
            6: "cos_local_time",
            7: "cos_longitude",
            8: "cp",
            9: "insolation",
            10: "lsm",
            11: "msl",
            12: "q_100",
            13: "q_1000",
            14: "q_150",
            15: "q_200",
            16: "q_250",
            17: "q_300",
            18: "q_400",
            19: "q_50",
            20: "q_500",
            21: "q_600",
            22: "q_700",
            23: "q_850",
            24: "q_925",
            25: "sdor",
            26: "sin_julian_day",
            27: "sin_latitude",
            28: "sin_local_time",
            29: "sin_longitude",
            30: "skt",
            31: "slor",
            32: "sp",
            33: "t_100",
            34: "t_1000",
            35: "t_150",
            36: "t_200",
            37: "t_250",
            38: "t_300",
            39: "t_400",
            40: "t_50",
            41: "t_500",
            42: "t_600",
            43: "t_700",
            44: "t_850",
            45: "t_925",
            46: "tcw",
            47: "tp",
            48: "u_100",
            49: "u_1000",
            50: "u_150",
            51: "u_200",
            52: "u_250",
            53: "u_300",
            54: "u_400",
            55: "u_50",
            56: "u_500",
            57: "u_600",
            58: "u_700",
            59: "u_850",
            60: "u_925",
            61: "v_100",
            62: "v_1000",
            63: "v_150",
            64: "v_200",
            65: "v_250",
            66: "v_300",
            67: "v_400",
            68: "v_50",
            69: "v_500",
            70: "v_600",
            71: "v_700",
            72: "v_850",
            73: "v_925",
            74: "w_100",
            75: "w_1000",
            76: "w_150",
            77: "w_200",
            78: "w_250",
            79: "w_300",
            80: "w_400",
            81: "w_50",
            82: "w_500",
            83: "w_600",
            84: "w_700",
            85: "w_850",
            86: "w_925",
            87: "z",
            88: "z_100",
            89: "z_1000",
            90: "z_150",
            91: "z_200",
            92: "z_250",
            93: "z_300",
            94: "z_400",
            95: "z_50",
            96: "z_500",
            97: "z_600",
            98: "z_700",
            99: "z_850",
            100: "z_925",
        },
    )

    assert checkpoint.metadata.version == "1.0", "version is not 1.0"
    assert checkpoint.metadata.run_id == "775d1ad8-4457-4268-a430-3df91cc55603", (
        "run_id seems wrong"
    )
    assert isinstance(checkpoint.metadata.config, DotDict), "config is not DotDict"
    assert isinstance(checkpoint.metadata.dataset, DotDict), "dataset is not DotDict"
    assert isinstance(checkpoint.metadata.data_indices, DotDict), (
        "data_indices is not DotDict"
    )

    assert checkpoint.config is not None, "config is None"
    assert checkpoint.graph is not None, "graph is None"

    assert checkpoint.name_to_index == n2i, "name_to_index is not correct"
    assert checkpoint.index_to_name == i2n, "index_to_name is not correct"


def test_model_output_index_to_name():
    checkpoint_paths = [
        os.path.dirname(os.path.abspath(__file__)) + "/files/checkpoint.ckpt",
        os.path.dirname(os.path.abspath(__file__)) + "/files/multiencdec.ckpt",
    ]
    testdata = {
        0: (
            {
                0: "10u",
                1: "10v",
                2: "2d",
                3: "2t",
                4: "cp",
                5: "msl",
                6: "q_100",
                7: "q_1000",
                8: "q_150",
                9: "q_200",
                10: "q_250",
                11: "q_300",
                12: "q_400",
                13: "q_50",
                14: "q_500",
                15: "q_600",
                16: "q_700",
                17: "q_850",
                18: "q_925",
                19: "skt",
                20: "sp",
                21: "t_100",
                22: "t_1000",
                23: "t_150",
                24: "t_200",
                25: "t_250",
                26: "t_300",
                27: "t_400",
                28: "t_50",
                29: "t_500",
                30: "t_600",
                31: "t_700",
                32: "t_850",
                33: "t_925",
                34: "tcw",
                35: "tp",
                36: "u_100",
                37: "u_1000",
                38: "u_150",
                39: "u_200",
                40: "u_250",
                41: "u_300",
                42: "u_400",
                43: "u_50",
                44: "u_500",
                45: "u_600",
                46: "u_700",
                47: "u_850",
                48: "u_925",
                49: "v_100",
                50: "v_1000",
                51: "v_150",
                52: "v_200",
                53: "v_250",
                54: "v_300",
                55: "v_400",
                56: "v_50",
                57: "v_500",
                58: "v_600",
                59: "v_700",
                60: "v_850",
                61: "v_925",
                62: "w_100",
                63: "w_1000",
                64: "w_150",
                65: "w_200",
                66: "w_250",
                67: "w_300",
                68: "w_400",
                69: "w_50",
                70: "w_500",
                71: "w_600",
                72: "w_700",
                73: "w_850",
                74: "w_925",
                75: "z_100",
                76: "z_1000",
                77: "z_150",
                78: "z_200",
                79: "z_250",
                80: "z_300",
                81: "z_400",
                82: "z_50",
                83: "z_500",
                84: "z_600",
                85: "z_700",
                86: "z_850",
                87: "z_925",
            },
        ),
        1: ({0: "10u", 1: "10v", 2: "2t"}, {0: "skt", 1: "msl", 2: "tp"}),
    }

    for i, checkpoint_path in enumerate(checkpoint_paths):
        checkpoint = bris.checkpoint.Checkpoint(path=checkpoint_path)
        assert checkpoint.model_output_index_to_name == testdata[i], (
            f"model_output_index_to_name is not correct for checkpoint {checkpoint_path}"
        )


def test_model_output_name_to_index():
    checkpoint_paths = [
        os.path.dirname(os.path.abspath(__file__)) + "/files/checkpoint.ckpt",
        os.path.dirname(os.path.abspath(__file__)) + "/files/multiencdec.ckpt",
    ]
    testdata = {
        0: (
            {
                "10u": 0,
                "10v": 1,
                "2d": 2,
                "2t": 3,
                "cp": 4,
                "msl": 5,
                "q_100": 6,
                "q_1000": 7,
                "q_150": 8,
                "q_200": 9,
                "q_250": 10,
                "q_300": 11,
                "q_400": 12,
                "q_50": 13,
                "q_500": 14,
                "q_600": 15,
                "q_700": 16,
                "q_850": 17,
                "q_925": 18,
                "skt": 19,
                "sp": 20,
                "t_100": 21,
                "t_1000": 22,
                "t_150": 23,
                "t_200": 24,
                "t_250": 25,
                "t_300": 26,
                "t_400": 27,
                "t_50": 28,
                "t_500": 29,
                "t_600": 30,
                "t_700": 31,
                "t_850": 32,
                "t_925": 33,
                "tcw": 34,
                "tp": 35,
                "u_100": 36,
                "u_1000": 37,
                "u_150": 38,
                "u_200": 39,
                "u_250": 40,
                "u_300": 41,
                "u_400": 42,
                "u_50": 43,
                "u_500": 44,
                "u_600": 45,
                "u_700": 46,
                "u_850": 47,
                "u_925": 48,
                "v_100": 49,
                "v_1000": 50,
                "v_150": 51,
                "v_200": 52,
                "v_250": 53,
                "v_300": 54,
                "v_400": 55,
                "v_50": 56,
                "v_500": 57,
                "v_600": 58,
                "v_700": 59,
                "v_850": 60,
                "v_925": 61,
                "w_100": 62,
                "w_1000": 63,
                "w_150": 64,
                "w_200": 65,
                "w_250": 66,
                "w_300": 67,
                "w_400": 68,
                "w_50": 69,
                "w_500": 70,
                "w_600": 71,
                "w_700": 72,
                "w_850": 73,
                "w_925": 74,
                "z_100": 75,
                "z_1000": 76,
                "z_150": 77,
                "z_200": 78,
                "z_250": 79,
                "z_300": 80,
                "z_400": 81,
                "z_50": 82,
                "z_500": 83,
                "z_600": 84,
                "z_700": 85,
                "z_850": 86,
                "z_925": 87,
            },
        ),
        1: ({"10u": 0, "10v": 1, "2t": 2}, {"skt": 0, "msl": 1, "tp": 2}),
    }

    for i, checkpoint_path in enumerate(checkpoint_paths):
        checkpoint = bris.checkpoint.Checkpoint(path=checkpoint_path)
        assert checkpoint.model_output_name_to_index == testdata[i], (
            f"model_output_name_to_index is not correct for checkpoint {checkpoint_path}"
        )


# def test__get_copy_model_params():
#     checkpoint_path = (
#         os.path.dirname(os.path.abspath(__file__)) + "/files/checkpoint.ckpt"
#     )
#     checkpoint = bris.checkpoint.Checkpoint(path=checkpoint_path)
#     copy_model_params = checkpoint._get_copy_model_params
#     assert copy_model_params is not None, "copy_model_params is None"
#     assert "model.encoder.proc.lin_key.weight" in copy_model_params, (
#         "copy_model_params does not have model.encoder.proc.lin_key.weight"
#     )


if __name__ == "__main__":
    test_metadata()
    test_model_output_index_to_name()
    test_model_output_name_to_index()
    # test__get_copy_model_params()
