"""
EmberGraph: DAG-based orchestration system for financial agents.

Coordinates the execution of multiple agents (Clerk, Auditor, Modeler, Controller)
in a directed acyclic graph structure.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Set
from uuid import uuid4

from pydantic import BaseModel


class NodeStatus(str, Enum):
    """Execution status of a graph node."""

    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    SKIPPED = "skipped"


class AgentProtocol(Protocol):
    """Protocol that all EmberQuant agents must implement."""

    def execute(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Execute the agent with given inputs and return outputs."""
        ...

    @property
    def name(self) -> str:
        """Return the agent's name."""
        ...


@dataclass
class GraphNode:
    """A node in the EmberGraph representing an agent task."""

    agent: AgentProtocol
    id: str = field(default_factory=lambda: str(uuid4()))
    dependencies: List[str] = field(default_factory=list)
    status: NodeStatus = NodeStatus.PENDING
    inputs: Dict[str, Any] = field(default_factory=dict)
    outputs: Dict[str, Any] = field(default_factory=dict)
    error: Optional[str] = None
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None

    @property
    def duration(self) -> Optional[float]:
        """Return execution duration in seconds if completed."""
        if self.started_at and self.completed_at:
            return (self.completed_at - self.started_at).total_seconds()
        return None


class ExecutionPlan(BaseModel):
    """A plan for executing agents in a specific order."""

    task_description: str
    nodes: List[Dict[str, Any]]
    estimated_steps: int
    created_at: datetime = field(default_factory=datetime.now)

    def __repr__(self) -> str:
        """String representation."""
        return f"ExecutionPlan(task='{self.task_description}', steps={self.estimated_steps})"


class EmberGraph:
    """
    Directed Acyclic Graph (DAG) orchestration system for financial agents.

    EmberGraph manages dependencies between agents, executes them in the correct
    order, and tracks the flow of data between processing stages.
    """

    def __init__(self) -> None:
        """Initialize an empty EmberGraph."""
        self.nodes: Dict[str, GraphNode] = {}
        self._execution_order: Optional[List[str]] = None

    def add_node(
        self,
        agent: AgentProtocol,
        dependencies: Optional[List[str]] = None,
        inputs: Optional[Dict[str, Any]] = None,
        node_id: Optional[str] = None,
    ) -> str:
        """
        Add a node to the graph.

        Args:
            agent: The agent to execute at this node
            dependencies: List of node IDs this node depends on
            inputs: Initial inputs for the node
            node_id: Optional custom node ID; auto-generated if not provided

        Returns:
            The node ID
        """
        dependencies = dependencies or []
        inputs = inputs or {}

        # Validate dependencies exist
        for dep_id in dependencies:
            if dep_id not in self.nodes:
                raise ValueError(f"Dependency node '{dep_id}' does not exist")

        final_node_id = node_id or str(uuid4())
        node = GraphNode(
            agent=agent,
            id=final_node_id,
            dependencies=dependencies,
            inputs=inputs,
        )

        self.nodes[final_node_id] = node
        self._execution_order = None  # Invalidate cached order
        return final_node_id

    def _validate_dag(self) -> None:
        """Validate that the graph is a valid DAG (no cycles)."""
        visited: Set[str] = set()
        rec_stack: Set[str] = set()

        def has_cycle(node_id: str) -> bool:
            visited.add(node_id)
            rec_stack.add(node_id)

            for dep_id in self.nodes[node_id].dependencies:
                if dep_id not in visited:
                    if has_cycle(dep_id):
                        return True
                elif dep_id in rec_stack:
                    return True

            rec_stack.remove(node_id)
            return False

        for node_id in self.nodes:
            if node_id not in visited:
                if has_cycle(node_id):
                    raise ValueError("Graph contains a cycle - not a valid DAG")

    def _topological_sort(self) -> List[str]:
        """
        Perform topological sort to determine execution order.

        Returns:
            List of node IDs in execution order
        """
        self._validate_dag()

        in_degree: Dict[str, int] = {node_id: 0 for node_id in self.nodes}

        # Calculate in-degrees (number of dependencies each node has)
        for node_id, node in self.nodes.items():
            in_degree[node_id] = len(node.dependencies)

        # Find nodes with no incoming edges
        queue: List[str] = [node_id for node_id, degree in in_degree.items() if degree == 0]
        result: List[str] = []

        while queue:
            node_id = queue.pop(0)
            result.append(node_id)

            # Find nodes that depend on this node and reduce their in-degree
            for other_id, other_node in self.nodes.items():
                if node_id in other_node.dependencies:
                    in_degree[other_id] -= 1
                    if in_degree[other_id] == 0:
                        queue.append(other_id)

        if len(result) != len(self.nodes):
            raise ValueError("Graph contains a cycle")

        return result

    def execute(self, verbose: bool = False) -> Dict[str, Any]:
        """
        Execute the graph in topological order.

        Args:
            verbose: Whether to print execution progress

        Returns:
            Dictionary mapping node IDs to their outputs
        """
        if not self.nodes:
            return {}

        execution_order = self._topological_sort()
        results: Dict[str, Any] = {}

        for node_id in execution_order:
            node = self.nodes[node_id]

            if verbose:
                print(f"Executing: {node.agent.name} (node: {node_id})")

            node.status = NodeStatus.RUNNING
            node.started_at = datetime.now()

            try:
                # Gather inputs from dependencies
                for dep_id in node.dependencies:
                    dep_node = self.nodes[dep_id]
                    if dep_node.status != NodeStatus.COMPLETED:
                        raise RuntimeError(
                            f"Dependency {dep_id} not completed (status: {dep_node.status})"
                        )
                    # Merge dependency outputs into this node's inputs
                    node.inputs.update(dep_node.outputs)

                # Execute the agent
                node.outputs = node.agent.execute(node.inputs)
                node.status = NodeStatus.COMPLETED
                results[node_id] = node.outputs

                if verbose:
                    print(f"Completed: {node.agent.name}")

            except Exception as e:
                node.status = NodeStatus.FAILED
                node.error = str(e)
                if verbose:
                    print(f"Failed: {node.agent.name} - {e}")
                raise

            finally:
                node.completed_at = datetime.now()

        return results

    def get_execution_summary(self) -> Dict[str, Any]:
        """
        Get a summary of the execution.

        Returns:
            Dictionary with execution statistics
        """
        total_nodes = len(self.nodes)
        status_counts = {status: 0 for status in NodeStatus}

        for node in self.nodes.values():
            status_counts[node.status] += 1

        total_duration = sum(
            node.duration for node in self.nodes.values() if node.duration is not None
        )

        return {
            "total_nodes": total_nodes,
            "completed": status_counts[NodeStatus.COMPLETED],
            "failed": status_counts[NodeStatus.FAILED],
            "pending": status_counts[NodeStatus.PENDING],
            "total_duration_seconds": total_duration,
        }

    def visualize(self) -> str:
        """
        Generate a simple text-based visualization of the graph.

        Returns:
            String representation of the graph structure
        """
        lines = ["EmberGraph Structure:", "=" * 50]

        for node_id, node in self.nodes.items():
            status_symbol = {
                NodeStatus.PENDING: "⏸",
                NodeStatus.RUNNING: "▶",
                NodeStatus.COMPLETED: "✓",
                NodeStatus.FAILED: "✗",
                NodeStatus.SKIPPED: "⊘",
            }.get(node.status, "?")

            deps_str = ", ".join(node.dependencies) if node.dependencies else "none"
            lines.append(f"{status_symbol} {node.agent.name} ({node_id[:8]}...)")
            lines.append(f"  Dependencies: {deps_str}")
            lines.append(f"  Status: {node.status.value}")

            if node.duration:
                lines.append(f"  Duration: {node.duration:.2f}s")

            lines.append("")

        return "\n".join(lines)

    def __repr__(self) -> str:
        """String representation."""
        return f"EmberGraph(nodes={len(self.nodes)})"
