"""Utility functions for building database queries."""

from typing import List, Optional
from uuid import UUID

from pydantic import BaseModel
from sqlmodel import col, select

from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
from lightly_studio.models.annotation_label import AnnotationLabelTable
from lightly_studio.models.metadata import SampleMetadataTable
from lightly_studio.models.sample import SampleTable
from lightly_studio.models.tag import TagTable
from lightly_studio.resolvers.metadata_resolver.metadata_filter import (
    MetadataFilter,
    apply_metadata_filters,
)
from lightly_studio.type_definitions import QueryType


class FilterDimensions(BaseModel):
    """Encapsulates dimension-based filter parameters for querying samples."""

    min: Optional[int] = None
    max: Optional[int] = None


class SampleFilter(BaseModel):
    """Encapsulates filter parameters for querying samples."""

    width: Optional[FilterDimensions] = None
    height: Optional[FilterDimensions] = None
    annotation_label_ids: Optional[List[UUID]] = None
    tag_ids: Optional[List[UUID]] = None
    metadata_filters: Optional[List[MetadataFilter]] = None

    def apply(self, query: QueryType) -> QueryType:
        """Apply the filters to the given query."""
        # Apply dimension-based filters to the query.
        if self.width:
            if self.width.min is not None:
                query = query.where(SampleTable.width >= self.width.min)
            if self.width.max is not None:
                query = query.where(SampleTable.width <= self.width.max)
        if self.height:
            if self.height.min is not None:
                query = query.where(SampleTable.height >= self.height.min)
            if self.height.max is not None:
                query = query.where(SampleTable.height <= self.height.max)

        # Apply annotation label filters to the query.
        if self.annotation_label_ids:
            sample_ids_subquery = (
                select(AnnotationBaseTable.sample_id)
                .select_from(AnnotationBaseTable)
                .join(AnnotationBaseTable.annotation_label)
                .where(col(AnnotationLabelTable.annotation_label_id).in_(self.annotation_label_ids))
                .distinct()
            )
            query = query.where(col(SampleTable.sample_id).in_(sample_ids_subquery))

        # Apply tag filters to the query.
        if self.tag_ids:
            sample_ids_subquery = (
                select(SampleTable.sample_id)
                .select_from(SampleTable)
                .join(SampleTable.tags)
                .where(col(TagTable.tag_id).in_(self.tag_ids))
                .distinct()
            )
            query = query.where(col(SampleTable.sample_id).in_(sample_ids_subquery))

        # Apply metadata filters to the query.
        if self.metadata_filters:
            query = apply_metadata_filters(
                query,
                self.metadata_filters,
                metadata_model=SampleMetadataTable,
                metadata_join_condition=SampleMetadataTable.sample_id == SampleTable.sample_id,
            )
        return query
