import os
import pytest
import numpy as np
from fine import xarrayIO as xrIO
import geopandas as gpd


@pytest.mark.parametrize("use_saved_file", [False, True])
def test_esm_to_xr_and_back_during_spatial_aggregation(
    use_saved_file, test_esM_for_spagat
):
    """Resulting number of regions would be the same as the original number.

    No aggregation actually takes place. Tests:
        - if the esm instance, created after spatial aggregation
        is run, has all the info originally present.
        - If the saved netcdf file can be reconstructed into an esm instance
            and has all the info originally present.
        - If temporal aggregation and optimization run successfully
    """
    SHAPEFILE_PATH = os.path.join(  # noqa: PTH118 # uses basic aggragation functions and when changing to pathlib output the error: TypeError: shapefile must either be a path to a shapefile or a geopandas dataframe --> therefore excluded here
        os.path.dirname(__file__),  # noqa: PTH120
        "../../../examples/03_Multi-regional_Energy_System_Workflow/",
        "InputData/SpatialData/ShapeFiles/clusteredRegions.shp",
    )

    PATH_TO_SAVE = os.path.join(  # noqa: PTH118
        os.path.dirname(__file__)  # noqa: PTH120
    )
    netcdf_file_name = "my_xr.nc"
    shp_file_name = "my_shp"

    # FUNCTION CALL
    aggregated_esM = test_esM_for_spagat.aggregateSpatially(
        shapefile=SHAPEFILE_PATH,
        n_groups=8,
        aggregatedResultsPath=PATH_TO_SAVE,
        aggregated_xr_filename=netcdf_file_name,
        aggregated_shp_name=shp_file_name,
        solver="glpk",
    )

    if use_saved_file:
        saved_file = os.path.join(PATH_TO_SAVE, netcdf_file_name)  # noqa: PTH118
        xr_dss = xrIO.readNetCDFToDatasets(filePath=saved_file)
        aggregated_esM = xrIO.convertDatasetsToEnergySystemModel(xr_dss)

    # ASSERTION
    assert sorted(aggregated_esM.locations) == sorted(test_esM_for_spagat.locations)

    expected_ts = test_esM_for_spagat.getComponentAttribute(
        "Hydrogen demand", "operationRateFix"
    ).values
    output_ts = aggregated_esM.getComponentAttribute(
        "Hydrogen demand", "operationRateFix"
    ).values
    assert np.array_equal(expected_ts, output_ts)

    expected_2d = test_esM_for_spagat.getComponentAttribute(
        "DC cables", "locationalEligibility"
    ).values
    output_2d = aggregated_esM.getComponentAttribute(
        "DC cables", "locationalEligibility"
    ).values
    assert np.array_equal(output_2d, expected_2d)

    expected_1d = test_esM_for_spagat.getComponentAttribute(
        "Pumped hydro storage", "capacityFix"
    ).values
    output_1d = aggregated_esM.getComponentAttribute(
        "Pumped hydro storage", "capacityFix"
    ).values
    assert np.array_equal(output_1d, expected_1d)

    expected_0d = test_esM_for_spagat.getComponentAttribute(
        "Electroylzers", "processedInvestPerCapacity"
    )
    output_0d = aggregated_esM.getComponentAttribute(
        "Electroylzers", "investPerCapacity"
    )

    assert expected_0d.sort_index().equals(output_0d)

    expected_0d_bool = test_esM_for_spagat.getComponentAttribute(
        "CO2 from enviroment", "hasCapacityVariable"
    )
    output_0d_bool = aggregated_esM.getComponentAttribute(
        "CO2 from enviroment", "hasCapacityVariable"
    )
    assert output_0d_bool == expected_0d_bool

    # additionally, check if clustering and optimization run through
    aggregated_esM.aggregateTemporally(numberOfTypicalPeriods=4)
    aggregated_esM.optimize(timeSeriesAggregation=True, solver="glpk")

    # if there are no problems, delete the saved files
    os.remove(  # noqa: PTH107
        os.path.join(  # noqa: PTH118
            PATH_TO_SAVE, netcdf_file_name
        )
    )  # noqa: PTH107

    file_extensions_list = [".cpg", ".dbf", ".prj", ".shp", ".shx"]

    for ext in file_extensions_list:
        os.remove(  # noqa: PTH107
            os.path.join(  # noqa: PTH118
                PATH_TO_SAVE, f"{shp_file_name}{ext}"
            )
        )


def test_error_in_reading_shp(test_esM_for_spagat):
    """Checks if relevant errors are raised when invalid shapefile
    is passed to aggregateSpatially().
    """
    ## Case 1: invalid path
    with pytest.raises(FileNotFoundError):
        SHAPEFILE_PATH = os.path.join(  # noqa: PTH118
            os.path.dirname(__file__),  # noqa: PTH120
            "../../../examples/03_Multi-regional_Energy_System_Workflow/",
            "InputData/SpatialData/ShapeFiles",
        )

        _ = test_esM_for_spagat.aggregateSpatially(
            shapefile=SHAPEFILE_PATH, n_groups=2, solver="glpk"
        )

    ## Case 2: invalid shapefile type
    with pytest.raises(TypeError):
        _ = test_esM_for_spagat.aggregateSpatially(
            shapefile=test_esM_for_spagat, n_groups=2, solver="glpk"
        )

    ## Case 3: invalid nRegionsForRepresentation for the shapefile
    with pytest.raises(ValueError):
        SHAPEFILE_PATH = os.path.join(  # noqa: PTH118
            os.path.dirname(__file__),  # noqa: PTH120
            "../../../examples/03_Multi-regional_Energy_System_Workflow/",
            "InputData/SpatialData/ShapeFiles/three_regions.shp",
        )

        _ = test_esM_for_spagat.aggregateSpatially(
            shapefile=SHAPEFILE_PATH, n_groups=5, solver="glpk"
        )


def test_spatial_aggregation_string_based(test_esM_for_spagat):
    SHAPEFILE_PATH = os.path.join(  # noqa: PTH118
        os.path.dirname(__file__),  # noqa: PTH120
        "../../../examples/03_Multi-regional_Energy_System_Workflow/",
        "InputData/SpatialData/ShapeFiles/clusteredRegions.shp",
    )

    # FUNCTION CALL
    aggregated_esM = test_esM_for_spagat.aggregateSpatially(
        shapefile=SHAPEFILE_PATH,
        grouping_mode="string_based",
        aggregatedResultsPath=None,
        separator="_",
    )

    # ASSERTION
    assert len(aggregated_esM.locations) == 8


@pytest.mark.parametrize(
    "skip_regions, enforced_groups, n_expected_groups",
    [
        (None, None, 2),
        (["cluster_3"], None, 3),
        (
            None,
            {
                "cluster_1_cluster_2_cluster_3": [
                    "cluster_1",
                    "cluster_2",
                    "cluster_3",
                ],
                "cluster_4_cluster_5_cluster_6_cluster_7": [
                    "cluster_4",
                    "cluster_5",
                    "cluster_6",
                    "cluster_7",
                ],
            },
            4,
        ),
    ],
)
def test_spatial_aggregation_distance_based(
    test_esM_for_spagat, skip_regions, enforced_groups, n_expected_groups
):
    SHAPEFILE_PATH = os.path.join(  # noqa: PTH118
        os.path.dirname(__file__),  # noqa: PTH120
        "../../../examples/03_Multi-regional_Energy_System_Workflow/",
        "InputData/SpatialData/ShapeFiles/clusteredRegions.shp",
    )

    # FUNCTION CALL
    aggregated_esM = test_esM_for_spagat.aggregateSpatially(
        shapefile=SHAPEFILE_PATH,
        grouping_mode="distance_based",
        n_groups=2,
        aggregatedResultsPath=None,
        skip_regions=skip_regions,
        enforced_groups=enforced_groups,
    )

    # ASSERTION
    assert len(aggregated_esM.locations) == n_expected_groups


@pytest.mark.parametrize(
    "aggregation_function_dict",
    [
        None,
        {
            "operationRateMax": ("weighted mean", "capacityMax"),
            "operationRateFix": ("sum", None),
            "capacityMax": ("sum", None),
            "capacityFix": ("sum", None),
            "locationalEligibility": ("bool", None),
        },
    ],
)
@pytest.mark.parametrize("n_regions", [2, 3])
def test_spatial_aggregation_parameter_based(
    test_esM_for_spagat, aggregation_function_dict, n_regions
):
    SHAPEFILE_PATH = os.path.join(  # noqa: PTH118
        os.path.dirname(__file__),  # noqa: PTH120
        "../../../examples/03_Multi-regional_Energy_System_Workflow/",
        "InputData/SpatialData/ShapeFiles/clusteredRegions.shp",
    )

    # FUNCTION CALL
    aggregated_esM = test_esM_for_spagat.aggregateSpatially(
        shapefile=SHAPEFILE_PATH,
        grouping_mode="parameter_based",
        n_groups=n_regions,
        aggregatedResultsPath=None,
        aggregation_function_dict=aggregation_function_dict,
        var_weights={"1d_vars": 10},
        solver="glpk",
    )

    # ASSERTION
    assert len(aggregated_esM.locations) == n_regions
    #  Additional check - if the optimization runs through
    aggregated_esM.aggregateTemporally(numberOfTypicalPeriods=4)
    aggregated_esM.optimize(timeSeriesAggregation=True)


def test_aggregation_of_balanceLimit(balanceLimitConstraint_test_esM):
    esM = balanceLimitConstraint_test_esM[0]
    SHAPEFILE_PATH = os.path.join(  # noqa: PTH118
        os.path.dirname(__file__),  # noqa: PTH120
        "../../../examples/03_Multi-regional_Energy_System_Workflow/",
        "InputData/SpatialData/ShapeFiles/clusteredRegions.shp",
    )

    gdf = gpd.read_file(SHAPEFILE_PATH)
    gdf = gdf.iloc[:2]
    gdf["index"] = [f"Region{i}" for i in [1, 2]]

    # FUNCTION CALL
    _ = esM.aggregateSpatially(
        shapefile=gdf,
        grouping_mode="distance_based",
        n_groups=1,
        aggregatedResultsPath=None,
    )


# %%
