import bisect
import json
import math
import random
import string
from collections import Counter, defaultdict
from collections.abc import Iterator
from dataclasses import replace
from threading import Lock
from typing import Any

import numpy
from arrakis import Channel


class PartitionMetadata:
    """Partition metadata contains any metadata that must be tracked per
    partition, as opposed to per channel"""

    def __init__(self) -> None:
        """
        Initialize a PartitionMetadata object.

        Fields
            _next_index: The next index value to issue for this partition
        """
        self._next_index: int = 0
        self._lock: Lock = Lock()

    def get_next_index(self) -> int:
        """Get the next index value to issue for this partition"""
        with self._lock:
            idx = self._next_index
            self._next_index += 1
            return idx

    def merge_in(self, other: "PartitionMetadata"):
        """Update this self with other.

        This will update the next index value to be the max of the two objects

        It is safe to call on the same object.
        """
        if self is other:
            return
        with self._lock:
            with other._lock:
                self._next_index = max(self._next_index, other._next_index)

    def to_dict(self) -> dict[str, Any]:
        return {
            "next_index": self._next_index,
        }

    def to_json(self) -> str:
        return json.dumps(self.to_dict())

    @classmethod
    def from_json(cls, json_str) -> "PartitionMetadata":
        obj = json.loads(json_str)
        meta = PartitionMetadata()
        meta._next_index = obj["next_index"]
        return meta

    @classmethod
    def from_dict(cls, obj) -> "PartitionMetadata":
        meta = PartitionMetadata()
        meta._next_index = obj["next_index"]
        return meta


def generate_partition_id(publisher_id: str, channel: Channel | None = None) -> str:
    alphanum = string.ascii_uppercase + string.digits
    rand_id = "".join(random.SystemRandom().choice(alphanum) for _ in range(6))
    if channel:
        return f"{publisher_id}_{channel.subsystem}_{rand_id}"
    else:
        return f"{publisher_id}_{rand_id}"


def grouped(items: list[Any], n: int) -> Iterator[list[Any]]:
    for i in range(0, len(items), n):
        yield items[i : i + n]


def partition_channels(
    channels: list[Channel],
    publisher: str,
    metadata: dict[str, Channel] | None = None,
    partition_metadata: dict[str, PartitionMetadata] | None = None,
    max_channels: int = 100,
    partition_fraction: float = 0.8,
) -> list[Channel]:
    """determine partitions IDs for channels

    Parameters
    ----------
    channels : list[Channel]
        List of channels for which to determine partition IDs
    publisher: str
        A publisher ID to apply to all channels being partitioned.
        This will override any publisher already specified in the
        channel metadata returned.
    metadata: dict[str, Channel]
        An existing channel metadata dictionary, from which existing
        partition information will be taken.
    partition_metadata: dict[str, PartitionMetadata]
        An existing partition metadata dictionary, from which existing
        partition information can be used.
    max_channels: int
        The maximum number per partition.
    partition_fraction: float
        Fraction of max channels to use in intial partition
        allocation.

    Returns the initially provided channel list updated with publisher
    and partition info.

    """
    if partition_metadata is None:
        partition_metadata = {}

    if metadata is None:
        metadata = {}
    else:
        # trim metadata to only contain the channels with the listed publisher
        metadata = {
            name: meta for name, meta in metadata.items() if meta.publisher == publisher
        }

    # determine channels to partition
    channels_to_partition = []
    need_to_index = False
    for channel in channels:
        if channel.partition_id is None:
            channels_to_partition.append(channel)
        if channel.partition_index is None:
            need_to_index = True

    if not channels_to_partition and not need_to_index:
        return channels

    # map channels to dtypes
    channels_by_dtype: dict[numpy.dtype | str, list[Channel]] = {}
    for channel in channels_to_partition:
        channels_by_dtype.setdefault(channel.data_type, []).append(channel)

    # filter channels that aren't matched to an ID
    # handle each data type separately
    updated = {}
    for subblock in channels_by_dtype.values():
        # filter channels that aren't matched to an ID
        subblock_group = {channel.name for channel in subblock}
        subpartitions = {
            name: meta.partition_id
            for name, meta in metadata.items()
            if name in subblock_group
        }
        unmatched = [
            channel for channel in subblock if channel.name not in subpartitions
        ]
        part_count = Counter(subpartitions.values())
        ordered = sorted(list(subpartitions.keys()))

        # determine where channel would go in sorted order
        insert_pt = defaultdict(list)
        for channel in unmatched:
            idx = bisect.bisect_left(ordered, channel.name)
            insert_pt[idx].append(channel)

        # assign unmatched into existing or new partitions
        max_partition_size = math.floor(partition_fraction * max_channels)
        for idx, adjacent in insert_pt.items():
            insert_idx = min(idx, len(ordered) - 1)

            if insert_idx == -1:
                # no initial partitions
                partition_id = generate_partition_id(publisher, adjacent[0])
            else:
                id_ = metadata[ordered[insert_idx]].partition_id
                assert isinstance(id_, str)
                partition_id = id_

            if part_count[partition_id] + len(adjacent) > max_channels:
                # assign to new partition
                for group in grouped(adjacent, max_partition_size):
                    partition_id = generate_partition_id(publisher, group[0])
                    local_partition_metadata = PartitionMetadata()

                    for channel in group:
                        updated[channel.name] = replace(
                            channel,
                            publisher=publisher,
                            partition_id=partition_id,
                            partition_index=local_partition_metadata.get_next_index(),
                        )

                    part_count[partition_id] += len(group)
                    partition_metadata[partition_id] = local_partition_metadata
            else:
                # assign to existing partition
                local_partition_metadata = partition_metadata.get(
                    partition_id, PartitionMetadata()
                )
                for channel in adjacent:
                    updated[channel.name] = replace(
                        channel,
                        publisher=publisher,
                        partition_id=partition_id,
                        partition_index=local_partition_metadata.get_next_index(),
                    )

                part_count[partition_id] += len(adjacent)
                partition_metadata[partition_id] = local_partition_metadata

    # fill in any channels that were not newly partitioned
    for channel in channels:
        if channel.name in updated:
            continue
        # We can get here if no metadata was passed in and the channel
        # has a partition_id
        if channel.name not in metadata:
            metadata[channel.name] = channel
        assert metadata[channel.name].partition_id
        partition_id = metadata[channel.name].partition_id
        if channel.partition_index is None:
            index = metadata[channel.name].partition_index
            if index is None:
                meta = partition_metadata.get(partition_id, PartitionMetadata())
                index = meta.get_next_index()
                partition_metadata[partition_id] = meta

            updated[channel.name] = replace(
                channel,
                publisher=publisher,
                partition_id=metadata[channel.name].partition_id,
                partition_index=index,
            )
        else:
            updated[channel.name] = replace(
                channel,
                publisher=publisher,
                partition_id=metadata[channel.name].partition_id,
            )

    # return same channel list order as passed in
    return [updated[channel.name] for channel in channels]
