"""Module with repository implementation for SQLAlchemy"""

import copy
import json
import logging
import uuid
from abc import abstractmethod
from enum import Enum
from typing import Any

import sqlalchemy.dialects.postgresql as psql
import sqlalchemy.dialects.mssql as mssql
from sqlalchemy import Column, MetaData, and_, create_engine, or_, orm, text
from sqlalchemy import types as sa_types
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import DatabaseError
from sqlalchemy import inspect
from sqlalchemy.types import CHAR, TypeDecorator

from protean.core.database_model import BaseDatabaseModel
from protean.core.queryset import ResultSet
from protean.core.value_object import BaseValueObject
from protean.exceptions import (
    ConfigurationError,
    DatabaseError as ProteanDatabaseError,
    ObjectNotFoundError,
)
from protean.fields import (
    Auto,
    Boolean,
    Date,
    DateTime,
    Dict,
    Field,
    Float,
    Identifier,
    Integer,
    List,
    String,
    Text,
    ValueObject,
)
from protean.fields.association import Reference, _ReferenceField
from protean.fields.embedded import _ShadowField
from protean.port.dao import BaseDAO, BaseLookup
from protean.port.provider import BaseProvider
from protean.utils import IdentityType
from protean.utils.container import Options
from protean.utils.globals import current_domain, current_uow
from protean.utils.query import Q
from protean.utils.reflection import attributes, id_field

logging.getLogger("sqlalchemy").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)


class GUID(TypeDecorator):
    """Platform-independent GUID type.

    Uses PostgreSQL's UUID type, otherwise uses
    CHAR(32), storing as stringified hex values.

    """

    impl = CHAR
    cache_ok = True

    def load_dialect_impl(self, dialect):
        if dialect.name == "postgresql":
            return dialect.type_descriptor(psql.UUID())
        elif dialect.name == "mssql":
            return dialect.type_descriptor(mssql.UNIQUEIDENTIFIER())
        else:
            return dialect.type_descriptor(CHAR(32))

    def process_bind_param(self, value, dialect):
        if value is None:
            return value
        elif dialect.name == "postgresql":
            return str(value)
        elif dialect.name == "mssql":
            return value if isinstance(value, uuid.UUID) else uuid.UUID(value)
        else:
            if not isinstance(value, uuid.UUID):
                return "%.32x" % uuid.UUID(value).int
            else:
                # hexstring
                return "%.32x" % value.int

    def process_result_value(self, value, dialect):
        if value is None:
            return value
        else:
            if not isinstance(value, uuid.UUID):
                value = uuid.UUID(value)
            return value


class MSSQLJSON(TypeDecorator):
    """JSON type for MSSQL using NVARCHAR storage with automatic serialization."""

    impl = mssql.NVARCHAR
    cache_ok = True

    def __init__(self, length=None):
        # Use NVARCHAR(MAX) by default for JSON storage
        super().__init__(length=length)

    def process_bind_param(self, value, dialect):
        """Serialize Python objects to JSON string when binding to database."""
        if value is None:
            return value
        return _custom_json_dumps(value)

    def process_result_value(self, value, dialect):
        """Deserialize JSON string back to Python objects when reading from database."""
        if value is None:
            return value
        try:
            return json.loads(value)
        except (json.JSONDecodeError, TypeError):
            # If we can't parse as JSON, return the raw value
            return value


def _get_identity_type():
    """Retrieve the configured data type for AutoGenerated Identifiers

    If `current_domain` is not yet available, it simply means that Protean is still being loaded.
    Default to `Identity.STRING`
    """
    if current_domain.config["identity_type"] == IdentityType.INTEGER.value:
        return sa_types.Integer
    elif current_domain.config["identity_type"] == IdentityType.STRING.value:
        return sa_types.String
    elif current_domain.config["identity_type"] == IdentityType.UUID.value:
        return GUID
    else:
        raise ConfigurationError(
            f"Unknown Identity Type {current_domain.config['identity_type']}"
        )


def _default(value):
    """A function that gets called for objects that can't otherwise be serialized.
    We handle the special case of Value Objects here.

    `TypeError` is raised for unknown types.
    """
    if isinstance(value, BaseValueObject):
        return value.to_dict()
    raise TypeError()


def _custom_json_dumps(value):
    """Custom JSON Serializer method to handle the special case of ValueObject deserialization.

    This method is passed into sqlalchemy as a value for param `json_serializer` in the call to `create_engine`.
    """
    return json.dumps(value, default=_default)


class SqlalchemyModel(orm.DeclarativeBase, BaseDatabaseModel):
    """Model representation for the Sqlalchemy Database"""

    def __init_subclass__(subclass, **kwargs):  # noqa: C901
        field_mapping = {
            Boolean: sa_types.Boolean,
            Date: sa_types.Date,
            DateTime: sa_types.DateTime,
            Dict: sa_types.PickleType,
            Float: sa_types.Float,
            Identifier: _get_identity_type(),
            Integer: sa_types.Integer,
            List: sa_types.PickleType,
            String: sa_types.String,
            Text: sa_types.Text,
            _ReferenceField: _get_identity_type(),
            ValueObject: sa_types.PickleType,
        }

        def field_mapping_for(field_obj: Field):
            """Return SQLAlchemy-equivalent type for Protean's field"""
            field_cls = type(field_obj)

            if field_cls is Auto:
                if field_obj.increment is True:
                    return sa_types.Integer
                else:
                    return _get_identity_type()

            return field_mapping.get(field_cls)

        # Update the class attrs with the entity attributes
        if "meta_" in subclass.__dict__:
            entity_cls = subclass.__dict__["meta_"].part_of
            for _, field_obj in attributes(entity_cls).items():
                attribute_name = field_obj.attribute_name

                # Map the field if not in attributes
                if attribute_name not in subclass.__dict__:
                    # Derive field based on field enclosed within ShadowField
                    if isinstance(field_obj, _ShadowField):
                        field_obj = field_obj.field_obj

                    field_cls = type(field_obj)
                    type_args = []
                    type_kwargs = {}

                    # Get the SA type
                    sa_type_cls = field_mapping_for(field_obj)

                    # Upgrade to Database-specific Data Types
                    dialect_name = subclass.__dict__["engine"].dialect.name
                    if dialect_name == "postgresql":
                        if field_cls == Dict and not field_obj.pickled:
                            sa_type_cls = psql.JSON

                        if field_cls == List and not field_obj.pickled:
                            sa_type_cls = psql.ARRAY

                            # Associate Content Type
                            if field_obj.content_type:
                                # Treat `ValueObject` differently because it is a field object instance,
                                #   not a field type class
                                #
                                # `ValueObject` instances are essentially treated as `Dict`. If not pickled,
                                #   they are persisted as JSON.
                                if isinstance(field_obj.content_type, ValueObject):
                                    if not field_obj.pickled:
                                        field_mapping_type = psql.JSON
                                    else:
                                        field_mapping_type = sa_types.PickleType
                                else:
                                    field_mapping_type = field_mapping.get(
                                        field_obj.content_type
                                    )
                                type_args.append(field_mapping_type)
                            else:
                                type_args.append(sa_types.Text)
                    elif dialect_name == "mssql":
                        # SQL Server doesn't have native JSON/Array types, use custom JSON type for JSON-like data
                        if field_cls == Dict and not field_obj.pickled:
                            sa_type_cls = MSSQLJSON
                            type_args.append(
                                None
                            )  # NVARCHAR(MAX) with JSON serialization

                        if field_cls == List and not field_obj.pickled:
                            sa_type_cls = MSSQLJSON
                            type_args.append(
                                None
                            )  # Store as JSON string with serialization

                    # Default to the text type if no mapping is found
                    if not sa_type_cls:
                        sa_type_cls = sa_types.String

                    # Build the column arguments
                    col_args = {
                        "primary_key": field_obj.identifier,
                        "nullable": not field_obj.required,
                        "unique": field_obj.unique,
                    }

                    # Update the arguments based on the field type
                    if issubclass(field_cls, String):
                        type_kwargs["length"] = field_obj.max_length
                    # If the field is an Auto field and the type is a string, set the length to 255
                    #   Without explicit length, we are leaving the decision to the database. And that
                    #   will not work for MSSQL.
                    elif issubclass(field_cls, Auto) and sa_type_cls == sa_types.String:
                        type_kwargs["length"] = 255

                    # Update the attributes of the class
                    column = Column(sa_type_cls(*type_args, **type_kwargs), **col_args)
                    setattr(subclass, attribute_name, column)  # Set class attribute

        super().__init_subclass__(**kwargs)

    @orm.declared_attr
    def __tablename__(cls):
        return cls.derive_schema_name()

    @classmethod
    def from_entity(cls, entity):
        """Convert the entity to a model object"""
        item_dict = {}
        for attribute_obj in attributes(cls.meta_.part_of).values():
            if isinstance(attribute_obj, Reference):
                item_dict[attribute_obj.relation.attribute_name] = (
                    attribute_obj.relation.value
                )
            else:
                item_dict[attribute_obj.attribute_name] = getattr(
                    entity, attribute_obj.attribute_name
                )
        return cls(**item_dict)

    @classmethod
    def to_entity(cls, model_obj: "SqlalchemyModel"):
        """Convert the model object to an entity"""
        item_dict = {}
        for field_name in attributes(cls.meta_.part_of):
            item_dict[field_name] = getattr(model_obj, field_name, None)
        return cls.meta_.part_of(item_dict)


class SADAO(BaseDAO):
    """DAO implementation for Databases compliant with SQLAlchemy"""

    def __repr__(self) -> str:
        return f"SQLAlchemyDAO <{self.entity_cls.__name__}>"

    def _get_session(self):
        """Returns an active connection to the persistence store.

        - If there is an active transaction, the connection associated with the transaction (in the UoW) is returned
        - If the DAO has been explicitly instructed to work outside a UoW (with the help of `_outside_uow`), or if
            there are no active transactions, a new connection is retrieved from the provider and returned.

        Overridden here instead of using the version in `BaseDAO` because the connection needs to be started
            with a call to `begin()` if it is not yet active (checked with `is_active`)
        """
        if current_uow and not self._outside_uow:
            return current_uow.get_session(self.provider.name)
        else:
            new_connection = self.provider.get_connection()
            if not new_connection.is_active:
                new_connection.begin()
            return new_connection

    def _build_filters(self, criteria: Q):
        """Recursively Build the filters from the criteria object"""
        # Decide the function based on the connector type
        func = and_ if criteria.connector == criteria.AND else or_
        params = []
        for child in criteria.children:
            if isinstance(child, Q):
                # Call the function again with the child
                params.append(self._build_filters(child))
            else:
                # Find the lookup class and the key
                stripped_key, lookup_class = self.provider._extract_lookup(child[0])

                # Instantiate the lookup class and get the expression
                lookup = lookup_class(stripped_key, child[1], self.database_model_cls)
                if criteria.negated:
                    params.append(~lookup.as_expression())
                else:
                    params.append(lookup.as_expression())

        return func(*params)

    def _filter(
        self, criteria: Q, offset: int = 0, limit: int = 10, order_by: list = ()
    ) -> ResultSet:
        """Filter objects from the sqlalchemy database"""
        conn = self._get_session()
        qs = conn.query(self.database_model_cls)

        # Build the filters from the criteria
        if criteria.children:
            qs = qs.filter(self._build_filters(criteria))

        # Apply the order by clause if present
        order_cols = []
        for order_col in order_by:
            col = getattr(self.database_model_cls, order_col.lstrip("-"))
            if order_col.startswith("-"):
                order_cols.append(col.desc())
            else:
                order_cols.append(col)

        # It is better to have explicitly order results, instead of relying on default
        #   db behavior. Postgresql and SQLite do not force ordering, but their result
        #   order is undefined. MSSQL does not support OFFSET/LIMIT without an ORDER BY clause.
        # So, we order by primary key ascending when there is no order specified.
        if not order_cols:
            order_cols.append(
                getattr(
                    self.database_model_cls, id_field(self.entity_cls).attribute_name
                ).asc()
            )

        qs = qs.order_by(*order_cols)
        qs_without_limit = qs
        qs = qs.limit(limit).offset(offset)

        # Return the results
        try:
            items = qs.all()
            result = ResultSet(
                offset=offset, limit=limit, total=qs_without_limit.count(), items=items
            )
        except Exception as exc:
            logger.error(f"Error while filtering: {exc}")
            raise
        finally:
            if not current_uow:
                conn.commit()
                conn.close()

        return result

    def _create(self, model_obj):
        """Add a new record to the sqlalchemy database"""
        conn = self._get_session()

        conn.add(model_obj)

        if not current_uow:
            try:
                conn.commit()
            except DatabaseError as exc:
                logger.error(f"Error while creating: {exc}")
                conn.rollback()
                raise ProteanDatabaseError(
                    f"Database error during creation: {str(exc)}",
                    original_exception=exc,
                )
            finally:
                conn.close()

        return model_obj

    def _update(self, model_obj):
        """Update a record in the sqlalchemy database"""
        conn = self._get_session()
        db_item = None

        # Fetch the record from database
        try:
            identifier = getattr(model_obj, id_field(self.entity_cls).attribute_name)
            db_item = conn.get(
                self.database_model_cls, identifier
            )  # This will raise exception if object was not found
        except DatabaseError as exc:
            logger.error(f"Database Record not found: {exc}")
            raise

        if db_item is None:
            conn.rollback()
            conn.close()
            raise ObjectNotFoundError(
                f"`{self.entity_cls.__name__}` object with identifier {identifier} "
                f"does not exist."
            )

        # Sync DB Record with current changes. When the session is committed, changes are automatically synced
        for attribute in attributes(self.entity_cls):
            if attribute != id_field(self.entity_cls).attribute_name and getattr(
                model_obj, attribute
            ) != getattr(db_item, attribute):
                setattr(db_item, attribute, getattr(model_obj, attribute))

        if not current_uow:
            try:
                conn.commit()
            except DatabaseError as exc:
                logger.error(f"Error while updating: {exc}")
                conn.rollback()
                raise ProteanDatabaseError(
                    f"Database error during update: {str(exc)}", original_exception=exc
                )
            finally:
                conn.close()

        return model_obj

    def _update_all(self, criteria: Q, *args, **kwargs):
        """Update all objects satisfying the criteria"""
        conn = self._get_session()
        qs = conn.query(self.database_model_cls).filter(self._build_filters(criteria))
        try:
            values = {}
            if args:
                values = args[
                    0
                ]  # `args[0]` is required because `*args` is sent as a tuple
            values.update(kwargs)
            updated_count = qs.update(values)
        except DatabaseError as exc:
            logger.error(f"Error while updating all: {exc}")
            raise
        finally:
            if not current_uow:
                conn.commit()
                conn.close()

        return updated_count

    def _delete(self, model_obj):
        """Delete the entity record in the dictionary"""
        conn = self._get_session()
        db_item = None

        # Fetch the record from database
        try:
            identifier = getattr(model_obj, id_field(self.entity_cls).attribute_name)
            db_item = conn.get(
                self.database_model_cls, identifier
            )  # This will raise exception if object was not found
        except DatabaseError as exc:
            logger.error(f"Database Record not found: {exc}")
            raise

        if db_item is None:
            conn.rollback()
            conn.close()
            raise ObjectNotFoundError(
                f"`{self.entity_cls.__name__}` object with identifier {identifier} "
                f"does not exist."
            )

        try:
            conn.delete(db_item)
        except DatabaseError as exc:
            logger.error(f"Error while deleting: {exc}")
            raise
        finally:
            if not current_uow:
                conn.commit()
                conn.close()

        return model_obj

    def _delete_all(self, criteria: Q = None):
        """Delete a record from the sqlalchemy database"""
        conn = self._get_session()

        del_count = 0
        if criteria:
            qs = conn.query(self.database_model_cls).filter(
                self._build_filters(criteria)
            )
        else:
            qs = conn.query(self.database_model_cls)

        try:
            del_count = qs.delete()
        except DatabaseError as exc:
            logger.error(f"Error while deleting all: {exc}")
            raise
        finally:
            if not current_uow:
                conn.commit()
                conn.close()

        return del_count

    def _raw(self, query: Any, data: Any = None):
        """Run a raw query on the repository and return entity objects"""
        assert isinstance(query, str)

        conn = self._get_session()
        try:
            results = conn.execute(text(query))

            entity_items = []
            for item in results:
                entity = self.database_model_cls.to_entity(item)
                entity.state_.mark_retrieved()
                entity_items.append(entity)

            result = ResultSet(
                offset=0,
                limit=len(entity_items),
                total=len(entity_items),
                items=entity_items,
            )
        except DatabaseError as exc:
            logger.error(f"Error while running raw query: {exc}")
            raise
        finally:
            if not current_uow:
                conn.commit()
                conn.close()

        return result

    def has_table(self) -> bool:
        """Check if the table exists in the database.

        Returns True if the table exists, False otherwise.
        """
        inspector = inspect(self.provider._engine)

        # Get the schema from the metadata, if any
        schema = self.provider._metadata.schema

        # Check if the table exists in the schema
        return inspector.has_table(self.schema_name, schema=schema)


class SAProvider(BaseProvider):
    """Provider Implementation class for SQLAlchemy"""

    class databases(Enum):
        postgresql = "postgresql"
        sqlite = "sqlite"
        mssql = "mssql"

    def _additional_engine_args(self):
        """Construct additional arguments for the engine"""
        extra_args = self._get_database_specific_engine_args()

        # Explicit database-specific arguments can override defaults
        extra_args.update(
            {
                key: value
                for key, value in self.conn_info.items()
                if key not in ["provider", "database_uri", "schema"]
            }
        )

        return extra_args

    def _get_default_schema(self, database):
        try:
            return {
                self.databases.postgresql.value: "public",
                self.databases.mssql.value: "dbo",
            }[database]
        except KeyError:
            return None

    def __init__(self, name, domain, conn_info: dict):
        """Initialize and maintain Engine"""
        super().__init__(name, domain, conn_info)

        self._engine = create_engine(
            make_url(self.conn_info["database_uri"]),
            json_serializer=_custom_json_dumps,
            **self._additional_engine_args(),
        )

        # Use `schema` value if specified as part of the conn info. Otherwise, construct
        #   and use default schema name as `DB`_schema.
        schema = (
            self.conn_info["schema"]
            if "schema" in self.conn_info
            else self._get_default_schema(self.__database__)
        )

        self._metadata = MetaData(schema=schema)

        # A temporary cache of already constructed model classes
        self._database_model_classes = {}

    @abstractmethod
    def _get_database_specific_engine_args(self):
        """Supplies additional database-specific arguments to SQLAlchemy Engine.

        Return: a dictionary with database-specific SQLAlchemy Engine arguments.
        """

    @abstractmethod
    def _get_database_specific_session_args(self):
        """Set Database specific session parameters.

        Depending on the database in use, this method supplies
        additional arguments while constructing sessions.

        Return: a dictionary with additional arguments and values.
        """

    def get_session(self):
        """Establish a session to the Database"""
        # Create the session
        kwargs = self._get_database_specific_session_args()
        session_factory = orm.sessionmaker(
            bind=self._engine, expire_on_commit=False, **kwargs
        )
        session_cls = orm.scoped_session(session_factory)

        return session_cls

    @abstractmethod
    def _execute_database_specific_connection_statements(self, conn):
        """Execute connection statements depending on the database in use.

        Each database has a unique set of commands and associated format to control
        connection-related parameters. Since we use SQLAlchemy, statements should
        be run dynamically based on the database in use.

        Arguments:
        * conn: An active connection object to the database

        Return: None
        """

    def get_connection(self, session_cls=None):
        """Create the connection to the Database instance"""
        # If this connection has to be created within an existing session,
        #   ``session_cls`` will be provided as an argument.
        #   Otherwise, fetch a new ``session_cls`` from ``get_session()``
        if session_cls is None:
            session_cls = self.get_session()

        conn = session_cls()
        conn = self._execute_database_specific_connection_statements(conn)

        return conn

    def is_alive(self) -> bool:
        """Check if the connection to the database is alive"""
        try:
            conn = self.get_connection()
            conn.execute(text("SELECT 1"))
            if not current_uow:  # If not in a UoW, we need to close the connection
                conn.close()
            return True
        except DatabaseError as e:
            logger.error(
                f"Could not connect to database at {self.conn_info['database_uri']}"
            )
            logger.error(f"Error: {e}")
            return False

    def _data_reset(self):
        conn = self._engine.connect()

        transaction = conn.begin()

        if self.__database__ == self.databases.sqlite.value:
            conn.execute(text("PRAGMA foreign_keys = OFF;"))

        for table in self._metadata.sorted_tables:
            conn.execute(table.delete())

        if self.__database__ == self.databases.sqlite.value:
            conn.execute(text("PRAGMA foreign_keys = ON;"))

        transaction.commit()
        conn.close()  # Explicitly close the connection to avoid connection leaks

        # Discard any active Unit of Work
        if current_uow and current_uow.in_progress:
            current_uow.rollback()

    def close(self):
        """Close the provider and clean up all connections.

        Disposes of the SQLAlchemy engine which closes all connections in the pool
        and frees up database resources.
        """
        if hasattr(self, "_engine") and self._engine:
            self._engine.dispose()

    def _create_database_artifacts(self):
        # Create tables for all registered aggregates, entities, and projections

        # Loop through self.domain.registry._elements and extract the classes under
        #   the keys 'AGGREGATE', 'ENTITY', and 'PROJECTION'
        #   We don't use properties because we want to access even the internal elements
        elements = {}

        for element_type in ["AGGREGATE", "ENTITY", "PROJECTION"]:
            if element_type in self.domain.registry._elements:
                elements.update(self.domain.registry._elements[element_type])

        for _, element_record in elements.items():
            self.domain.repository_for(element_record.cls)._dao

        # Create all tables in a single transaction
        conn = self._engine.connect()
        try:
            transaction = conn.begin()
            self._metadata.create_all(conn)
            transaction.commit()
        finally:
            conn.close()

    def _drop_database_artifacts(self):
        # Drop all tables in a single transaction
        conn = self._engine.connect()
        try:
            transaction = conn.begin()
            self._metadata.drop_all(conn)
            transaction.commit()
        finally:
            conn.close()

        self._metadata.clear()

    def decorate_database_model_class(self, entity_cls, database_model_cls):
        schema_name = database_model_cls.derive_schema_name()

        # Return the model class if it was already seen/decorated
        if schema_name in self._database_model_classes:
            return self._database_model_classes[schema_name]

        # If `database_model_cls` is already subclassed from SqlAlchemyModel,
        #   this method call is a no-op
        if issubclass(database_model_cls, SqlalchemyModel):
            return database_model_cls
        else:
            # Strip out `Column` attributes from the model class
            # Create a deep copy to make this work
            # https://stackoverflow.com/a/62528033/1858466
            columns = copy.deepcopy(
                {
                    key: value
                    for key, value in vars(database_model_cls).items()
                    if isinstance(value, Column)
                }
            )

            custom_attrs = {
                key: value
                for (key, value) in vars(database_model_cls).items()
                if key not in ["Meta", "__module__", "__doc__", "__weakref__"]
                and not isinstance(value, Column)
            }

            # Add the earlier copied columns to the custom attributes
            custom_attrs = {**custom_attrs, **columns}

            meta_ = Options(database_model_cls.meta_)
            meta_.part_of = entity_cls
            meta_.schema_name = (
                schema_name if meta_.schema_name is None else meta_.schema_name
            )

            custom_attrs.update(
                {"meta_": meta_, "engine": self._engine, "metadata": self._metadata}
            )
            # FIXME Ensure the custom model attributes are constructed properly
            decorated_database_database_model_cls = type(
                database_model_cls.__name__,
                (SqlalchemyModel, database_model_cls),
                custom_attrs,
            )

            # Memoize the constructed model class
            self._database_model_classes[schema_name] = (
                decorated_database_database_model_cls
            )

            return decorated_database_database_model_cls

    def construct_database_model_class(self, entity_cls):
        """Return a fully-baked Model class for a given Entity class"""
        database_model_cls = None

        # Return the model class if it was already seen/decorated
        if entity_cls.meta_.schema_name in self._database_model_classes:
            database_model_cls = self._database_model_classes[
                entity_cls.meta_.schema_name
            ]
        else:
            # Construct a new Meta object with existing values
            meta_ = Options()
            meta_.part_of = entity_cls
            # If schema_name is not provided, sqlalchemy can throw
            #   sqlalchemy.exc.InvalidRequestError: Class does not
            #   have a __table__ or __tablename__ specified and
            #   does not inherit from an existing table-mapped class
            meta_.schema_name = entity_cls.meta_.schema_name

            attrs = {
                "meta_": meta_,
                "engine": self._engine,
                "metadata": self._metadata,
            }
            # FIXME Ensure the custom model attributes are constructed properly
            database_model_cls = type(
                entity_cls.__name__ + "Model", (SqlalchemyModel,), attrs
            )

            # Memoize the constructed model class
            self._database_model_classes[entity_cls.meta_.schema_name] = (
                database_model_cls
            )

        # Set Entity Class as a class level attribute for the Model, to be able to reference later.
        return database_model_cls

    def get_dao(self, entity_cls, database_model_cls):
        """Return a DAO object configured with a live connection"""
        return SADAO(self.domain, self, entity_cls, database_model_cls)

    def raw(self, query: Any, data: Any = None):
        """Run raw query on Provider"""
        if data is None:
            data = {}
        assert isinstance(query, str)
        assert isinstance(data, (dict, None))

        conn = self.get_connection()
        try:
            result = conn.execute(text(query), data)
            return result
        finally:
            if not current_uow:  # If not in a UoW, we need to close the connection
                conn.close()


class PostgresqlProvider(SAProvider):
    __database__ = SAProvider.databases.postgresql.value

    def _get_database_specific_engine_args(self) -> dict:
        """Supplies additional database-specific arguments to SQLAlchemy Engine.

        Return: a dictionary with database-specific SQLAlchemy Engine arguments.
        """
        return {"isolation_level": "AUTOCOMMIT"}

    def _get_database_specific_session_args(self) -> dict:
        """Set Database specific session parameters.

        Depending on the database in use, this method supplies
        additional arguments while constructing sessions.

        Return: a dictionary with additional arguments and values.
        """
        return {"autoflush": False}

    def _execute_database_specific_connection_statements(self, conn):
        """Execute connection statements depending on the database in use.
        Overridden implementation for PostgreSQL.
        Arguments:
        * conn: An active connection object to the database

        Return: Updated connection object
        """
        return conn


class SqliteProvider(SAProvider):
    __database__ = SAProvider.databases.sqlite.value

    def _get_database_specific_engine_args(self) -> dict:
        """Supplies additional database-specific arguments to SQLAlchemy Engine.

        Return: a dictionary with database-specific SQLAlchemy Engine arguments.
        """
        return {}

    def _get_database_specific_session_args(self) -> dict:
        """Set Database specific session parameters.

        Depending on the database in use, this method supplies
        additional arguments while constructing sessions.

        Return: a dictionary with additional arguments and values.
        """
        return {}

    def _execute_database_specific_connection_statements(self, conn):
        """Execute connection statements depending on the database in use.
        Overridden implementation for SQLite.

        Arguments:
        * conn: An active connection object to the database

        Return: Updated connection object
        """
        conn.execute(text("PRAGMA case_sensitive_like = ON;"))

        return conn


class MssqlProvider(SAProvider):
    __database__ = SAProvider.databases.mssql.value

    def _get_database_specific_engine_args(self) -> dict:
        """Supplies additional database-specific arguments to SQLAlchemy Engine.

        Return: a dictionary with database-specific SQLAlchemy Engine arguments.
        """
        return {
            "isolation_level": "AUTOCOMMIT",
            "pool_pre_ping": True,  # Enable connection health checks
        }

    def _get_database_specific_session_args(self) -> dict:
        """Set Database specific session parameters.

        Depending on the database in use, this method supplies
        additional arguments while constructing sessions.

        Return: a dictionary with additional arguments and values.
        """
        return {"autoflush": False}

    def _execute_database_specific_connection_statements(self, conn):
        """Execute connection statements depending on the database in use.
        Overridden implementation for SQL Server.

        Arguments:
        * conn: An active connection object to the database

        Return: Updated connection object
        """
        # Set SQL Server specific options if needed
        # For example, you might want to set specific isolation levels
        # conn.execute(text("SET TRANSACTION ISOLATION LEVEL READ COMMITTED;"))
        return conn


operators = {
    "exact": "__eq__",
    "iexact": "ilike",
    "contains": "contains",
    "icontains": "ilike",
    "startswith": "startswith",
    "endswith": "endswith",
    "gt": "__gt__",
    "gte": "__ge__",
    "lt": "__lt__",
    "lte": "__le__",
    "in": "in_",
    "any": "any",
    "overlap": "overlap",
}


class DefaultLookup(BaseLookup):
    """Base class with default implementation of expression construction"""

    def __init__(self, source, target, database_model_cls):
        """Source is LHS and Target is RHS of a comparsion"""
        self.database_model_cls = database_model_cls
        super().__init__(source, target)

    def process_source(self):
        """Return source with transformations, if any"""
        source_col = getattr(self.database_model_cls, self.source)
        return source_col

    def process_target(self):
        """Return target with transformations, if any"""
        return self.target

    def as_expression(self):
        lookup_func = getattr(self.process_source(), operators[self.lookup_name])
        return lookup_func(self.process_target())


@SAProvider.register_lookup
class Exact(DefaultLookup):
    """Exact Match Query"""

    lookup_name = "exact"


@SAProvider.register_lookup
class IExact(DefaultLookup):
    """Exact Case-Insensitive Match Query"""

    lookup_name = "iexact"


@SAProvider.register_lookup
class Contains(DefaultLookup):
    """Exact Contains Query"""

    lookup_name = "contains"


@SAProvider.register_lookup
class IContains(DefaultLookup):
    """Case-Insensitive Contains Query"""

    lookup_name = "icontains"

    def process_target(self):
        """Return target in lowercase"""
        assert isinstance(self.target, str)
        return f"%{super().process_target()}%"


@SAProvider.register_lookup
class Startswith(DefaultLookup):
    """Exact Contains Query"""

    lookup_name = "startswith"


@SAProvider.register_lookup
class Endswith(DefaultLookup):
    """Exact Contains Query"""

    lookup_name = "endswith"


@SAProvider.register_lookup
class GreaterThan(DefaultLookup):
    """Greater than Query"""

    lookup_name = "gt"


@SAProvider.register_lookup
class GreaterThanOrEqual(DefaultLookup):
    """Greater than or Equal Query"""

    lookup_name = "gte"


@SAProvider.register_lookup
class LessThan(DefaultLookup):
    """Less than Query"""

    lookup_name = "lt"


@SAProvider.register_lookup
class LessThanOrEqual(DefaultLookup):
    """Less than or Equal Query"""

    lookup_name = "lte"


@SAProvider.register_lookup
class In(DefaultLookup):
    """In Query"""

    lookup_name = "in"

    def process_target(self):
        """Ensure target is a list or tuple"""
        assert isinstance(self.target, (list, tuple))
        return super().process_target()


@SAProvider.register_lookup
class Any(DefaultLookup):
    """Any Query"""

    lookup_name = "any"


@SAProvider.register_lookup
class Overlap(DefaultLookup):
    """Overlap Query"""

    lookup_name = "overlap"


class MSSQLStringLookupMixin:
    """Mixin to add MSSQL case-sensitive collation support to string lookups"""

    def _is_string_type(self, column):
        """Check if the column type is a string-based type"""
        if not hasattr(column, "type"):
            return False

        column_type = column.type

        # Check for SQLAlchemy string types
        string_types = (
            sa_types.String,
            sa_types.Text,
            sa_types.Unicode,
            sa_types.UnicodeText,
            mssql.NVARCHAR,
            mssql.VARCHAR,
            mssql.CHAR,
            mssql.NCHAR,
            mssql.TEXT,
            mssql.NTEXT,
        )

        # Also check for our custom MSSQLJSON type which is string-based for collation purposes
        if isinstance(column_type, MSSQLJSON):
            return True

        return isinstance(column_type, string_types)

    def process_source(self):
        """Return source column with case-sensitive collation applied only for string types"""
        source_col = getattr(self.database_model_cls, self.source)

        # Only apply collation to string/text columns
        if self._is_string_type(source_col):
            return source_col.collate("Latin1_General_CS_AS")

        # For non-string types, return the column as-is
        return source_col


@MssqlProvider.register_lookup
class MSSQLExact(MSSQLStringLookupMixin, DefaultLookup):
    """Case-sensitive exact match query for MSSQL using CS/BIN collation"""

    lookup_name = "exact"

    def as_expression(self):
        """Build the expression with case-sensitive collation for string fields only"""
        return self.process_source() == self.process_target()


@MssqlProvider.register_lookup
class MSSQLContains(MSSQLStringLookupMixin, DefaultLookup):
    """Case-sensitive contains query for MSSQL using CS/BIN collation"""

    lookup_name = "contains"

    def as_expression(self):
        """Build the contains expression with case-sensitive collation for string fields only"""
        return self.process_source().contains(self.process_target())


@MssqlProvider.register_lookup
class MSSQLStartswith(MSSQLStringLookupMixin, DefaultLookup):
    """Case-sensitive startswith query for MSSQL using CS/BIN collation"""

    lookup_name = "startswith"

    def as_expression(self):
        """Build the startswith expression with case-sensitive collation for string fields only"""
        return self.process_source().startswith(self.process_target())


@MssqlProvider.register_lookup
class MSSQLEndswith(MSSQLStringLookupMixin, DefaultLookup):
    """Case-sensitive endswith query for MSSQL using CS/BIN collation"""

    lookup_name = "endswith"

    def as_expression(self):
        """Build the endswith expression with case-sensitive collation for string fields only"""
        return self.process_source().endswith(self.process_target())
