from collections.abc import Iterator
from pathlib import Path

import pytest
from cognite.client.data_classes import AssetList, AssetWrite, AssetWriteList
from cognite.client.data_classes.data_modeling import NodeId, Space, ViewId
from cognite.client.data_classes.data_modeling.cdm.v1 import CogniteAsset
from cognite.client.exceptions import CogniteAPIError

from cognite_toolkit._cdf_tk.client import ToolkitClient
from cognite_toolkit._cdf_tk.commands._migrate.command import MigrationCommand
from cognite_toolkit._cdf_tk.commands._migrate.data_mapper import AssetCentricMapper
from cognite_toolkit._cdf_tk.commands._migrate.default_mappings import (
    ASSET_ID,
    EVENT_ID,
    FILE_METADATA_ID,
    TIME_SERIES_ID,
)
from cognite_toolkit._cdf_tk.commands._migrate.migration_io import AssetCentricMigrationIO
from cognite_toolkit._cdf_tk.commands._migrate.selectors import MigrateDataSetSelector, MigrationCSVFileSelector
from tests.test_integration.conftest import HierarchyMinimal
from tests.test_integration.constants import RUN_UNIQUE_ID


@pytest.fixture()
def three_assets(toolkit_client: ToolkitClient, toolkit_space: Space) -> Iterator[AssetList]:
    client = toolkit_client
    space = toolkit_space.space
    assets = AssetWriteList([])
    for i in range(3):
        asset = AssetWrite(
            external_id=f"toolkit_asset_test_migration_{i}_{RUN_UNIQUE_ID}",
            name=f"toolkit_asset_test_migration_{i}_{RUN_UNIQUE_ID}",
            parent_external_id=f"toolkit_asset_test_migration_{0}_{RUN_UNIQUE_ID}" if i > 0 else None,
        )
        assets.append(asset)
    output = client.assets.retrieve_multiple(external_ids=assets.as_external_ids(), ignore_unknown_ids=True)
    if output:
        try:
            client.assets.delete(external_id=output.as_external_ids(), ignore_unknown_ids=True)
        except CogniteAPIError:
            client.data_modeling.instances.delete([NodeId(space, ts.external_id) for ts in output])
    created = client.assets.create(assets)

    yield created

    # Cleanup after test
    deleted = client.data_modeling.instances.delete([NodeId(space, ts.external_id) for ts in created])
    if deleted.nodes:
        return
    client.assets.delete(external_id=created.as_external_ids())


class TestMigrateAssetsCommand:
    def test_migrate_assets(
        self,
        toolkit_client: ToolkitClient,
        three_assets: AssetList,
        toolkit_space: Space,
        tmp_path: Path,
    ) -> None:
        client = toolkit_client
        space = toolkit_space.space

        input_file = tmp_path / "timeseries_migration.csv"
        with input_file.open("w", encoding="utf-8") as f:
            f.write(
                "id,dataSetId,space,externalId\n"
                + "\n".join(
                    f"{a.id},{a.data_set_id if a.data_set_id else ''},{space},{a.external_id}" for a in three_assets
                )
                + "\n"
            )

        cmd = MigrationCommand(skip_tracking=True, silent=True)
        cmd.migrate(
            selected=MigrationCSVFileSelector(datafile=input_file, kind="Assets"),
            data=AssetCentricMigrationIO(client),
            mapper=AssetCentricMapper(client),
            log_dir=tmp_path / "logs",
            dry_run=False,
            verbose=False,
        )
        node_ids = [NodeId(space, a.external_id) for a in three_assets]
        migrated_assets = client.data_modeling.instances.retrieve_nodes(node_ids, CogniteAsset)
        assert len(migrated_assets) == len(three_assets), "Not all assets were migrated successfully."

    def test_migrate_assets_by_dataset_dry_run(
        self, toolkit_client: ToolkitClient, migration_hierarchy_minimal: HierarchyMinimal, tmp_path: Path
    ) -> None:
        client = toolkit_client
        hierarchy = migration_hierarchy_minimal
        cmd = MigrationCommand(skip_tracking=True, silent=True)
        progress = cmd.migrate(
            selected=MigrateDataSetSelector(
                kind="Assets",
                data_set_external_id=hierarchy.dataset.external_id,
                ingestion_mapping=ASSET_ID,
                preferred_consumer_view=ViewId("cdf_cdm", "CogniteAsset", "v1"),
            ),
            data=AssetCentricMigrationIO(client),
            mapper=AssetCentricMapper(client),
            log_dir=tmp_path,
            dry_run=True,
        )
        results = progress.aggregate()
        expected_results = {(step, "success"): 2 for step in cmd.Steps.list()}
        assert results == expected_results


class TestMigrateEventsCommand:
    def test_migrate_events_by_dataset_dry_run(
        self, toolkit_client: ToolkitClient, migration_hierarchy_minimal: HierarchyMinimal, tmp_path: Path
    ) -> None:
        client = toolkit_client
        hierarchy = migration_hierarchy_minimal
        cmd = MigrationCommand(skip_tracking=True, silent=True)
        progress = cmd.migrate(
            selected=MigrateDataSetSelector(
                kind="Events",
                data_set_external_id=hierarchy.dataset.external_id,
                ingestion_mapping=EVENT_ID,
                preferred_consumer_view=ViewId("cdf_cdm", "CogniteActivity", "v1"),
            ),
            data=AssetCentricMigrationIO(client),
            mapper=AssetCentricMapper(client),
            log_dir=tmp_path,
            dry_run=True,
        )
        results = progress.aggregate()
        expected_results = {(step, "success"): 1 for step in cmd.Steps.list()}
        assert results == expected_results


class TestMigrateTimeSeriesCommand:
    def test_migrate_time_series_by_dataset_dry_run(
        self, toolkit_client: ToolkitClient, migration_hierarchy_minimal: HierarchyMinimal, tmp_path: Path
    ) -> None:
        client = toolkit_client
        hierarchy = migration_hierarchy_minimal
        cmd = MigrationCommand(skip_tracking=True, silent=True)
        progress = cmd.migrate(
            selected=MigrateDataSetSelector(
                kind="TimeSeries",
                data_set_external_id=hierarchy.dataset.external_id,
                ingestion_mapping=TIME_SERIES_ID,
                preferred_consumer_view=ViewId("cdf_cdm", "CogniteTimeSeries", "v1"),
            ),
            data=AssetCentricMigrationIO(client, skip_linking=True),
            mapper=AssetCentricMapper(client),
            log_dir=tmp_path,
            dry_run=True,
        )
        results = progress.aggregate()
        expected_results = {(step, "success"): 1 for step in cmd.Steps.list()}
        assert results == expected_results


class TestMigrateFileMetadataCommand:
    def test_migrate_file_metadata_by_dataset_dry_run(
        self, toolkit_client: ToolkitClient, migration_hierarchy_minimal: HierarchyMinimal, tmp_path: Path
    ) -> None:
        client = toolkit_client
        hierarchy = migration_hierarchy_minimal
        cmd = MigrationCommand(skip_tracking=True, silent=True)
        progress = cmd.migrate(
            selected=MigrateDataSetSelector(
                kind="FileMetadata",
                data_set_external_id=hierarchy.dataset.external_id,
                ingestion_mapping=FILE_METADATA_ID,
                preferred_consumer_view=ViewId("cdf_cdm", "CogniteFile", "v1"),
            ),
            data=AssetCentricMigrationIO(client, skip_linking=False),
            mapper=AssetCentricMapper(client),
            log_dir=tmp_path,
            dry_run=True,
        )
        results = progress.aggregate()
        expected_results = {(step, "success"): 1 for step in cmd.Steps.list()}
        assert results == expected_results
