from __future__ import annotations

import collections.abc
import dataclasses
import datetime
import functools
import re
import types
import typing

import dateutil.parser
from zodchy import codex

FieldName: typing.TypeAlias = str
FieldType: typing.TypeAlias = type
FieldValue: typing.TypeAlias = str | None
TypesMapType: typing.TypeAlias = collections.abc.Mapping[FieldName, FieldType]


@dataclasses.dataclass
class ParsingSchema:
    order_by: str = "order_by"
    limit: str = "limit"
    offset: str = "offset"
    fieldset: str = "fieldset"


def _cast_bool(value: str) -> bool:
    value = value.strip().lower()
    if value == "false":
        return False
    if value == "true":
        return True
    raise ValueError(f"Unable to cast '{value}' to bool")


default_casting_map: collections.abc.Mapping[FieldType, collections.abc.Callable[[str], typing.Any]] = (
    types.MappingProxyType(
        {datetime.datetime: dateutil.parser.parse, datetime.date: datetime.date.fromisoformat, bool: _cast_bool}
    )
)

interval_types = (datetime.datetime, datetime.date, int, float)


@dataclasses.dataclass
class Param:
    name: str
    value: codex.operator.ClauseBit


ClauseHandler: typing.TypeAlias = collections.abc.Callable[[FieldName, str], Param]
LiteralFactory: typing.TypeAlias = collections.abc.Callable[[typing.Any], codex.operator.FilterBit[typing.Any]]
RangeLeftClause: typing.TypeAlias = codex.operator.GE[typing.Any] | codex.operator.GT[typing.Any]
RangeRightClause: typing.TypeAlias = codex.operator.LE[typing.Any] | codex.operator.LT[typing.Any]
RangeFactories: typing.TypeAlias = tuple[
    collections.abc.Callable[[typing.Any], RangeLeftClause],
    collections.abc.Callable[[typing.Any], RangeRightClause],
]
SupportedNotOperand: typing.TypeAlias = (
    codex.operator.EQ[typing.Any]
    | codex.operator.LE[typing.Any]
    | codex.operator.GE[typing.Any]
    | codex.operator.LT[typing.Any]
    | codex.operator.GT[typing.Any]
    | codex.operator.IS[typing.Any]
    | codex.operator.LIKE[typing.Any]
    | codex.operator.SET[typing.Any]
    | codex.operator.RANGE[typing.Any]
)


class Parser:
    def __init__(
        self,
        casting_map: collections.abc.Mapping[
            FieldType, collections.abc.Callable[[str], typing.Any]
        ] = default_casting_map,
        parsing_schema: ParsingSchema = ParsingSchema(),
    ):
        self._casting_map = casting_map
        self._parsing_schema = parsing_schema

    def __call__(
        self,
        query: str | collections.abc.Mapping[FieldName, FieldValue],
        types_map: TypesMapType,
    ) -> collections.abc.Generator[tuple[str, codex.operator.ClauseBit], None, None]:
        data: collections.abc.Iterable[tuple[str, FieldValue]]
        if isinstance(query, str):
            if "=" not in query:
                raise ValueError("You have to specify name for parameter value")
            data = (self._split_key_value(pairs) for pairs in query.split("&"))
        elif isinstance(query, collections.abc.Mapping):
            data = query.items()
        else:
            raise ValueError("Query mast be string or mapping")
        for param in self._parse(data, types_map):
            yield param.name, param.value

    def _parse(
        self,
        query: collections.abc.Iterable[tuple[str, FieldValue]],
        types_map: TypesMapType,
    ) -> collections.abc.Generator[Param, None, None]:
        for pair in query:
            k, v = pair
            if v is None:
                continue
            if k == self._parsing_schema.order_by:
                yield from self._parse_order_param(v)
            elif k == self._parsing_schema.limit:
                yield Param(name=self._parsing_schema.limit, value=codex.operator.Limit(int(v)))
            elif k == self._parsing_schema.offset:
                yield Param(name=self._parsing_schema.offset, value=codex.operator.Offset(int(v)))
            else:
                data = self._parse_filter_param(k, v.strip(), types_map)
                if data is not None:
                    yield data

    @staticmethod
    def _parse_order_param(names: FieldName) -> collections.abc.Generator[Param, None, None]:
        priority = 0
        for name in names.split(","):
            direction: type[codex.operator.OrderBit] = codex.operator.ASC
            if name.startswith("-"):
                direction = codex.operator.DESC
                name = name[1:]
            yield Param(name=name, value=direction(priority))
            priority += 1

    def _parse_filter_param(self, name: FieldName, value: str, types_map: TypesMapType) -> Param | None:
        if name not in types_map:
            raise Exception(f"Type of parameter {name} must be defined in types map")

        for pattern, handler in self._pattern_handler_map(types_map).items():
            if mo := pattern.search(value):
                return handler(name, mo.group(1))

        return None

    @staticmethod
    def _split_key_value(pair: str) -> tuple[str, str]:
        key, value = pair.split("=", 1)
        return key, value

    def _interval(
        self,
        field_name: str,
        field_value: str,
        operations: RangeFactories,
        types_map: TypesMapType,
    ) -> Param:
        if types_map[field_name] not in interval_types:
            raise TypeError(f"Interval cannot be calculated for type {types_map[field_name]} for field {field_name}")

        _data = field_value.split(",")
        if len(_data) != 2:
            raise ValueError(f"Range must contain strictly two members for field {field_name}")

        left: RangeLeftClause | None = None
        right: RangeRightClause | None = None
        if _data[0]:
            left = operations[0](self._cast(_data[0], types_map[field_name]))
        if _data[1]:
            right = operations[1](self._cast(_data[1], types_map[field_name]))
        value = codex.operator.RANGE(left, right)

        return Param(name=field_name, value=value)

    def _multitude(
        self,
        field_name: str,
        field_value: str,
        types_map: TypesMapType,
        inversion: bool = False,
    ) -> Param:
        clause_value: codex.operator.SET[typing.Any] = codex.operator.SET(
            *(self._cast(v, types_map[field_name]) for v in field_value.split(",") if v)
        )

        value: codex.operator.ClauseBit
        if inversion:
            value = codex.operator.NOT(clause_value)
        else:
            value = clause_value

        return Param(name=field_name, value=value)

    def _literal(
        self,
        field_name: str,
        field_value: str,
        operation: LiteralFactory,
        types_map: TypesMapType,
        inversion: bool = False,
    ) -> Param:
        literal = operation(self._cast(field_value, types_map[field_name]))
        value: codex.operator.ClauseBit

        if inversion:
            value = codex.operator.NOT(typing.cast(SupportedNotOperand, literal))
        else:
            value = literal

        return Param(name=field_name, value=value)

    def _cast(self, value: str, type_: type) -> typing.Any:
        if cast := self._casting_map.get(type_):
            return cast(value)
        return type_(value)

    def _pattern_handler_map(self, types_map: TypesMapType) -> collections.abc.Mapping[re.Pattern[str], ClauseHandler]:
        return {
            re.compile("^(null)$"): lambda x, y: Param(name=x, value=codex.operator.IS(None)),
            re.compile("^(!null)$"): lambda x, y: Param(name=x, value=codex.operator.NOT(codex.operator.IS(None))),
            re.compile(r"^\(([\dTZ:\-,.]+)\)$"): self._interval_handler(
                operations=(codex.operator.GT, codex.operator.LT),
                types_map=types_map,
            ),
            re.compile(r"^\[([\dTZ:\-,.]+)\)$"): self._interval_handler(
                operations=(codex.operator.GE, codex.operator.LT),
                types_map=types_map,
            ),
            re.compile(r"^\(([\dTZ:\-,.]+)]$"): self._interval_handler(
                operations=(codex.operator.GT, codex.operator.LE),
                types_map=types_map,
            ),
            re.compile(r"^\[([\dTZ:\-,.]+)]$"): self._interval_handler(
                operations=(codex.operator.GE, codex.operator.LE),
                types_map=types_map,
            ),
            re.compile(r"^!{(.*)}$"): self._multitude_handler(inversion=True, types_map=types_map),
            re.compile(r"^{(.*)}$"): self._multitude_handler(types_map=types_map),
            re.compile(r"^~{2}(.*)$"): self._literal_handler(
                operation=codex.operator.LIKE,
                types_map=types_map,
            ),
            re.compile(r"^![~]{2}(.*)$"): self._literal_handler(
                operation=codex.operator.LIKE,
                types_map=types_map,
                inversion=True,
            ),
            re.compile(r"^~(.*)$"): self._literal_handler(
                operation=functools.partial(codex.operator.LIKE, case_sensitive=True),
                types_map=types_map,
            ),
            re.compile(r"^!~(.*)$"): self._literal_handler(
                operation=functools.partial(codex.operator.LIKE, case_sensitive=True),
                types_map=types_map,
                inversion=True,
            ),
            re.compile(r"^!(.*)$"): self._literal_handler(
                operation=codex.operator.EQ,
                types_map=types_map,
                inversion=True,
            ),
            re.compile(r"(.*)"): self._literal_handler(
                operation=codex.operator.EQ,
                types_map=types_map,
            ),
        }

    def _interval_handler(
        self,
        *,
        operations: RangeFactories,
        types_map: TypesMapType,
    ) -> ClauseHandler:
        def handler(field_name: FieldName, field_value: str) -> Param:
            return self._interval(field_name, field_value, operations, types_map)

        return handler

    def _multitude_handler(self, *, types_map: TypesMapType, inversion: bool = False) -> ClauseHandler:
        def handler(field_name: FieldName, field_value: str) -> Param:
            return self._multitude(field_name, field_value, types_map, inversion=inversion)

        return handler

    def _literal_handler(
        self,
        *,
        operation: LiteralFactory,
        types_map: TypesMapType,
        inversion: bool = False,
    ) -> ClauseHandler:
        def handler(field_name: FieldName, field_value: str) -> Param:
            return self._literal(field_name, field_value, operation, types_map, inversion=inversion)

        return handler
