from datetime import datetime
from typing import Any, Literal, Optional

from dateutil.parser import isoparse
from fastapi import APIRouter, HTTPException
from pydantic import Field, model_validator
from starlette.requests import Request
from strawberry.relay import GlobalID
from typing_extensions import Self

from phoenix.db import models
from phoenix.db.helpers import SupportedSQLDialect
from phoenix.db.insertion.helpers import insert_on_conflict
from phoenix.server.api.types.node import from_global_id_with_expected_type
from phoenix.server.dml_event import ExperimentRunAnnotationInsertEvent

from .models import V1RoutesBaseModel
from .utils import ResponseBody, add_errors_to_responses

router = APIRouter(tags=["experiments"], include_in_schema=True)


class ExperimentEvaluationResult(V1RoutesBaseModel):
    label: Optional[str] = Field(default=None, description="The label assigned by the evaluation")
    score: Optional[float] = Field(default=None, description="The score assigned by the evaluation")
    explanation: Optional[str] = Field(
        default=None, description="Explanation of the evaluation result"
    )


class UpsertExperimentEvaluationRequestBody(V1RoutesBaseModel):
    experiment_run_id: str = Field(description="The ID of the experiment run being evaluated")
    name: str = Field(description="The name of the evaluation")
    annotator_kind: Literal["LLM", "CODE", "HUMAN"] = Field(
        description="The kind of annotator used for the evaluation"
    )
    start_time: datetime = Field(description="The start time of the evaluation in ISO format")
    end_time: datetime = Field(description="The end time of the evaluation in ISO format")
    result: Optional[ExperimentEvaluationResult] = Field(
        None, description="The result of the evaluation. Either result or error must be provided."
    )
    error: Optional[str] = Field(
        None,
        description="Error message if the evaluation encountered an error. "
        "Either result or error must be provided.",
    )
    metadata: Optional[dict[str, Any]] = Field(
        default=None, description="Metadata for the evaluation"
    )
    trace_id: Optional[str] = Field(default=None, description="Optional trace ID for tracking")

    @model_validator(mode="after")
    def validate_result_or_error(self) -> Self:
        if self.result is None and self.error is None:
            raise ValueError("Either 'result' or 'error' must be provided")
        return self


class UpsertExperimentEvaluationResponseBodyData(V1RoutesBaseModel):
    id: str = Field(description="The ID of the upserted experiment evaluation")


class UpsertExperimentEvaluationResponseBody(
    ResponseBody[UpsertExperimentEvaluationResponseBodyData]
):
    pass


@router.post(
    "/experiment_evaluations",
    operation_id="upsertExperimentEvaluation",
    summary="Create or update evaluation for an experiment run",
    responses=add_errors_to_responses(
        [{"status_code": 404, "description": "Experiment run not found"}]
    ),
)
async def upsert_experiment_evaluation(
    request: Request, request_body: UpsertExperimentEvaluationRequestBody
) -> UpsertExperimentEvaluationResponseBody:
    payload = await request.json()
    experiment_run_gid = GlobalID.from_id(payload["experiment_run_id"])
    try:
        experiment_run_id = from_global_id_with_expected_type(experiment_run_gid, "ExperimentRun")
    except ValueError:
        raise HTTPException(
            detail=f"ExperimentRun with ID {experiment_run_gid} does not exist",
            status_code=404,
        )
    name = request_body.name
    annotator_kind = request_body.annotator_kind
    result = request_body.result
    label = result.label if result else None
    score = result.score if result else None
    explanation = result.explanation if result else None
    error = request_body.error
    metadata = request_body.metadata or {}
    start_time = payload["start_time"]
    end_time = payload["end_time"]
    async with request.app.state.db() as session:
        values = dict(
            experiment_run_id=experiment_run_id,
            name=name,
            annotator_kind=annotator_kind,
            label=label,
            score=score,
            explanation=explanation,
            error=error,
            metadata_=metadata,  # `metadata_` must match database
            start_time=isoparse(start_time),
            end_time=isoparse(end_time),
            trace_id=payload.get("trace_id"),
        )
        dialect = SupportedSQLDialect(session.bind.dialect.name)
        exp_eval_run = await session.scalar(
            insert_on_conflict(
                values,
                dialect=dialect,
                table=models.ExperimentRunAnnotation,
                unique_by=("experiment_run_id", "name"),
            ).returning(models.ExperimentRunAnnotation)
        )
    evaluation_gid = GlobalID("ExperimentEvaluation", str(exp_eval_run.id))
    request.state.event_queue.put(ExperimentRunAnnotationInsertEvent((exp_eval_run.id,)))
    return UpsertExperimentEvaluationResponseBody(
        data=UpsertExperimentEvaluationResponseBodyData(id=str(evaluation_gid))
    )
