"""Database abstraction layer."""

from __future__ import annotations

import copy
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, Type, TypeVar

import pydantic
from bson import ObjectId
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

T = TypeVar("T", bound="DatabaseItem")


class ForeignKey(Generic[T]):
    """A reference to another DatabaseItem."""

    def __init__(self, target_type: type[T], identifier: str):
        self.target_type = target_type
        self.identifier = identifier

    def __eq__(self, other: object) -> bool:
        return (
            isinstance(other, ForeignKey)
            and self.target_type == other.target_type
            and self.identifier == other.identifier
        )

    def __hash__(self) -> int:
        return hash((self.target_type, self.identifier))

    def __repr__(self) -> str:
        return f"ForeignKey({self.target_type.__name__}:{self.identifier})"

    #
    # --- Pydantic integration ---
    #
    @classmethod
    def __class_getitem__(cls, item: type[T]):
        target_type = item

        class _ForeignKey(cls):  # type: ignore
            __origin__ = cls
            __args__ = (item,)

            @classmethod
            def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
                def validator(v):
                    if isinstance(v, ForeignKey):
                        return v
                    if isinstance(v, target_type):
                        return ForeignKey(target_type, v.identifier)
                    if isinstance(v, str):
                        return ForeignKey(target_type, v)
                    raise TypeError(f"Cannot convert {v!r} to ForeignKey[{target_type.__name__}]")

                return core_schema.no_info_after_validator_function(
                    validator,
                    core_schema.union_schema(
                        [
                            core_schema.is_instance_schema(target_type),
                            core_schema.str_schema(),
                            core_schema.is_instance_schema(ForeignKey),
                        ]
                    ),
                )

            @classmethod
            def __get_pydantic_json_schema__(cls, _core_schema, handler):
                # Expose as string in OpenAPI
                return handler(core_schema.str_schema())

        return _ForeignKey


class PyObjectId(ObjectId):
    """Custom ObjectId type for Pydantic."""

    @classmethod
    def __get_pydantic_core_schema__(cls, _source, _handler) -> core_schema.PlainValidatorFunctionSchema:
        return core_schema.no_info_plain_validator_function(cls.validate)

    @classmethod
    def validate(cls, v: Any) -> PyObjectId:
        """Validate and convert to ObjectId."""
        if isinstance(v, ObjectId):
            return cls(v)
        if not ObjectId.is_valid(v):
            raise ValueError(f"Invalid ObjectId: {v}")
        return cls(v)


class DatabaseItem(ABC, pydantic.BaseModel):
    """Base class for database items."""

    model_config = pydantic.ConfigDict(
        revalidate_instances="always", json_encoders={ObjectId: str}, populate_by_name=True
    )
    identifier: PyObjectId = pydantic.Field(default_factory=PyObjectId, alias="_id")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, DatabaseItem):
            return NotImplemented
        return self.identifier == other.identifier

    def __hash__(self) -> int:
        return hash(self.identifier)


class DatabaseError(Exception):
    """Errors related to database operations."""


class UnknownEntityError(DatabaseError):
    """Requested entity does not exist."""


class Database(ABC):
    """Database abstraction."""

    @abstractmethod
    async def update(self, item: DatabaseItem) -> None:
        """Update entity."""

    @abstractmethod
    async def get(self, schema: Type[T], identifier: PyObjectId) -> T:
        """Return entity, raise UnknownEntityError if entity does not exist."""

    @abstractmethod
    async def get_all(self, schema: Type[T]) -> Dict[str, T]:
        """Return all entities of schema."""

    @abstractmethod
    async def delete(self, schema: Type[T], identifier: PyObjectId, cascade: bool = False) -> None:
        """Delete entity."""

    @abstractmethod
    async def find(self, schema: Type[T], **kwargs: str) -> Optional[Dict[PyObjectId, T]]:
        """Return all entities of schema matching the filter criteria."""

    @abstractmethod
    async def find_one(self, schema: Type[T], **kwargs: str) -> Optional[T]:
        """Return one entitiy of schema matching the filter criteria, raise if multiple exist."""


class DictDatabase(Database):
    """Simple Database implementation with dictionary."""

    def __init__(self) -> None:
        self.data: Dict[Type[DatabaseItem], Dict[PyObjectId, DatabaseItem]] = {}

    async def update(self, item: DatabaseItem) -> None:
        """Update data."""
        item_type = type(item)
        if item_type not in self.data:
            self.data[item_type] = {}
        self.data[item_type][item.identifier] = copy.deepcopy(item)

    async def get(self, schema: Type[T], identifier: PyObjectId) -> T:
        try:
            return self.data[schema][identifier]  # type: ignore
        except KeyError as exc:
            raise UnknownEntityError(f"Unknown identifier: {identifier}") from exc

    async def get_all(self, schema: Type[T]) -> Dict[str, T]:
        try:
            return self.data[schema]  # type: ignore
        except KeyError as exc:
            raise DatabaseError(f"Unkonwn schema: {schema}") from exc

    async def delete(self, schema: Type[T], identifier: PyObjectId, cascade: bool = False) -> None:
        try:
            del self.data[schema][identifier]
        except KeyError as exc:
            raise UnknownEntityError(f"Unknown identifier: {identifier}") from exc
        if cascade:
            for db in self.data:
                for identifier, item in self.data[db].items():
                    for attribute in item.__class__.model_fields:
                        if isinstance(attribute, ForeignKey) and attribute == item.identifier:
                            del self.data[db][identifier]

    async def find(self, schema: Type[T], **kwargs: str) -> Optional[Dict[PyObjectId, T]]:
        try:
            results = []
            for item in self.data[schema].values():  # type: ignore
                if all(getattr(item, k) == v for k, v in kwargs.items()):
                    results.append(item)  # type: ignore
            return {item.identifier: item for item in results}  # type: ignore
        except KeyError as exc:
            raise DatabaseError(f"Unkonwn schema: {schema}") from exc

    async def find_one(self, schema: Type[T], **kwargs: str) -> Optional[T]:
        if results := await self.find(schema, **kwargs):
            if len(results) > 1:
                raise DatabaseError(f"Multiple entities found for {schema} with {kwargs}")
            return list(results.values())[0]
        return None
