"""Plugins for dependency injection tools."""

from collections import defaultdict
from collections.abc import Callable
from functools import wraps
from typing import Any

from funcy_bear.context.di.types import Bindable, CollectionChoice, Params, Return, ReturnedCallable
from funcy_bear.context.di.wiring import ParamsReturn, parse_params
from funcy_bear.tools.names import Names
from funcy_bear.type_stuffs.validate import is_mapping


class Getter[V](Bindable):
    """A tool for getting fields from a document."""

    doc: Any

    def bind(self, doc: Any, **kwargs) -> None:  # noqa: ARG002
        """Bind a document to the getter."""
        self.doc = doc

    def __call__(self, field: str, doc: Any | None = None) -> V:
        """Retrieve a field from the bound document."""
        if doc is not None:
            self.doc = doc
        if is_mapping(self.doc):
            return self.doc[field]
        return getattr(self.doc, field)


class Setter[V](Bindable):
    """A tool for setting fields on a document."""

    doc: Any

    def bind(self, doc: Any, **kwargs) -> None:  # noqa: ARG002
        """Bind a document to the setter."""
        self.doc = doc

    def __call__(self, field: str, value: V, return_val: bool = False, doc: Any | None = None) -> V | None:
        """Set a field on the bound document."""
        if doc is not None:
            self.doc = doc
        if is_mapping(self.doc):
            self.doc[field] = value
            if return_val:
                return value
        else:
            setattr(self.doc, field, value)
            if return_val:
                return value
        return None


class Deleter(Bindable):
    """A tool for deleting fields from a document."""

    doc: Any

    def bind(self, doc: Any, **kwargs) -> None:  # noqa: ARG002
        """Bind a document to the deleter."""
        self.doc = doc

    def __call__(self, field: str, doc: Any | None = None) -> None:
        """Delete a field from the bound document."""
        if doc is not None:
            self.doc = doc
        if is_mapping(self.doc):
            del self.doc[field]  # type: ignore[index]
            return
        delattr(self.doc, field)


class Factory(Bindable):
    """A factory tool for creating collections."""

    def bind(self, doc: Any, **kwargs) -> None:  # noqa: ARG002
        """Bind a document to the factory. No-op for Factory."""
        default: ReturnedCallable = default_factory(kwargs.pop("default_factory", "dict"))
        default = dict if default is None else default
        self._factory_override: Callable = default

    def __call__(self) -> Any:
        """Default factory function to create collections based on choice."""
        choice: CollectionChoice = "dict"
        if hasattr(self, "_factory_override"):
            return self._factory_override()
        return default_factory(choice=choice)


def default_factory(choice: CollectionChoice = "dict", **kwargs) -> ReturnedCallable:
    """Return a factory function based on the specified choice."""
    if factory := kwargs.pop("override", False):
        return factory
    match choice:
        case "list":
            return list
        case "set":
            return set
        case "dict":
            return dict
        case "defaultdict":
            return defaultdict
        case _:
            raise ValueError(f"Invalid choice: {choice}")


class ToolContext(Names):
    """A context that holds tool instances for document manipulation."""

    def __init__(self, **kwargs: Any) -> None:
        """Initialize the ToolContext with the tools, passed in the container and kwargs."""
        super().__init__(**kwargs)


def inject_tools(**kws) -> Callable[..., Callable[..., Return]]:  # pyright: ignore[reportInvalidTypeVarUse]
    """Decorator that auto-injects tool dependencies, allowing for delayed execution."""

    def decorator(op_func: Callable[Params, Return]) -> Callable[Params, Return]:
        @wraps(op_func)
        def wrapper(*args: Params.args, **kwargs: Params.kwargs) -> Return:
            kwargs["__passed_kwargs__"] = kws
            returned: ParamsReturn = parse_params(op_func, *args, **kwargs)
            if returned.payload is not None:
                op_factory: Callable[..., Return] = returned.payload
                return op_factory(*returned.args, **returned.kwargs)
            return op_func(*returned.args, **returned.kwargs)

        return wrapper

    return decorator  # pyright: ignore[reportReturnType]
