from collections.abc import AsyncIterable
from datetime import datetime
from typing import ClassVar, Optional, cast

import strawberry
from sqlalchemy import and_, func, or_, select
from sqlalchemy.sql.functions import count
from strawberry import UNSET
from strawberry.relay import Connection, GlobalID, Node, NodeID
from strawberry.scalars import JSON
from strawberry.types import Info

from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import BadRequest
from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
from phoenix.server.api.types.DatasetExample import DatasetExample
from phoenix.server.api.types.DatasetVersion import DatasetVersion
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
from phoenix.server.api.types.node import from_global_id_with_expected_type
from phoenix.server.api.types.pagination import (
    ConnectionArgs,
    CursorString,
    connection_from_list,
)
from phoenix.server.api.types.SortDir import SortDir


@strawberry.type
class Dataset(Node):
    _table: ClassVar[type[models.Base]] = models.Experiment
    id_attr: NodeID[int]
    name: str
    description: Optional[str]
    metadata: JSON
    created_at: datetime
    updated_at: datetime

    @strawberry.field
    async def versions(
        self,
        info: Info[Context, None],
        first: Optional[int] = 50,
        last: Optional[int] = UNSET,
        after: Optional[CursorString] = UNSET,
        before: Optional[CursorString] = UNSET,
        sort: Optional[DatasetVersionSort] = UNSET,
    ) -> Connection[DatasetVersion]:
        args = ConnectionArgs(
            first=first,
            after=after if isinstance(after, CursorString) else None,
            last=last,
            before=before if isinstance(before, CursorString) else None,
        )
        async with info.context.db() as session:
            stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
            if sort:
                # For now assume the the column names match 1:1 with the enum values
                sort_col = getattr(models.DatasetVersion, sort.col.value)
                if sort.dir is SortDir.desc:
                    stmt = stmt.order_by(sort_col.desc(), models.DatasetVersion.id.desc())
                else:
                    stmt = stmt.order_by(sort_col.asc(), models.DatasetVersion.id.asc())
            else:
                stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
            versions = await session.scalars(stmt)
        data = [
            DatasetVersion(
                id_attr=version.id,
                description=version.description,
                metadata=version.metadata_,
                created_at=version.created_at,
            )
            for version in versions
        ]
        return connection_from_list(data=data, args=args)

    @strawberry.field(
        description="Number of examples in a specific version if version is specified, or in the "
        "latest version if version is not specified."
    )  # type: ignore
    async def example_count(
        self,
        info: Info[Context, None],
        dataset_version_id: Optional[GlobalID] = UNSET,
    ) -> int:
        dataset_id = self.id_attr
        version_id = (
            from_global_id_with_expected_type(
                global_id=dataset_version_id,
                expected_type_name=DatasetVersion.__name__,
            )
            if dataset_version_id
            else None
        )
        revision_ids = (
            select(func.max(models.DatasetExampleRevision.id))
            .join(models.DatasetExample)
            .where(models.DatasetExample.dataset_id == dataset_id)
            .group_by(models.DatasetExampleRevision.dataset_example_id)
        )
        if version_id:
            version_id_subquery = (
                select(models.DatasetVersion.id)
                .where(models.DatasetVersion.dataset_id == dataset_id)
                .where(models.DatasetVersion.id == version_id)
                .scalar_subquery()
            )
            revision_ids = revision_ids.where(
                models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
            )
        stmt = (
            select(count(models.DatasetExampleRevision.id))
            .where(models.DatasetExampleRevision.id.in_(revision_ids))
            .where(models.DatasetExampleRevision.revision_kind != "DELETE")
        )
        async with info.context.db() as session:
            return (await session.scalar(stmt)) or 0

    @strawberry.field
    async def examples(
        self,
        info: Info[Context, None],
        dataset_version_id: Optional[GlobalID] = UNSET,
        first: Optional[int] = 50,
        last: Optional[int] = UNSET,
        after: Optional[CursorString] = UNSET,
        before: Optional[CursorString] = UNSET,
    ) -> Connection[DatasetExample]:
        args = ConnectionArgs(
            first=first,
            after=after if isinstance(after, CursorString) else None,
            last=last,
            before=before if isinstance(before, CursorString) else None,
        )
        dataset_id = self.id_attr
        version_id = (
            from_global_id_with_expected_type(
                global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
            )
            if dataset_version_id
            else None
        )
        revision_ids = (
            select(func.max(models.DatasetExampleRevision.id))
            .join(models.DatasetExample)
            .where(models.DatasetExample.dataset_id == dataset_id)
            .group_by(models.DatasetExampleRevision.dataset_example_id)
        )
        if version_id:
            version_id_subquery = (
                select(models.DatasetVersion.id)
                .where(models.DatasetVersion.dataset_id == dataset_id)
                .where(models.DatasetVersion.id == version_id)
                .scalar_subquery()
            )
            revision_ids = revision_ids.where(
                models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
            )
        query = (
            select(models.DatasetExample)
            .join(
                models.DatasetExampleRevision,
                onclause=models.DatasetExample.id
                == models.DatasetExampleRevision.dataset_example_id,
            )
            .where(
                and_(
                    models.DatasetExampleRevision.id.in_(revision_ids),
                    models.DatasetExampleRevision.revision_kind != "DELETE",
                )
            )
            .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
        )
        async with info.context.db() as session:
            dataset_examples = [
                DatasetExample(
                    id_attr=example.id,
                    version_id=version_id,
                    created_at=example.created_at,
                )
                async for example in await session.stream_scalars(query)
            ]
        return connection_from_list(data=dataset_examples, args=args)

    @strawberry.field(
        description="Number of experiments for a specific version if version is specified, "
        "or for all versions if version is not specified."
    )  # type: ignore
    async def experiment_count(
        self,
        info: Info[Context, None],
        dataset_version_id: Optional[GlobalID] = UNSET,
    ) -> int:
        stmt = select(count(models.Experiment.id)).where(
            models.Experiment.dataset_id == self.id_attr
        )
        version_id = (
            from_global_id_with_expected_type(
                global_id=dataset_version_id,
                expected_type_name=DatasetVersion.__name__,
            )
            if dataset_version_id
            else None
        )
        if version_id is not None:
            stmt = stmt.where(models.Experiment.dataset_version_id == version_id)
        async with info.context.db() as session:
            return (await session.scalar(stmt)) or 0

    @strawberry.field
    async def experiments(
        self,
        info: Info[Context, None],
        first: Optional[int] = 50,
        last: Optional[int] = UNSET,
        after: Optional[CursorString] = UNSET,
        before: Optional[CursorString] = UNSET,
        filter_condition: Optional[str] = UNSET,
        filter_ids: Optional[
            list[GlobalID]
        ] = UNSET,  # this is a stopgap until a query DSL is implemented
    ) -> Connection[Experiment]:
        args = ConnectionArgs(
            first=first,
            after=after if isinstance(after, CursorString) else None,
            last=last,
            before=before if isinstance(before, CursorString) else None,
        )
        dataset_id = self.id_attr
        row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
        query = (
            select(models.Experiment, row_number)
            .where(models.Experiment.dataset_id == dataset_id)
            .order_by(models.Experiment.id.desc())
        )
        if filter_condition is not UNSET and filter_condition:
            # Search both name and description columns with case-insensitive partial matching
            search_filter = or_(
                models.Experiment.name.ilike(f"%{filter_condition}%"),
                models.Experiment.description.ilike(f"%{filter_condition}%"),
            )
            query = query.where(search_filter)

        if filter_ids:
            filter_rowids = []
            for filter_id in filter_ids:
                try:
                    filter_rowids.append(
                        from_global_id_with_expected_type(
                            global_id=filter_id,
                            expected_type_name=Experiment.__name__,
                        )
                    )
                except ValueError:
                    raise BadRequest(f"Invalid filter ID: {filter_id}")
            query = query.where(models.Experiment.id.in_(filter_rowids))

        async with info.context.db() as session:
            experiments = [
                to_gql_experiment(experiment, sequence_number)
                async for experiment, sequence_number in cast(
                    AsyncIterable[tuple[models.Experiment, int]],
                    await session.stream(query),
                )
            ]
        return connection_from_list(data=experiments, args=args)

    @strawberry.field
    async def experiment_annotation_summaries(
        self, info: Info[Context, None]
    ) -> list[ExperimentAnnotationSummary]:
        dataset_id = self.id_attr
        repetition_mean_scores_by_example_subquery = (
            select(
                models.ExperimentRunAnnotation.name.label("annotation_name"),
                func.avg(models.ExperimentRunAnnotation.score).label("mean_repetition_score"),
            )
            .select_from(models.ExperimentRunAnnotation)
            .join(
                models.ExperimentRun,
                models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
            )
            .join(
                models.Experiment,
                models.ExperimentRun.experiment_id == models.Experiment.id,
            )
            .where(models.Experiment.dataset_id == dataset_id)
            .group_by(
                models.ExperimentRun.dataset_example_id,
                models.ExperimentRunAnnotation.name,
            )
            .subquery()
            .alias("repetition_mean_scores_by_example")
        )
        repetition_mean_scores_subquery = (
            select(
                repetition_mean_scores_by_example_subquery.c.annotation_name.label(
                    "annotation_name"
                ),
                func.avg(repetition_mean_scores_by_example_subquery.c.mean_repetition_score).label(
                    "mean_score"
                ),
            )
            .select_from(repetition_mean_scores_by_example_subquery)
            .group_by(
                repetition_mean_scores_by_example_subquery.c.annotation_name,
            )
            .subquery()
            .alias("repetition_mean_scores")
        )
        repetitions_subquery = (
            select(
                models.ExperimentRunAnnotation.name.label("annotation_name"),
                func.min(models.ExperimentRunAnnotation.score).label("min_score"),
                func.max(models.ExperimentRunAnnotation.score).label("max_score"),
                func.count().label("count"),
                func.count(models.ExperimentRunAnnotation.error).label("error_count"),
            )
            .select_from(models.ExperimentRunAnnotation)
            .join(
                models.ExperimentRun,
                models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
            )
            .join(
                models.Experiment,
                models.ExperimentRun.experiment_id == models.Experiment.id,
            )
            .where(models.Experiment.dataset_id == dataset_id)
            .group_by(models.ExperimentRunAnnotation.name)
            .subquery()
        )
        run_scores_query = (
            select(
                repetition_mean_scores_subquery.c.annotation_name.label("annotation_name"),
                repetition_mean_scores_subquery.c.mean_score.label("mean_score"),
                repetitions_subquery.c.min_score.label("min_score"),
                repetitions_subquery.c.max_score.label("max_score"),
                repetitions_subquery.c.count.label("count_"),
                repetitions_subquery.c.error_count.label("error_count"),
            )
            .select_from(repetition_mean_scores_subquery)
            .join(
                repetitions_subquery,
                repetitions_subquery.c.annotation_name
                == repetition_mean_scores_subquery.c.annotation_name,
            )
            .order_by(repetition_mean_scores_subquery.c.annotation_name)
        )
        async with info.context.db() as session:
            return [
                ExperimentAnnotationSummary(
                    annotation_name=scores_tuple.annotation_name,
                    min_score=scores_tuple.min_score,
                    max_score=scores_tuple.max_score,
                    mean_score=scores_tuple.mean_score,
                    count=scores_tuple.count_,
                    error_count=scores_tuple.error_count,
                )
                async for scores_tuple in await session.stream(run_scores_query)
            ]

    @strawberry.field
    def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
        return info.context.last_updated_at.get(self._table, self.id_attr)


def to_gql_dataset(dataset: models.Dataset) -> Dataset:
    """
    Converts an ORM dataset to a GraphQL dataset.
    """
    return Dataset(
        id_attr=dataset.id,
        name=dataset.name,
        description=dataset.description,
        metadata=dataset.metadata_,
        created_at=dataset.created_at,
        updated_at=dataset.updated_at,
    )
