from pathlib import Path
from shutil import copyfile
from tempfile import TemporaryDirectory

import toml

from arrakis_server.metadata import ChannelConfigBackend
from arrakis_server.partition import PartitionMetadata, partition_channels
from arrakis_server.tests.test_partition import create_channel


def setup_metadata(dest_dir: str | Path, metdata_filename: str) -> Path:
    dest_path = Path(dest_dir) / "metdata.toml"
    src_path = Path(__file__).parent / "data" / metdata_filename
    copyfile(src_path, dest_path)
    return dest_path


def test_channel_config_backend_load_metadata():
    with TemporaryDirectory() as cache_dir:
        cache_path = setup_metadata(cache_dir, "metadata_test_1.toml")

        backend = ChannelConfigBackend(cache_file=cache_path)
        for name in ["H1:TEST-CHANNEL_A", "H1:TEST-CHANNEL_B"]:
            assert name in backend.metadata

        assert len(backend.partition_metadata) == 2
        assert backend.partition_metadata["existing_partition"]._next_index == 5
        assert backend.partition_metadata["existing_partition2"]._next_index == 42


def test_channel_config_backend_save_metadata():
    with TemporaryDirectory() as cache_dir:
        cache_path = setup_metadata(cache_dir, "metadata_test_empty.toml")
        backend = ChannelConfigBackend(cache_file=cache_path)
        assert backend.cache_file is not None

        channels = [
            create_channel(
                "H1:TEST-CHANNEL_A",
                publisher="test_publisher",
                partition_id="existing_partition",
            ),
            create_channel(
                "H1:TEST-CHANNEL_B",
                publisher="test_publisher",
                partition_id="existing_partition2",
            ),
        ]
        partition_metadata: dict[str, PartitionMetadata] = {}
        channels = partition_channels(
            channels, publisher="test_publisher", partition_metadata=partition_metadata
        )
        backend.update(channels, partition_metadata)

        with open(cache_path, "rt") as f:
            objs = toml.load(f)
            assert (
                objs["__internal_partition_metadata"]["existing_partition"][
                    "next_index"
                ]
                == 1
            )
            assert (
                objs["__internal_partition_metadata"]["existing_partition"][
                    "next_index"
                ]
                == 1
            )
