from __future__ import annotations

import logging
from collections.abc import Callable
from typing import Any, TypeVar

from injectq import Inject
from litellm.types.utils import ModelResponse

from pyagenity.checkpointer import BaseCheckpointer
from pyagenity.publisher.base_publisher import BasePublisher
from pyagenity.state import AgentState, ExecutionStatus
from pyagenity.state.base_context import BaseContextManager
from pyagenity.state.execution_state import ExecutionState as ExecMeta
from pyagenity.utils import (
    END,
    START,
    Command,
    Message,
    ResponseGranularity,
    add_messages,
)
from pyagenity.utils.background_task_manager import BackgroundTaskManager
from pyagenity.utils.streaming import EventModel


StateT = TypeVar("StateT", bound=AgentState)

logger = logging.getLogger(__name__)


async def parse_response(
    state: AgentState,
    messages: list[Message],
    response_granularity: ResponseGranularity = ResponseGranularity.LOW,
) -> dict[str, Any]:
    """Parse response based on granularity."""
    match response_granularity:
        case ResponseGranularity.FULL:
            # Return full state and messages
            return {"state": state, "messages": messages}
        case ResponseGranularity.PARTIAL:
            # Return state and summary of messages
            return {
                "state": None,
                "context": state.context,
                "summary": state.context_summary,
                "message": messages,
            }
        case ResponseGranularity.LOW:
            # Return all messages from state context
            return {"messages": state.context or []}

    return {"messages": messages}


# Utility to update only provided fields in state
def _update_state_fields(state, partial: dict):
    """Update only the provided fields in the state object."""
    for k, v in partial.items():
        # Avoid updating special fields
        if k in ("context", "context_summary", "execution_meta"):
            continue
        if hasattr(state, k):
            setattr(state, k, v)


async def load_or_create_state[StateT: AgentState](
    input_data: dict[str, Any],
    config: dict[str, Any],
    old_state: StateT,
    checkpointer: BaseCheckpointer = Inject[BaseCheckpointer],  # will be auto-injected
) -> StateT:
    """Load existing state from checkpointer or create new state.

    Attempts to fetch a realtime-synced state first, then falls back to
    the persistent checkpointer. If no existing state is found, creates
    a new state from the `StateGraph`'s prototype state and merges any
    incoming messages. Supports partial state update via 'state' in input_data.
    """
    logger.debug("Loading or creating state with thread_id=%s", config.get("thread_id", "default"))

    # Try to load existing state if checkpointer is available
    if checkpointer:
        logger.debug("Attempting to load existing state from checkpointer")
        # first check realtime-synced state
        existing_state: StateT | None = await checkpointer.aget_state_cache(config)
        if not existing_state:
            logger.debug("No synced state found, trying persistent checkpointer")
            # If no synced state, try to get from persistent checkpointer
            existing_state = await checkpointer.aget_state(config)

        if existing_state:
            logger.info(
                "Loaded existing state with %d context messages, current_node=%s, step=%d",
                len(existing_state.context) if existing_state.context else 0,
                existing_state.execution_meta.current_node,
                existing_state.execution_meta.step,
            )
            # Merge new messages with existing context
            new_messages = input_data.get("messages", [])
            if new_messages:
                logger.debug("Merging %d new messages with existing context", len(new_messages))
                existing_state.context = add_messages(existing_state.context, new_messages)
            # Merge partial state fields if provided
            partial_state = input_data.get("state", {})
            if partial_state and isinstance(partial_state, dict):
                logger.debug("Merging partial state with %d fields", len(partial_state))
                _update_state_fields(existing_state, partial_state)
            # Update current node if available
            if "current_node" in partial_state and partial_state["current_node"] is not None:
                existing_state.set_current_node(partial_state["current_node"])
            return existing_state
    else:
        logger.debug("No checkpointer available, will create new state")

    # Create new state by deep copying the graph's prototype state
    logger.info("Creating new state from graph prototype")
    import copy

    state = copy.deepcopy(old_state)

    # Ensure core AgentState fields are properly initialized
    if hasattr(state, "context") and not isinstance(state.context, list):
        state.context = []
        logger.debug("Initialized empty context list")
    if hasattr(state, "context_summary") and state.context_summary is None:
        state.context_summary = None
        logger.debug("Initialized context_summary as None")
    if hasattr(state, "execution_meta"):
        # Create a fresh execution metadata
        state.execution_meta = ExecMeta(current_node=START)
        logger.debug("Created fresh execution metadata starting at %s", START)

    # Set thread_id in execution metadata
    thread_id = config.get("thread_id", "default")
    state.execution_meta.thread_id = thread_id
    logger.debug("Set thread_id to %s", thread_id)

    # Merge new messages with context
    new_messages = input_data.get("messages", [])
    if new_messages:
        logger.debug("Adding %d new messages to fresh state", len(new_messages))
        state.context = add_messages(state.context, new_messages)
    # Merge partial state fields if provided
    partial_state = input_data.get("state", {})
    if partial_state and isinstance(partial_state, dict):
        logger.debug("Merging partial state with %d fields", len(partial_state))
        _update_state_fields(state, partial_state)

    logger.info(
        "Created new state with %d context messages", len(state.context) if state.context else 0
    )
    if "current_node" in partial_state and partial_state["current_node"] is not None:
        state.set_current_node(partial_state["current_node"])
    return state  # type: ignore[return-value]


# def get_default_event(
#     state: AgentState,
#     source: str = "graph",
#     event_type: str = "initialize",
#     input_data: dict[str, Any] | None = None,
#     config: dict[str, Any] | None = None,
#     meta: dict[str, Any] | None = None,
# ) -> Event:
#     from pyagenity.utils.streaming import EventModel

#     metadata = meta or {}
#     metadata["step"] = state.execution_meta.step
#     metadata["current_node"] = state.execution_meta.current_node

#     # Map old SourceType to new Event enum
#     if source == "message":
#         mapped_event = StreamEvent.MESSAGE
#     elif source == "state":
#         mapped_event = StreamEvent.STATE
#     elif source == "node":
#         mapped_event = StreamEvent.NODE_EXECUTION
#     elif source == "tool":
#         mapped_event = StreamEvent.TOOL_EXECUTION
#     else:  # graph or unknown
#         mapped_event = StreamEvent.STATE

#     # Map old EventType to new EventType enum
#     from pyagenity.utils.streaming import EventType as StreamEventType

#     if event_type == "initialize":
#         mapped_event_type = StreamEventType.START
#     elif event_type == "invoked":
#         mapped_event_type = StreamEventType.START
#     elif event_type == "running":
#         mapped_event_type = StreamEventType.PROGRESS
#     elif event_type == "completed":
#         mapped_event_type = StreamEventType.END
#     elif event_type == "interrupted":
#         mapped_event_type = StreamEventType.END
#     elif event_type == "error":
#         mapped_event_type = StreamEventType.END
#     else:
#         mapped_event_type = StreamEventType.UPDATE

#     return EventModel(
#         event=mapped_event,
#         event_type=mapped_event_type,
#         data={
#             "input_keys": list(input_data.keys()) if input_data else [],
#             "is_resume": state.is_interrupted(),
#             "config": config or {},
#         },
#         metadata=metadata,
#         node_name=state.execution_meta.current_node,
#         run_id=config.get("run_id", "") if config else "",
#     )


def process_node_result[StateT: AgentState](
    result: Any,
    state: StateT,
    messages: list[Message],
) -> tuple[StateT, list[Message], str | None]:
    """
    Processes the result from a node execution, updating the agent state, message list,
    and determining the next node.

    Supports:
        - Handling results of type Command, AgentState, Message, list, str, dict, ModelResponse,
            or other types.
        - Deduplicating messages by message_id.
        - Updating the agent state and its context with new messages.
        - Extracting navigation information (next node) from Command results.

    Args:
        result (Any): The output from a node execution. Can be a Command, AgentState, Message,
            list, str, dict, ModelResponse, or other types.
        state (StateT): The current agent state.
        messages (list[Message]): The list of messages accumulated so far.

    Returns:
        tuple[StateT, list[Message], str | None]:
            - The updated agent state.
            - The updated list of messages (with new, unique messages added).
            - The identifier of the next node to execute, if specified; otherwise, None.
    """
    next_node = None
    existing_ids = {msg.message_id for msg in messages}
    new_messages = []

    def add_unique_message(msg: Message) -> None:
        """Add message only if it doesn't already exist."""
        if msg.message_id not in existing_ids:
            new_messages.append(msg)
            existing_ids.add(msg.message_id)

    def create_and_add_message(content: Any) -> Message:
        """Create message from content and add if unique."""
        if isinstance(content, str):
            msg = Message.from_text(content)
        elif isinstance(content, dict):
            try:
                msg = Message.from_dict(content)
            except Exception as e:
                raise ValueError(f"Invalid message dict: {e}") from e
        elif isinstance(content, ModelResponse):
            msg = Message.from_response(content)
        else:
            msg = Message.from_text(str(content))

        add_unique_message(msg)
        return msg

    def handle_state_message(old_state: StateT, new_state: StateT) -> None:
        """Handle state messages by updating the context."""
        old_messages = {}
        if old_state.context:
            old_messages = {msg.message_id: msg for msg in old_state.context}

        if not new_state.context:
            return
        # now save all the new messages
        for msg in new_state.context:
            if msg.message_id in old_messages:
                continue
            # otherwise save it
            add_unique_message(msg)

    # Process different result types
    if isinstance(result, Command):
        # Handle state updates
        if result.update:
            if isinstance(result.update, AgentState):
                handle_state_message(state, result.update)  # type: ignore[assignment]
                state = result.update  # type: ignore[assignment]
            else:
                create_and_add_message(result.update)

        # Handle navigation
        next_node = result.goto

    elif isinstance(result, AgentState):
        handle_state_message(state, result)  # type: ignore[assignment]
        state = result  # type: ignore[assignment]

    elif isinstance(result, Message):
        add_unique_message(result)

    elif isinstance(result, list):
        # Handle list of items (convert each to message)
        for item in result:
            create_and_add_message(item)
    else:
        # Handle single items (str, dict, ModelResponse, or other)
        create_and_add_message(result)

    # Add new messages to the main list and state context
    if new_messages:
        messages.extend(new_messages)
        state.context = add_messages(state.context, new_messages)

    return state, messages, next_node


async def check_and_handle_interrupt(
    interrupt_before: list[str],
    interrupt_after: list[str],
    current_node: str,
    interrupt_type: str,
    state: AgentState,
    config: dict[str, Any],
    _sync_data: Callable,
) -> bool:
    """Check for interrupts and save state if needed. Returns True if interrupted."""
    interrupt_nodes = interrupt_before if interrupt_type == "before" else interrupt_after

    if current_node in interrupt_nodes:
        status = (
            ExecutionStatus.INTERRUPTED_BEFORE
            if interrupt_type == "before"
            else ExecutionStatus.INTERRUPTED_AFTER
        )
        state.set_interrupt(
            current_node,
            f"interrupt_{interrupt_type}: {current_node}",
            status,
        )
        # Save state and interrupt
        await _sync_data(state, config, [])
        logger.debug("Node '%s' interrupted", current_node)
        return True

    logger.debug(
        "No interrupts found for node '%s', continuing execution",
        current_node,
    )
    return False


def get_next_node(
    current_node: str,
    state: AgentState,
    edges: list,
) -> str:
    """Get the next node to execute based on edges."""
    # Find outgoing edges from current node
    outgoing_edges = [e for e in edges if e.from_node == current_node]

    if not outgoing_edges:
        logger.debug("No outgoing edges from node '%s', ending execution", current_node)
        return END

    # Handle conditional edges
    for edge in outgoing_edges:
        if edge.condition:
            try:
                condition_result = edge.condition(state)
                if hasattr(edge, "condition_result") and edge.condition_result is not None:
                    # Mapped conditional edge
                    if condition_result == edge.condition_result:
                        return edge.to_node
                elif isinstance(condition_result, str):
                    return condition_result
                elif condition_result:
                    return edge.to_node
            except Exception:
                logger.exception("Error evaluating condition for edge: %s", edge)
                continue

    # Return first static edge if no conditions matched
    static_edges = [e for e in outgoing_edges if not e.condition]
    if static_edges:
        return static_edges[0].to_node

    logger.debug("No valid edges found from node '%s', ending execution", current_node)
    return END


async def call_realtime_sync(
    state: AgentState,
    config: dict[str, Any],
    checkpointer: BaseCheckpointer = Inject[BaseCheckpointer],  # will be auto-injected
) -> None:
    """Call the realtime state sync hook if provided."""
    if checkpointer:
        logger.debug("Calling realtime state sync hook")
        # await call_sync_or_async(checkpointer.a, config, state)
        await checkpointer.aput_state_cache(config, state)


async def sync_data(
    state: AgentState,
    config: dict[str, Any],
    messages: list[Message],
    trim: bool = False,
    checkpointer: BaseCheckpointer = Inject[BaseCheckpointer],  # will be auto-injected
    context_manager: BaseContextManager = Inject[BaseContextManager],  # will be auto-injected
) -> bool:
    """Sync the current state and messages to the checkpointer."""
    import copy

    is_context_trimmed = False

    new_state = copy.deepcopy(state)
    # if context manager is available then utilize it
    if context_manager and trim:
        new_state = await context_manager.atrim_context(state)
        is_context_trimmed = True

    # first sync with realtime then main db
    await call_realtime_sync(state, config, checkpointer)
    logger.debug("Persisting state and %d messages to checkpointer", len(messages))

    if checkpointer:
        await checkpointer.aput_state(config, new_state)
        if messages:
            await checkpointer.aput_messages(config, messages)

    return is_context_trimmed


async def _publish_event_task(
    event: EventModel,
    publisher: BasePublisher | None,
) -> None:
    """Publish an event if publisher is configured."""
    if publisher:
        try:
            await publisher.publish(event)
            logger.debug("Published event: %s", event)
        except Exception as e:
            logger.error("Failed to publish event: %s", e)


def publish_event(
    event: EventModel,
    publisher: BasePublisher | None = Inject[BasePublisher],
    task_manager: BackgroundTaskManager = Inject[BackgroundTaskManager],
) -> None:
    """Publish an event if publisher is configured."""
    # Store the task to prevent it from being garbage collected
    task_manager.create_task(_publish_event_task(event, publisher))
