from enum import Enum
from typing import Any, MutableMapping, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from pydantic.json_schema import SkipJsonSchema

from .document import Document


class Node(BaseModel):
    """A single operation in a logical plan."""

    model_config = ConfigDict(extra="allow", use_attribute_docstrings=True)

    node_type: Optional[str] = Field(default=None)
    """The type of this node."""

    node_id: int
    """A unique integer ID representing this node."""

    description: Optional[str] = Field(None, json_schema_extra={"exclude_from_comparison": True})
    """A detailed description of why this operator was chosen for this query plan."""

    inputs: list[int] = []
    """A list of node IDs that this operation depends on."""


class LogicalPlan(BaseModel):
    """A logical query plan used to evaluate a query."""

    model_config = ConfigDict(use_attribute_docstrings=True)

    query: str
    """The query that the plan is for."""

    nodes: MutableMapping[int, SerializeAsAny[Node]]
    """A mapping of node IDs to nodes in the query plan."""

    result_node: int
    """The ID of the node that is the result of the query."""

    llm_prompt: Optional[Any] = None
    """The LLM prompt that was used to generate this query plan."""

    llm_plan: Optional[str] = None
    """The result generated by the LLM."""


class Query(BaseModel):
    """A query against a DocSet. Contains either a natural language string or a query plan."""

    model_config = ConfigDict(use_attribute_docstrings=True)

    docset_id: Optional[str] = None
    """The docset against which to run the query"""

    query: Optional[str] = None
    """The natural language query to run. if specified, `plan` must not be set."""

    plan: Optional[LogicalPlan] = None
    """The logical query plan to run. If specified, `query` must not be set."""

    stream: bool = False
    """If true, query results will be streamed back to the client as they are generated. Applies only when calling the query api."""

    summarize_result: bool = False
    """
    If true, an english summary of the result in context of the original query will be returned.
    Applies only when calling the query API and only available when `stream=True`
    """

    rag_mode: bool = False
    """
    If true, the query will only run a RAG query plan.
    """

    # Bookmarks are currently tied to the UI, so we hide them from the
    # generated OpenAPI schema so as not to confuse callers of the API.
    bookmark_source: SkipJsonSchema[Optional[str]] = None

    bookmark_target: SkipJsonSchema[Optional[str]] = None

    @model_validator(mode="after")
    def check_not_both_query_and_plan(self):
        if self.query is not None and self.plan is not None:
            raise ValueError("query and plan cannot both be specified")
        if self.query is None and self.plan is None:
            raise ValueError("one of query or plan is required")
        return self

    @model_validator(mode="after")
    def check_rag_mode_query_and_plan(self):
        if self.rag_mode and self.plan is not None:
            raise ValueError("plan must not be specified when rag_mode is True, use query instead")
        return self


class QueryResult(BaseModel):
    """The result of a non-streaming query."""

    model_config = ConfigDict(use_attribute_docstrings=True)

    query_id: str
    """The unique ID of the query operation."""

    plan: LogicalPlan
    """The logical query plan that was executed."""

    result: Any
    """The result of the query operation. Depending on the query, this could be a list of documents,
    a single document, a string, an integer, etc.
    """


class QueryTraceDoc(BaseModel):
    """A document in the trace of a query result."""

    model_config = ConfigDict(use_attribute_docstrings=True)

    node_id: int
    """The ID of the node in the query plan that produced this document."""

    doc: dict[str, Any]
    """The document data."""


class QueryEventType(str, Enum):
    """The type of event that occurred in the query trace."""

    COMPLETE = "complete"
    ERROR = "error"
    PLAN = "plan"
    RESULT = "result"
    RESULT_DOC = "result_doc"
    STATUS = "status"
    TRACE_DOC = "trace_doc"
    RESULT_SUMMARY = "result_summary"
    RAG_SOURCE = "rag_source"
    RAG_ANSWER = "rag_answer"
    CITATIONS = "citations"


class QueryEvent(BaseModel):
    """An event in the trace of a query result."""

    event_type: QueryEventType
    """The type of event."""

    data: Union[int, str, float, QueryTraceDoc, LogicalPlan, QueryResult, Document]
    """The data associated with the event."""
