"""Some tools for ranking data."""

from math import isfinite
from operator import add, itemgetter
from typing import Callable, Final, Iterable, TypeVar

from pycommons.types import type_error

#: the type of the element to sort during ranking
T = TypeVar("T")

#: the type of the key to use during ranking
K = TypeVar("K")

#: the rank type: must either be int, float, or a union of both
R = TypeVar("R", int, float, int | float)

#: the output type var
X = TypeVar("X")


def rank(source: Iterable[T],
         key: Callable[[T], K] = lambda x: x,  # type: ignore
         output: Callable[[R, T, K], X] =  # type: ignore
         lambda rr, tt, _: (rr, tt),  # type: ignore
         rank_join: Callable[[int, int], R] = add,  # type: ignore
         rank_offset: int = 0) -> list[X]:
    """
    Rank the elements in a data source based on a given key function.

    The default behavior of this function is to basically sort the data from
    `source` and return tuples with the rank and the data element in a list.
    The result list is sorted by the keys of the object.

    By default, ranks start at `0` and increase in steps of `2`. The ranks of
    objects that would have the same rank are resolved by averaging their
    ranks. This is why we increment ranks in steps of `2`: This way, the mean
    of two ranks is always an integer:
    >>> rank([3, 6, 6, 12])
    [(0, 3), (3, 6), (3, 6), (6, 12)]

    This averaging can be modified by providing a `rank_join` function that
    computes a joint rank for objects as well as a `rank_offset`:
    >>> rank([3, 6, 6, 12], rank_join=lambda a, b: 0.5 * (a + b))
    [(0.0, 3), (1.5, 6), (1.5, 6), (3.0, 12)]

    >>> rank([3, 6, 6, 12], rank_join=min, rank_offset=1)
    [(1, 3), (2, 6), (2, 6), (4, 12)]

    >>> rank([3, 6, 6, 12], rank_join=max, rank_offset=1)
    [(1, 3), (3, 6), (3, 6), (4, 12)]

    However, the result of `rank_offset + rank_join(a, b)` must always be
    either an `int` or a `float` and also always finite and never negative,
    for any two non-negative integers `a` and `b`.

    The `key` function must compute a key for each element of `source` which
    can be used for sorting. By default, it returns the element itself.
    But it can be customized.
    >>> rank((6, 5, 3, 4, 0, 7))
    [(0, 0), (2, 3), (4, 4), (6, 5), (8, 6), (10, 7)]

    >>> sorted(rank({"a", "c", "X", "y", "x", "xx", "L", "l"}, key=str.lower))
    [(0, 'a'), (2, 'c'), (5, 'L'), (5, 'l'), (9, 'X'), (9, 'x'), (12, 'xx'), \
(14, 'y')]

    The `output` function is used to create the records to be placed in the
    list returned by this function. Its input are the rank, the object, and
    its computed key. By default, it creates tuples of the rank and the object
    obtained from `source`. You can customize this as well:
    >>> rank([5, 7, 4, 9, 2, 1])
    [(0, 1), (2, 2), (4, 4), (6, 5), (8, 7), (10, 9)]

    >>> rank([5, 7, 4, 9, 2, 1], output=lambda rr, oo, kk: f"{rr}:{oo}")
    ['0:1', '2:2', '4:4', '6:5', '8:7', '10:9']

    >>> rank([5, 7, 4, 19, 2, 1], key=str,
    ...         output=lambda rr, oo, kk: (rr, oo, kk))
    [(0, 1, '1'), (2, 19, '19'), (4, 2, '2'), (6, 4, '4'), (8, 5, '5'), \
(10, 7, '7')]

    :param source: the data source
    :param key: a function returning the key for a given object
    :param output: a function creating the output object, receiving the rank,
        original object, and key as input
    :param rank_join: a function for joining a maximum and minimum index of an
        object to a rank; by default this returns the sum of both
    :param rank_offset: an offset to be added to the ranks
    :returns: a list with the objects generated by the `output` function,
        which by default are tuples of rank and object

    >>> rank({})
    []
    >>> rank([12])
    [(0, 12)]
    >>> rank([12, 3])
    [(0, 3), (2, 12)]
    >>> rank([3, 12], rank_offset=2)
    [(2, 3), (4, 12)]
    >>> rank([12, 12])
    [(1, 12), (1, 12)]
    >>> rank([12, 12], output=lambda rr, tt, kk: rr)
    [1, 1]
    >>> rank([-1, 0, 4, 3, 3, 5, 6, 1])
    [(0, -1), (2, 0), (4, 1), (7, 3), (7, 3), (10, 4), (12, 5), (14, 6)]
    >>> rank([-1, 0, 4, 3, 3, 5, 6, 1], rank_join=lambda a, b: 0.5 * (a + b))
    [(0.0, -1), (1.0, 0), (2.0, 1), (3.5, 3), (3.5, 3), (5.0, 4), (6.0, 5), \
(7.0, 6)]
    >>> sorted(rank(("a", "B", "c", "b", "A", "A", "cc"), key=str.casefold))
    [(2, 'A'), (2, 'A'), (2, 'a'), (7, 'B'), (7, 'b'), (10, 'c'), (12, 'cc')]

    >>> try:
    ...     rank(1)
    ... except TypeError as te:
    ...     print(te)
    source should be an instance of typing.Iterable but is int, namely 1.

    >>> try:
    ...     rank([], key=1)
    ... except TypeError as te:
    ...     print(te)
    key should be a callable but is int, namely 1.

    >>> try:
    ...     rank([], output=1)
    ... except TypeError as te:
    ...     print(te)
    output should be a callable but is int, namely 1.

    >>> try:
    ...     rank([], rank_join=1)
    ... except TypeError as te:
    ...     print(te)
    rank_join should be a callable but is int, namely 1.

    >>> try:
    ...     rank([], rank_offset="x")
    ... except TypeError as te:
    ...     print(te)
    rank_offset should be an instance of any in {float, int} \
but is str, namely 'x'.

    >>> from math import inf, nan
    >>> try:
    ...     rank([], rank_offset=inf)
    ... except ValueError as ve:
    ...     print(ve)
    rank_offset=inf should be finite

    >>> try:
    ...     rank([], rank_offset=nan)
    ... except ValueError as ve:
    ...     print(ve)
    rank_offset=nan should be finite

    >>> try:
    ...     rank([1, 2, 3], rank_join=lambda a, b: "x")
    ... except TypeError as te:
    ...     print(te)
    rank_join(0, 0) should be an instance of any in {float, int} \
but is str, namely 'x'.

    >>> try:
    ...     rank([1, 2, 3], rank_join=lambda a, b: inf)
    ... except ValueError as ve:
    ...     print(ve)
    rank inf=rank_join(0, 0) + 0 is not finite and non-negative.

    >>> try:
    ...     rank([1, 2, 3], rank_join=lambda a, b: nan)
    ... except ValueError as ve:
    ...     print(ve)
    rank nan=rank_join(0, 0) + 0 is not finite and non-negative.
    """
    if not isinstance(source, Iterable):
        raise type_error(source, "source", Iterable)
    if not callable(key):
        raise type_error(key, "key", call=True)
    if not callable(output):
        raise type_error(output, "output", call=True)
    if not callable(rank_join):
        raise type_error(rank_join, "rank_join", call=True)
    if not isinstance(rank_offset, int | float):
        raise type_error(rank_offset, "rank_offset", (int, float))
    if not isfinite(rank_offset):
        raise ValueError(
            f"rank_offset={rank_offset} should be finite")

    data: list = [(key(t), t) for t in source]  # convert data to list
    max_hi: Final[int] = list.__len__(data) - 1  # maximum index
    if max_hi < 0:  # data is empty, can return it as-is
        return data

    data.sort(key=itemgetter(0))  # sort the data by the key

    lo: int = 0
    while lo <= max_hi:  # iterate through all the data
        # first, we obtain the index range of objects with the same rank
        lo_key = data[lo][0]
        hi: int = lo
        while (hi < max_hi) and (not (lo_key < data[hi + 1][0])):
            hi += 1

        r: R = rank_join(lo, hi)  # compute the joint rank
        if not isinstance(r, int | float):  # sanity check of rank, part 1
            raise type_error(r, f"rank_join({lo}, {hi})", (int, float))

        r += rank_offset
        if (not isfinite(r)) or (r < 0):  # sanity check of rank, part 2
            raise ValueError(f"rank {r}=rank_join({lo}, {hi}) + {rank_offset}"
                             " is not finite and non-negative.")

        for i in range(lo, hi + 1):  # assign rank and create output objects
            dk, dt = data[i]
            data[i] = output(r, dt, dk)

        lo = hi + 1  # move on to next object

    return data  # return finalized list
