from __future__ import annotations

import inspect
import pathlib
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from contextvars import ContextVar, Token
from copy import copy, deepcopy
from dataclasses import InitVar, dataclass, field
from inspect import Parameter, currentframe, signature
from types import FunctionType
from typing import (  # type: ignore[attr-defined]
    TYPE_CHECKING,
    Any,
    ClassVar,
    Generic,
    Literal,
    NoReturn,
    Protocol,
    TypedDict,
    TypeVar,
    Union,
    _GenericAlias,
    cast,
    get_args,
    get_origin,
    get_type_hints,
    overload,
)

import graphviz
from typing_extensions import ParamSpec, Self, Unpack

from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations

from . import bindings
from .declarations import *
from .ipython_magic import IN_IPYTHON
from .runtime import *
from .runtime import _resolve_callable, class_to_ref, convert_to_same_type

if TYPE_CHECKING:
    import ipywidgets

    from .builtins import Bool, PyObject, String, f64, i64


__all__ = [
    "EGraph",
    "Module",
    "BUILTINS",
    "Expr",
    "Unit",
    "rewrite",
    "birewrite",
    "eq",
    "panic",
    "let",
    "delete",
    "union",
    "set_",
    "rule",
    "var",
    "vars_",
    "Fact",
    "expr_parts",
    "Schedule",
    "run",
    "seq",
    "Command",
]

T = TypeVar("T")
P = ParamSpec("P")
TYPE = TypeVar("TYPE", bound="type[Expr]")
CALLABLE = TypeVar("CALLABLE", bound=Callable)
EXPR = TypeVar("EXPR", bound="Expr")
E1 = TypeVar("E1", bound="Expr")
E2 = TypeVar("E2", bound="Expr")
E3 = TypeVar("E3", bound="Expr")
E4 = TypeVar("E4", bound="Expr")
# Attributes which are sometimes added to classes by the interpreter or the dataclass decorator, or by ipython.
# We ignore these when inspecting the class.

IGNORED_ATTRIBUTES = {
    "__module__",
    "__doc__",
    "__dict__",
    "__weakref__",
    "__orig_bases__",
    "__annotations__",
    "__hash__",
    # Ignore all reflected binary method
    *REFLECTED_BINARY_METHODS.keys(),
}


_BUILTIN_DECLS: Declarations | None = None

ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"}


class PyObjectFunction(Protocol):
    def __call__(self, *__args: PyObject) -> PyObject:
        ...


@dataclass
class _BaseModule(ABC):
    """
    Base Module which provides methods to register sorts, expressions, actions etc.

    Inherited by:
    - EGraph: Holds a live EGraph instance
    - Builtins: Stores a list of the builtins which have already been pre-regsietered
    - Module: Stores a list of commands and additional declerations
    """

    # Any modules you want to depend on
    modules: InitVar[list[Module]] = []  # noqa: RUF008
    # All dependencies flattened
    _flatted_deps: list[Module] = field(init=False, default_factory=list)
    _mod_decls: ModuleDeclarations = field(init=False)

    def __post_init__(self, modules: list[Module]) -> None:
        included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else []
        # Traverse all the included modules to flatten all their dependencies and add to the included declerations
        for mod in modules:
            for child_mod in [*mod._flatted_deps, mod]:
                if child_mod not in self._flatted_deps:
                    self._flatted_deps.append(child_mod)
                    included_decls.append(child_mod._mod_decls._decl)
        self._mod_decls = ModuleDeclarations(Declarations(), included_decls)

    @abstractmethod
    def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
        """
        Process the commands generated by this module.
        """
        raise NotImplementedError

    @overload
    def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]:
        ...

    @overload
    def class_(self, cls: TYPE, /) -> TYPE:
        ...

    def class_(self, *args, **kwargs) -> Any:
        """
        Registers a class.
        """
        frame = currentframe()
        assert frame
        prev_frame = frame.f_back
        assert prev_frame

        if kwargs:
            assert set(kwargs.keys()) == {"egg_sort"}
            return lambda cls: self._class(cls, prev_frame.f_locals, prev_frame.f_globals, kwargs["egg_sort"])
        assert len(args) == 1
        return self._class(args[0], prev_frame.f_locals, prev_frame.f_globals)

    def _class(  # noqa: PLR0912
        self,
        cls: type[Expr],
        hint_locals: dict[str, Any],
        hint_globals: dict[str, Any],
        egg_sort: str | None = None,
    ) -> RuntimeClass:
        """
        Registers a class.
        """
        cls_name = cls.__name__
        # Get all the methods from the class
        cls_dict: dict[str, Any] = {
            k: v for k, v in cls.__dict__.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
        }
        parameters: list[TypeVar] = cls_dict.pop("__parameters__", [])

        n_type_vars = len(parameters)
        self._process_commands(self._mod_decls.register_class(cls_name, n_type_vars, egg_sort))
        # The type ref of self is paramterized by the type vars
        slf_type_ref = TypeRefWithVars(cls_name, tuple(ClassTypeVarRef(i) for i in range(n_type_vars)))

        # First register any class vars as constants
        hint_globals = hint_globals.copy()
        hint_globals[cls_name] = cls
        for k, v in get_type_hints(cls, globalns=hint_globals, localns=hint_locals).items():
            if v.__origin__ == ClassVar:
                (inner_tp,) = v.__args__
                self._register_constant(ClassVariableRef(cls_name, k), inner_tp, None, (cls, cls_name))
            else:
                msg = "The only supported annotations on class attributes are class vars"
                raise NotImplementedError(msg)

        # Then register each of its methods
        for method_name, method in cls_dict.items():
            is_init = method_name == "__init__"
            # Don't register the init methods for literals, since those don't use the type checking mechanisms
            if is_init and cls_name in LIT_CLASS_NAMES:
                continue
            if isinstance(method, _WrappedMethod):
                fn = method.fn
                egg_fn = method.egg_fn
                cost = method.cost
                default = method.default
                merge = method.merge
                on_merge = method.on_merge
                mutates_first_arg = method.mutates_self
                unextractable = method.unextractable
                if method.preserve:
                    self._mod_decls.register_preserved_method(cls_name, method_name, fn)
                    continue
            else:
                fn = method
                egg_fn, cost, default, merge, on_merge = None, None, None, None, None
                unextractable = False
                mutates_first_arg = False
            if isinstance(fn, classmethod):
                fn = fn.__func__
                is_classmethod = True
            else:
                # We count __init__ as a classmethod since it is called on the class
                is_classmethod = is_init

            if isinstance(fn, property):
                fn = fn.fget
                is_property = True
                if is_classmethod:
                    msg = "Can't have a classmethod property"
                    raise NotImplementedError(msg)
            else:
                is_property = False
            ref: FunctionCallableRef = (
                ClassMethodRef(cls_name, method_name)
                if is_classmethod
                else PropertyRef(cls_name, method_name)
                if is_property
                else MethodRef(cls_name, method_name)
            )
            self._register_function(
                ref,
                egg_fn,
                fn,
                hint_locals,
                default,
                cost,
                merge,
                on_merge,
                mutates_first_arg or method_name in ALWAYS_MUTATES_SELF,
                "cls" if is_classmethod and not is_init else slf_type_ref,
                parameters,
                is_init,
                # If this is an i64, use the runtime class for the alias so that i64Like is resolved properly
                # Otherwise, this might be a Map in which case pass in the original cls so that we
                # can do Map[T, V] on it, which is not allowed on the runtime class
                cls_type_and_name=(
                    RuntimeClass(self._mod_decls, cls_name) if cls_name in {"i64", "String"} else cls,
                    cls_name,
                ),
                unextractable=unextractable,
            )

        # Register != as a method so we can print it as a string
        self._mod_decls._decl.register_callable_ref(MethodRef(cls_name, "__ne__"), "!=")
        return RuntimeClass(self._mod_decls, cls_name)

    # We seperate the function and method overloads to make it simpler to know if we are modifying a function or method,
    # So that we can add the functions eagerly to the registry and wait on the methods till we process the class.

    @overload
    def method(
        self,
        *,
        preserve: Literal[True],
    ) -> Callable[[CALLABLE], CALLABLE]:
        ...

    # We have to seperate method/function overloads for those that use the T params and those that don't
    # Otherwise, if you say just pass in `cost` then the T param is inferred as `Nothing` and
    # It will break the typing.

    @overload
    def method(
        self,
        *,
        egg_fn: str | None = None,
        cost: int | None = None,
        merge: Callable[[Any, Any], Any] | None = None,
        on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None,
        mutates_self: bool = False,
        unextractable: bool = False,
    ) -> Callable[[CALLABLE], CALLABLE]:
        ...

    @overload
    def method(
        self,
        *,
        egg_fn: str | None = None,
        cost: int | None = None,
        default: EXPR | None = None,
        merge: Callable[[EXPR, EXPR], EXPR] | None = None,
        on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
        mutates_self: bool = False,
        unextractable: bool = False,
    ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
        ...

    def method(
        self,
        *,
        egg_fn: str | None = None,
        cost: int | None = None,
        default: EXPR | None = None,
        merge: Callable[[EXPR, EXPR], EXPR] | None = None,
        on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
        preserve: bool = False,
        mutates_self: bool = False,
        unextractable: bool = False,
    ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
        return lambda fn: _WrappedMethod(
            egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable
        )

    @overload
    def function(self, fn: CALLABLE, /) -> CALLABLE:
        ...

    @overload
    def function(
        self,
        *,
        egg_fn: str | None = None,
        cost: int | None = None,
        merge: Callable[[Any, Any], Any] | None = None,
        on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None,
        mutates_first_arg: bool = False,
        unextractable: bool = False,
    ) -> Callable[[CALLABLE], CALLABLE]:
        ...

    @overload
    def function(
        self,
        *,
        egg_fn: str | None = None,
        cost: int | None = None,
        default: EXPR | None = None,
        merge: Callable[[EXPR, EXPR], EXPR] | None = None,
        on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
        mutates_first_arg: bool = False,
        unextractable: bool = False,
    ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
        ...

    def function(self, *args, **kwargs) -> Any:
        """
        Registers a function.
        """
        fn_locals = currentframe().f_back.f_locals  # type: ignore[union-attr]

        # If we have any positional args, then we are calling it directly on a function
        if args:
            assert len(args) == 1
            return self._function(args[0], fn_locals)
        # otherwise, we are passing some keyword args, so save those, and then return a partial
        return lambda fn: self._function(fn, fn_locals, **kwargs)

    def _function(
        self,
        fn: Callable[..., RuntimeExpr],
        hint_locals: dict[str, Any],
        mutates_first_arg: bool = False,
        egg_fn: str | None = None,
        cost: int | None = None,
        default: RuntimeExpr | None = None,
        merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None,
        on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None,
        unextractable: bool = False,
    ) -> RuntimeFunction:
        """
        Uncurried version of function decorator
        """
        name = fn.__name__
        # Save function decleartion
        self._register_function(
            FunctionRef(name),
            egg_fn,
            fn,
            hint_locals,
            default,
            cost,
            merge,
            on_merge,
            mutates_first_arg,
            unextractable=unextractable,
        )
        # Return a runtime function which will act like the decleration
        return RuntimeFunction(self._mod_decls, name)

    def _register_function(  # noqa: C901, PLR0912
        self,
        ref: FunctionCallableRef,
        egg_name: str | None,
        fn: object,
        # Pass in the locals, retrieved from the frame when wrapping,
        # so that we support classes and function defined inside of other functions (which won't show up in the globals)
        hint_locals: dict[str, Any],
        default: RuntimeExpr | None,
        cost: int | None,
        merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None,
        on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
        mutates_first_arg: bool,
        # The first arg is either cls, for a classmethod, a self type, or none for a function
        first_arg: Literal["cls"] | TypeOrVarRef | None = None,
        cls_typevars: list[TypeVar] | None = None,
        is_init: bool = False,
        cls_type_and_name: tuple[type | RuntimeClass, str] | None = None,
        unextractable: bool = False,
    ) -> None:
        if cls_typevars is None:
            cls_typevars = []
        if not isinstance(fn, FunctionType):
            raise NotImplementedError(f"Can only generate function decls for functions not {fn}  {type(fn)}")

        hint_globals = fn.__globals__.copy()

        if cls_type_and_name:
            hint_globals[cls_type_and_name[1]] = cls_type_and_name[0]
        hints = get_type_hints(fn, hint_globals, hint_locals)

        params = list(signature(fn).parameters.values())
        arg_names = tuple(t.name for t in params)
        arg_defaults = tuple(expr_parts(p.default).expr if p.default is not Parameter.empty else None for p in params)
        # If this is an init function, or a classmethod, remove the first arg name
        if is_init or first_arg == "cls":
            arg_names = arg_names[1:]
            arg_defaults = arg_defaults[1:]
        # Remove first arg if this is a classmethod or a method, since it won't have an annotation
        if first_arg is not None:
            first, *params = params
            if first.annotation != Parameter.empty:
                raise ValueError(f"First arg of a method must not have an annotation, not {first.annotation}")

        # Check that all the params are positional or keyword, and that there is only one var arg at the end
        found_var_arg = False
        for param in params:
            if found_var_arg:
                msg = "Can only have a single var arg at the end"
                raise ValueError(msg)
            kind = param.kind
            if kind == Parameter.VAR_POSITIONAL:
                found_var_arg = True
            elif kind != Parameter.POSITIONAL_OR_KEYWORD:
                raise ValueError(f"Can only register functions with positional or keyword args, not {param.kind}")

        if found_var_arg:
            *params, var_arg_param = params
            # For now, we don't use the variable arg name
            arg_names = arg_names[:-1]
            arg_defaults = arg_defaults[:-1]
            var_arg_type = self._resolve_type_annotation(hints[var_arg_param.name], cls_typevars, cls_type_and_name)
        else:
            var_arg_type = None
        arg_types = tuple(self._resolve_type_annotation(hints[t.name], cls_typevars, cls_type_and_name) for t in params)
        # If the first arg is a self, and this not an __init__ fn, add this as a typeref
        if isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init:
            arg_types = (first_arg, *arg_types)

        # If this is an init fn use the first arg as the return type
        if is_init:
            assert not mutates_first_arg
            if not isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars):
                msg = "Init function must have a self type"
                raise ValueError(msg)
            return_type = first_arg
        elif mutates_first_arg:
            return_type = arg_types[0]
        else:
            return_type = self._resolve_type_annotation(hints["return"], cls_typevars, cls_type_and_name)

        default_decl = None if default is None else default.__egg_typed_expr__.expr
        merge_decl = (
            None
            if merge is None
            else merge(
                RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
                RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
            ).__egg_typed_expr__.expr
        )
        merge_action = (
            []
            if on_merge is None
            else _action_likes(
                on_merge(
                    RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
                    RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
                )
            )
        )
        fn_decl = FunctionDecl(
            return_type=return_type,
            var_arg_type=var_arg_type,
            arg_types=arg_types,
            arg_names=arg_names,
            arg_defaults=arg_defaults,
            mutates_first_arg=mutates_first_arg,
        )
        self._process_commands(
            self._mod_decls.register_function_callable(
                ref,
                fn_decl,
                egg_name,
                cost,
                default_decl,
                merge_decl,
                [a._to_egg_action(self._mod_decls) for a in merge_action],
                unextractable,
            )
        )

    def _resolve_type_annotation(
        self,
        tp: object,
        cls_typevars: list[TypeVar],
        cls_type_and_name: tuple[type | RuntimeClass, str] | None,
    ) -> TypeOrVarRef:
        if isinstance(tp, TypeVar):
            return ClassTypeVarRef(cls_typevars.index(tp))
        # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type.
        if get_origin(tp) == Union:
            first, *_rest = get_args(tp)
            return self._resolve_type_annotation(first, cls_typevars, cls_type_and_name)
        # If the type is `object` then this is assumed to be a PyObjetLike, i.e. converted into a PyObject
        if tp == object:
            return TypeRefWithVars("PyObject")
            # from .builtins import PyObject

            # tp = PyObject
        # If this is the type for the class, use the class name
        if cls_type_and_name and tp == cls_type_and_name[0]:
            return TypeRefWithVars(cls_type_and_name[1])

        # If this is the class for this method and we have a paramaterized class, recurse
        if cls_type_and_name and isinstance(tp, _GenericAlias) and tp.__origin__ == cls_type_and_name[0]:
            return TypeRefWithVars(
                cls_type_and_name[1],
                tuple(self._resolve_type_annotation(a, cls_typevars, cls_type_and_name) for a in tp.__args__),
            )

        if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass):
            return class_to_ref(tp).to_var()
        raise TypeError(f"Unexpected type annotation {tp}")

    def register(self, command_or_generator: CommandLike | CommandGenerator, *commands: CommandLike) -> None:
        """
        Registers any number of rewrites or rules.
        """
        if isinstance(command_or_generator, FunctionType):
            assert not commands
            commands = tuple(_command_generator(command_or_generator))
        else:
            commands = (cast(CommandLike, command_or_generator), *commands)
        self._process_commands(_command_like(command)._to_egg_command(self._mod_decls) for command in commands)

    def ruleset(self, name: str) -> Ruleset:
        self._process_commands([bindings.AddRuleset(name)])
        return Ruleset(name)

    # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
    @overload
    def relation(
        self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], /
    ) -> Callable[[E1, E2, E3, E4], Unit]:
        ...

    @overload
    def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]:
        ...

    @overload
    def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]:
        ...

    @overload
    def relation(self, name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]:
        ...

    @overload
    def relation(self, name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]:
        ...

    def relation(self, name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..., Unit]:
        """
        Defines a relation, which is the same as a function which returns unit.
        """
        arg_types = tuple(self._resolve_type_annotation(cast(object, tp), [], None) for tp in tps)
        fn_decl = FunctionDecl(
            arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False
        )
        commands = self._mod_decls.register_function_callable(
            FunctionRef(name),
            fn_decl,
            egg_fn,
            cost=None,
            default=None,
            merge=None,
            merge_action=[],
            unextractable=False,
            is_relation=True,
        )
        self._process_commands(commands)
        return cast(Callable[..., Unit], RuntimeFunction(self._mod_decls, name))

    def input(self, fn: Callable[..., String], path: str) -> None:
        """
        Loads a CSV file and sets it as *input, output of the function.
        """
        fn_name = self._mod_decls.get_egg_fn(_resolve_callable(fn))
        self._process_commands([bindings.Input(fn_name, path)])

    def constant(self, name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
        """
        Defines a named constant of a certain type.

        This is the same as defining a nullary function with a high cost.
        """
        ref = ConstantRef(name)
        type_ref = self._register_constant(ref, tp, egg_name, None)
        return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(type_ref, CallDecl(ref))))

    def _register_constant(
        self,
        ref: ConstantRef | ClassVariableRef,
        tp: object,
        egg_name: str | None,
        cls_type_and_name: tuple[type | RuntimeClass, str] | None,
    ) -> JustTypeRef:
        """
        Register a constant, returning its typeref().
        """
        type_ref = self._resolve_type_annotation(tp, [], cls_type_and_name).to_just()
        self._process_commands(self._mod_decls.register_constant_callable(ref, type_ref, egg_name))
        return type_ref

    def let(self, name: str, expr: EXPR) -> EXPR:
        """
        Define a new expression in the egraph and return a reference to it.
        """
        typed_expr = expr_parts(expr)
        self._process_commands([bindings.ActionCommand(bindings.Let(name, typed_expr.to_egg(self._mod_decls)))])
        return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(typed_expr.tp, VarDecl(name))))


@dataclass
class _Builtins(_BaseModule):
    def __post_init__(self, modules: list[Module]) -> None:
        """
        Register these declarations as builtins, so others can use them.
        """
        assert not modules
        super().__post_init__(modules)
        global _BUILTIN_DECLS
        if _BUILTIN_DECLS is not None:
            msg = "Builtins already initialized"
            raise RuntimeError(msg)
        _BUILTIN_DECLS = self._mod_decls._decl

    def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
        """
        Commands which would have been used to create the builtins are discarded, since they are already registered.
        """


@dataclass
class Module(_BaseModule):
    _cmds: list[bindings._Command] = field(default_factory=list, repr=False)

    @property
    def as_egglog_string(self) -> str:
        """
        Returns the egglog string for this module.
        """
        return "\n".join(str(c) for c in self._cmds)

    def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
        self._cmds.extend(cmds)

    def unextractable(self) -> Module:
        """
        Makes a copy of this module with all functions marked as un-extractable
        """
        return self._map_functions(
            lambda decl: bindings.FunctionDecl(
                decl.name,
                decl.schema,
                decl.default,
                decl.merge,
                decl.merge_action,
                decl.cost,
                True,
            )
        )

    def increase_cost(self, x: int = 10000000) -> Module:
        """
        Make a copy of this module with all function costs increased by x
        """
        return self._map_functions(
            lambda decl, x=x: bindings.FunctionDecl(  # type: ignore[misc]
                decl.name,
                decl.schema,
                decl.default,
                decl.merge,
                decl.merge_action,
                (decl.cost or 1) + x,
                decl.unextractable,
            )
        )

    def without_rules(self) -> Module:
        """
        Makes a copy of this module with all rules removed.
        """
        new = copy(self)
        new._cmds = [
            c
            for c in new._cmds
            if not isinstance(c, bindings.RuleCommand)
            and not isinstance(c, bindings.RewriteCommand)
            and not isinstance(c, bindings.BiRewriteCommand)
        ]
        return new

    # def rename_ruleset(self, new_r: str) -> Module:
    #     """
    #     Makes a copy of this module with all default rulsets changed to the new one.
    #     """
    #     new = copy(self)
    #     new._cmds = [
    #         bindings.RuleCommand(c.name, new_r, c.rule)
    #         if isinstance(c, bindings.RuleCommand) and not c.ruleset
    #         else bindings.RewriteCommand(new_r, c.rewrite)
    #         if isinstance(c, bindings.RewriteCommand) and not c.name
    #         else bindings.BiRewriteCommand(new_r, c.rewrite)
    #         if isinstance(c, bindings.BiRewriteCommand) and not c.name
    #         else c
    #         for c in new._cmds
    #     ]
    #     new._cmds.insert(0, bindings.AddRuleset(new_r))
    #     return new

    def _map_functions(self, fn: Callable[[bindings.FunctionDecl], bindings.FunctionDecl]) -> Module:
        """
        Returns a copy where all the functions have been mapped with the given function.
        """
        new = copy(self)
        new._cmds = [bindings.Function(fn(c.decl)) if isinstance(c, bindings.Function) else c for c in new._cmds]
        return new


class GraphvizKwargs(TypedDict, total=False):
    max_functions: int | None
    max_calls_per_function: int | None
    n_inline_leaves: int
    split_primitive_outputs: bool


@dataclass
class EGraph(_BaseModule):
    """
    Represents an EGraph instance at runtime
    """

    seminaive: InitVar[bool] = True
    save_egglog_string: InitVar[bool] = False

    _egraph: bindings.EGraph = field(repr=False, init=False)
    # The current declarations which have been pushed to the stack
    _decl_stack: list[Declarations] = field(default_factory=list, repr=False)
    _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False)
    _egglog_string: str | None = field(default=None, repr=False, init=False)

    def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None:
        super().__post_init__(modules)
        self._egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive)
        for m in self._flatted_deps:
            self._process_commands(m._cmds)
        if save_egglog_string:
            self._egglog_string = ""

    def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
        self._egraph.run_program(*commands)
        if isinstance(self._egglog_string, str):
            self._egglog_string += "\n".join(str(c) for c in commands) + "\n"

    @property
    def as_egglog_string(self) -> str:
        """
        Returns the egglog string for this module.
        """
        if self._egglog_string is None:
            msg = "Can't get egglog string unless EGraph created with save_egglog_string=True"
            raise ValueError(msg)
        return self._egglog_string

    def _repr_mimebundle_(self, *args, **kwargs):
        """
        Returns the graphviz representation of the e-graph.
        """
        return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}

    def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
        # By default we want to split primitive outputs
        kwargs.setdefault("split_primitive_outputs", True)
        n_inline = kwargs.pop("n_inline_leaves", 0)
        serialized = self._egraph.serialize(**kwargs)  # type: ignore[misc]
        serialized.map_ops(self._mod_decls.op_mapping())
        for _ in range(n_inline):
            serialized.inline_leaves()
        original = serialized.to_dot()
        # Add link to stylesheet to the graph, so that edges light up on hover
        # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
        styles = """/* the lines within the edges */
      .edge:active path,
      .edge:hover path {
        stroke: fuchsia;
        stroke-width: 3;
        stroke-opacity: 1;
      }
      /* arrows are typically drawn with a polygon */
      .edge:active polygon,
      .edge:hover polygon {
        stroke: fuchsia;
        stroke-width: 3;
        fill: fuchsia;
        stroke-opacity: 1;
        fill-opacity: 1;
      }
      /* If you happen to have text and want to color that as well... */
      .edge:active text,
      .edge:hover text {
        fill: fuchsia;
      }"""
        p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
        p.write_text(styles)
        with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
        return graphviz.Source(with_stylesheet)

    def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
        return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")

    def _repr_html_(self) -> str:
        """
        Add a _repr_html_ to be an SVG to work with sphinx gallery.

        ala https://github.com/xflr6/graphviz/pull/121
        until this PR is merged and released
        https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
        """
        return self.graphviz_svg()

    def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
        """
        Displays the e-graph in the notebook.
        """
        graphviz = self.graphviz(**kwargs)
        if IN_IPYTHON:
            from IPython.display import SVG, display

            display(SVG(self.graphviz_svg(**kwargs)))
        else:
            graphviz.render(view=True, format="svg", quiet=True)

    @overload
    def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR:
        ...

    @overload
    def simplify(self, expr: EXPR, schedule: Schedule, /) -> EXPR:
        ...

    def simplify(
        self, expr: EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
    ) -> EXPR:
        """
        Simplifies the given expression.
        """
        if isinstance(limit_or_schedule, int):
            limit_or_schedule = run(ruleset, *until) * limit_or_schedule
        typed_expr = expr_parts(expr)
        egg_expr = typed_expr.to_egg(self._mod_decls)
        self._process_commands([bindings.Simplify(egg_expr, limit_or_schedule._to_egg_schedule(self._mod_decls))])
        extract_report = self._egraph.extract_report()
        if not isinstance(extract_report, bindings.Best):
            msg = "No extract report saved"
            raise ValueError(msg)  # noqa: TRY004
        new_typed_expr = TypedExprDecl.from_egg(
            self._egraph, self._mod_decls, bindings.termdag_term_to_expr(extract_report.termdag, extract_report.term)
        )
        return cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr))

    def include(self, path: str) -> None:
        """
        Include a file of rules.
        """
        msg = "Not implemented yet, because we don't have a way of registering the types with Python"
        raise NotImplementedError(msg)

    def output(self) -> None:
        msg = "Not imeplemented yet, because there are no examples in the egglog repo"
        raise NotImplementedError(msg)

    @overload
    def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport:
        ...

    @overload
    def run(self, schedule: Schedule, /) -> bindings.RunReport:
        ...

    def run(
        self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
    ) -> bindings.RunReport:
        """
        Run the egraph until the given limit or until the given facts are true.
        """
        if isinstance(limit_or_schedule, int):
            limit_or_schedule = run(ruleset, *until) * limit_or_schedule
        return self._run_schedule(limit_or_schedule)

    def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
        self._process_commands([bindings.RunSchedule(schedule._to_egg_schedule(self._mod_decls))])
        run_report = self._egraph.run_report()
        if not run_report:
            msg = "No run report saved"
            raise ValueError(msg)
        return run_report

    def check(self, *facts: FactLike) -> None:
        """
        Check if a fact is true in the egraph.
        """
        self._process_commands([self._facts_to_check(facts)])

    def check_fail(self, *facts: FactLike) -> None:
        """
        Checks that one of the facts is not true
        """
        self._process_commands([bindings.Fail(self._facts_to_check(facts))])

    def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check:
        egg_facts = [f._to_egg_fact(self._mod_decls) for f in _fact_likes(facts)]
        return bindings.Check(egg_facts)

    @overload
    def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR:
        ...

    @overload
    def extract(self, expr: EXPR, /, include_cost: Literal[True]) -> tuple[EXPR, int]:
        ...

    def extract(self, expr: EXPR, include_cost: bool = False) -> EXPR | tuple[EXPR, int]:
        """
        Extract the lowest cost expression from the egraph.
        """
        typed_expr = expr_parts(expr)
        egg_expr = typed_expr.to_egg(self._mod_decls)
        extract_report = self._run_extract(egg_expr, 0)
        if not isinstance(extract_report, bindings.Best):
            msg = "No extract report saved"
            raise ValueError(msg)  # noqa: TRY004
        new_typed_expr = TypedExprDecl.from_egg(
            self._egraph, self._mod_decls, bindings.termdag_term_to_expr(extract_report.termdag, extract_report.term)
        )
        if new_typed_expr.tp != typed_expr.tp:
            raise RuntimeError(f"Type mismatch: {new_typed_expr.tp} != {typed_expr.tp}")
        res = cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr))
        if include_cost:
            return res, extract_report.cost
        return res

    def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]:
        """
        Extract multiple expressions from the egraph.
        """
        typed_expr = expr_parts(expr)
        egg_expr = typed_expr.to_egg(self._mod_decls)
        extract_report = self._run_extract(egg_expr, n)
        if not isinstance(extract_report, bindings.Variants):
            msg = "Wrong extract report type"
            raise ValueError(msg)  # noqa: TRY004
        new_exprs = [
            TypedExprDecl.from_egg(
                self._egraph, self._mod_decls, bindings.termdag_term_to_expr(extract_report.termdag, term)
            )
            for term in extract_report.terms
        ]
        return [cast(EXPR, RuntimeExpr(self._mod_decls, expr)) for expr in new_exprs]

    def _run_extract(self, expr: bindings._Expr, n: int) -> bindings._ExtractReport:
        self._process_commands([bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n))))])
        extract_report = self._egraph.extract_report()
        if not extract_report:
            msg = "No extract report saved"
            raise ValueError(msg)
        return extract_report

    def push(self) -> None:
        """
        Push the current state of the egraph, so that it can be popped later and reverted back.
        """
        self._process_commands([bindings.Push(1)])
        self._decl_stack.append(self._mod_decls._decl)
        self._decls = deepcopy(self._mod_decls._decl)

    def pop(self) -> None:
        """
        Pop the current state of the egraph, reverting back to the previous state.
        """
        self._process_commands([bindings.Pop(1)])
        self._mod_decls._decl = self._decl_stack.pop()

    def __enter__(self) -> Self:
        """
        Copy the egraph state, so that it can be reverted back to the original state at the end.

        Also sets the current egraph to this one.
        """
        self._token_stack.append(CURRENT_EGRAPH.set(self))
        self.push()
        return self

    def __exit__(self, exc_type, exc, exc_tb) -> None:  # noqa: ANN001
        CURRENT_EGRAPH.reset(self._token_stack.pop())
        self.pop()

    @overload
    def eval(self, expr: i64) -> int:
        ...

    @overload
    def eval(self, expr: f64) -> float:
        ...

    @overload
    def eval(self, expr: Bool) -> bool:
        ...

    @overload
    def eval(self, expr: String) -> str:
        ...

    @overload
    def eval(self, expr: PyObject) -> object:
        ...

    def eval(self, expr: Expr) -> object:
        """
        Evaluates the given expression (which must be a primitive type), returning the result.
        """
        typed_expr = expr_parts(expr)
        egg_expr = typed_expr.to_egg(self._mod_decls)
        match typed_expr.tp:
            case JustTypeRef("i64"):
                return self._egraph.eval_i64(egg_expr)
            case JustTypeRef("f64"):
                return self._egraph.eval_f64(egg_expr)
            case JustTypeRef("Bool"):
                return self._egraph.eval_bool(egg_expr)
            case JustTypeRef("String"):
                return self._egraph.eval_string(egg_expr)
            case JustTypeRef("PyObject"):
                return self._egraph.eval_py_object(egg_expr)
        raise NotImplementedError(f"Eval not implemented for {typed_expr.tp.name}")

    def eval_fn(self, fn: Callable) -> PyObjectFunction:
        """
        Takes a python callable and maps it to a callable which takes and returns PyObjects.

        It translates it to a call which uses `py_eval` to call the function, passing in the
        args as locals, and using the globals from function.
        """
        from .builtins import PyObject, py_eval

        def inner(*__args: PyObject, __fn: Callable = fn) -> PyObject:
            new_kvs: list[object] = []
            eval_str = "__fn("
            for i, arg in enumerate(__args):
                new_kvs.append(f"__arg_{i}")
                new_kvs.append(arg)
                eval_str += f"__arg_{i}, "
            eval_str += ")"
            return py_eval(eval_str, PyObject({"__fn": __fn}).dict_update(*new_kvs), __fn.__globals__)

        return inner

    def saturate(
        self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
    ) -> ipywidgets.Widget:
        from .graphviz_widget import graphviz_widget_with_slider

        dots = [str(self.graphviz(**kwargs))]
        i = 0
        while self.run(1).updated and i < max:
            i += 1
            dots.append(str(self.graphviz(**kwargs)))
        return graphviz_widget_with_slider(dots, performance=performance)

    def saturate_to_html(
        self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
    ) -> None:
        # raise NotImplementedError("Upstream bugs prevent rendering to HTML")

        # import panel

        # panel.extension("ipywidgets")

        widget = self.saturate(performance=performance, **kwargs)
        # panel.panel(widget).save(file)

        from ipywidgets.embed import embed_minimal_html

        embed_minimal_html("tmp.html", views=[widget], drop_defaults=False)
        # Use panel while this issue persists
        # https://github.com/jupyter-widgets/ipywidgets/issues/3761#issuecomment-1755563436

    @classmethod
    def current(cls) -> EGraph:
        """
        Returns the current egraph, which is the one in the context.
        """
        return CURRENT_EGRAPH.get()


CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH")


@dataclass(frozen=True)
class _WrappedMethod(Generic[P, EXPR]):
    """
    Used to wrap a method and store some extra options on it before processing it when processing the class.
    """

    egg_fn: str | None
    cost: int | None
    default: EXPR | None
    merge: Callable[[EXPR, EXPR], EXPR] | None
    on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None
    fn: Callable[P, EXPR]
    preserve: bool
    mutates_self: bool
    unextractable: bool

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR:
        msg = "We should never call a wrapped method. Did you forget to wrap the class?"
        raise NotImplementedError(msg)


class _ExprMetaclass(type):
    """
    Metaclass of Expr.

    Used to override isistance checks, so that runtime expressions are instances of Expr at runtime.
    """

    def __instancecheck__(cls, instance: object) -> bool:
        return isinstance(instance, RuntimeExpr)


class Expr(metaclass=_ExprMetaclass):
    """
    Expression base class, which adds suport for != to all expression types.
    """

    def __ne__(self: EXPR, other_expr: EXPR) -> Unit:  # type: ignore[override, empty-body]
        """
        Compare whether to expressions are not equal.

        :param self: The expression to compare.
        :param other_expr: The other expression to compare to, which must be of the same type.
        :meta public:
        """
        ...

    def __eq__(self, other: NoReturn) -> NoReturn:  # type: ignore[override, empty-body]
        """
        Equality is currently not supported.

        We only add this method so that if you try to use it MyPy will warn you.
        """
        ...


BUILTINS = _Builtins()


@BUILTINS.class_(egg_sort="Unit")
class Unit(Expr):
    """
    The unit type. This is also used to reprsent if a value exists, if it is resolved or not.
    """

    def __init__(self) -> None:
        ...


@dataclass(frozen=True)
class Ruleset:
    name: str


def _ruleset_name(ruleset: Ruleset | None) -> str:
    return ruleset.name if ruleset else ""


class Command(ABC):
    """
    A command that can be executed in the egg interpreter.

    We only use this for commands which return no result and don't create new Python objects.

    Anything that can be passed to the `register` function in a Module is a Command.
    """

    @abstractmethod
    def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
        raise NotImplementedError

    @abstractmethod
    def __str__(self) -> str:
        raise NotImplementedError


@dataclass
class Rewrite(Command):
    _ruleset: str
    _lhs: RuntimeExpr
    _rhs: RuntimeExpr
    _conditions: tuple[Fact, ...]
    _fn_name: ClassVar[str] = "rewrite"

    def __str__(self) -> str:
        args_str = ", ".join(map(str, [self._rhs, *self._conditions]))
        return f"{self._fn_name}({self._lhs}).to({args_str})"

    def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
        return bindings.RewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls))

    def _to_egg_rewrite(self, mod_decls: ModuleDeclarations) -> bindings.Rewrite:
        return bindings.Rewrite(
            self._lhs.__egg_typed_expr__.expr.to_egg(mod_decls),
            self._rhs.__egg_typed_expr__.expr.to_egg(mod_decls),
            [c._to_egg_fact(mod_decls) for c in self._conditions],
        )


@dataclass
class BiRewrite(Rewrite):
    _fn_name: ClassVar[str] = "birewrite"

    def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
        return bindings.BiRewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls))


@dataclass
class Fact(ABC):
    """
    An e-graph fact, either an equality or a unit expression.
    """

    @abstractmethod
    def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings._Fact:
        raise NotImplementedError


@dataclass
class Eq(Fact):
    _exprs: list[RuntimeExpr]

    def __str__(self) -> str:
        first, *rest = self._exprs
        args_str = ", ".join(map(str, rest))
        return f"eq({first}).to({args_str})"

    def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Eq:
        return bindings.Eq([e.__egg_typed_expr__.expr.to_egg(mod_decls) for e in self._exprs])


@dataclass
class ExprFact(Fact):
    _expr: RuntimeExpr

    def __str__(self) -> str:
        return str(self._expr)

    def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Fact:
        return bindings.Fact(self._expr.__egg_typed_expr__.expr.to_egg(mod_decls))


@dataclass
class Rule(Command):
    head: tuple[Action, ...]
    body: tuple[Fact, ...]
    name: str
    ruleset: str

    def __str__(self) -> str:
        head_str = ", ".join(map(str, self.head))
        body_str = ", ".join(map(str, self.body))
        return f"rule({body_str}).then({head_str})"

    def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings.RuleCommand:
        return bindings.RuleCommand(
            self.name,
            self.ruleset,
            bindings.Rule(
                [a._to_egg_action(mod_decls) for a in self.head],
                [f._to_egg_fact(mod_decls) for f in self.body],
            ),
        )


class Action(Command, ABC):
    @abstractmethod
    def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings._Action:
        raise NotImplementedError

    def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
        return bindings.ActionCommand(self._to_egg_action(mod_decls))


@dataclass
class Let(Action):
    _name: str
    _value: RuntimeExpr

    def __str__(self) -> str:
        return f"let({self._name}, {self._value})"

    def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Let:
        return bindings.Let(self._name, self._value.__egg_typed_expr__.expr.to_egg(mod_decls))


@dataclass
class Set(Action):
    _call: RuntimeExpr
    _rhs: RuntimeExpr

    def __str__(self) -> str:
        return f"set({self._call}).to({self._rhs})"

    def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Set:
        egg_call = self._call.__egg_typed_expr__.expr.to_egg(mod_decls)
        if not isinstance(egg_call, bindings.Call):
            raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}")  # noqa: TRY004
        return bindings.Set(
            egg_call.name,
            egg_call.args,
            self._rhs.__egg_typed_expr__.expr.to_egg(mod_decls),
        )


@dataclass
class ExprAction(Action):
    _expr: RuntimeExpr

    def __str__(self) -> str:
        return str(self._expr)

    def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Expr_:
        return bindings.Expr_(self._expr.__egg_typed_expr__.expr.to_egg(mod_decls))


@dataclass
class Delete(Action):
    _call: RuntimeExpr

    def __str__(self) -> str:
        return f"delete({self._call})"

    def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Delete:
        egg_call = self._call.__egg_typed_expr__.expr.to_egg(mod_decls)
        if not isinstance(egg_call, bindings.Call):
            raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}")  # noqa: TRY004
        return bindings.Delete(egg_call.name, egg_call.args)


@dataclass
class Union_(Action):  # noqa: N801
    _lhs: RuntimeExpr
    _rhs: RuntimeExpr

    def __str__(self) -> str:
        return f"union({self._lhs}).with_({self._rhs})"

    def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Union:
        return bindings.Union(
            self._lhs.__egg_typed_expr__.expr.to_egg(mod_decls), self._rhs.__egg_typed_expr__.expr.to_egg(mod_decls)
        )


@dataclass
class Panic(Action):
    message: str

    def __str__(self) -> str:
        return f"panic({self.message})"

    def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Panic:
        return bindings.Panic(self.message)


class Schedule(ABC):
    def __mul__(self, length: int) -> Schedule:
        """
        Repeat the schedule a number of times.
        """
        return Repeat(length, self)

    def saturate(self) -> Schedule:
        """
        Run the schedule until the e-graph is saturated.
        """
        return Saturate(self)

    def __add__(self, other: Schedule) -> Schedule:
        """
        Run two schedules in sequence.
        """
        return Sequence((self, other))

    @abstractmethod
    def __str__(self) -> str:
        raise NotImplementedError

    @abstractmethod
    def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
        raise NotImplementedError


@dataclass
class Run(Schedule):
    """Configuration of a run"""

    ruleset: str
    until: tuple[Fact, ...]

    def __str__(self) -> str:
        args_str = ", ".join(map(str, [self.ruleset, *self.until]))
        return f"run({args_str})"

    def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
        return bindings.Run(self._to_egg_config(mod_decls))

    def _to_egg_config(self, mod_decls: ModuleDeclarations) -> bindings.RunConfig:
        return bindings.RunConfig(
            self.ruleset,
            [fact._to_egg_fact(mod_decls) for fact in self.until] if self.until else None,
        )


@dataclass
class Saturate(Schedule):
    schedule: Schedule

    def __str__(self) -> str:
        return f"{self.schedule}.saturate()"

    def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
        return bindings.Saturate(self.schedule._to_egg_schedule(mod_decls))


@dataclass
class Repeat(Schedule):
    length: int
    schedule: Schedule

    def __str__(self) -> str:
        return f"{self.schedule} * {self.length}"

    def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
        return bindings.Repeat(self.length, self.schedule._to_egg_schedule(mod_decls))


@dataclass
class Sequence(Schedule):
    schedules: tuple[Schedule, ...]

    def __str__(self) -> str:
        return f"sequence({', '.join(map(str, self.schedules))})"

    def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
        return bindings.Sequence([schedule._to_egg_schedule(mod_decls) for schedule in self.schedules])


# We use these builders so that when creating these structures we can type check
# if the arguments are the same type of expression


def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]:
    """Rewrite the given expression to a new expression."""
    return _RewriteBuilder(lhs, ruleset)


def birewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _BirewriteBuilder[EXPR]:
    """Rewrite the given expression to a new expression and vice versa."""
    return _BirewriteBuilder(lhs, ruleset)


def eq(expr: EXPR) -> _EqBuilder[EXPR]:
    """Check if the given expression is equal to the given value."""
    return _EqBuilder(expr)


def panic(message: str) -> Action:
    """Raise an error with the given message."""
    return Panic(message)


def let(name: str, expr: Expr) -> Action:
    """Create a let binding."""
    return Let(name, to_runtime_expr(expr))


def expr_action(expr: Expr) -> Action:
    return ExprAction(to_runtime_expr(expr))


def delete(expr: Expr) -> Action:
    """Create a delete expression."""
    return Delete(to_runtime_expr(expr))


def expr_fact(expr: Expr) -> Fact:
    return ExprFact(to_runtime_expr(expr))


def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
    """Create a union of the given expression."""
    return _UnionBuilder(lhs=lhs)


def set_(lhs: EXPR) -> _SetBuilder[EXPR]:
    """Create a set of the given expression."""
    return _SetBuilder(lhs=lhs)


def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = None) -> _RuleBuilder:
    """Create a rule with the given facts."""
    return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)


def var(name: str, bound: type[EXPR]) -> EXPR:
    """Create a new variable with the given name and type."""
    return cast(EXPR, _var(name, bound))


def _var(name: str, bound: object) -> RuntimeExpr:
    """Create a new variable with the given name and type."""
    if not isinstance(bound, RuntimeClass | RuntimeParamaterizedClass):
        raise TypeError(f"Unexpected type {type(bound)}")
    return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(class_to_ref(bound), VarDecl(name)))


def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
    """Create variables with the given names and type."""
    for name in names.split(" "):
        yield var(name, bound)


@dataclass
class _RewriteBuilder(Generic[EXPR]):
    lhs: EXPR
    ruleset: Ruleset | None

    def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
        lhs = to_runtime_expr(self.lhs)
        return Rewrite(
            _ruleset_name(self.ruleset),
            lhs,
            convert_to_same_type(rhs, lhs),
            _fact_likes(conditions),
        )

    def __str__(self) -> str:
        return f"rewrite({self.lhs})"


@dataclass
class _BirewriteBuilder(Generic[EXPR]):
    lhs: EXPR
    ruleset: Ruleset | None

    def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
        lhs = to_runtime_expr(self.lhs)
        return BiRewrite(
            _ruleset_name(self.ruleset),
            lhs,
            convert_to_same_type(rhs, lhs),
            _fact_likes(conditions),
        )

    def __str__(self) -> str:
        return f"birewrite({self.lhs})"


@dataclass
class _EqBuilder(Generic[EXPR]):
    expr: EXPR

    def to(self, *exprs: EXPR) -> Fact:
        expr = to_runtime_expr(self.expr)
        return Eq([expr] + [convert_to_same_type(e, expr) for e in exprs])

    def __str__(self) -> str:
        return f"eq({self.expr})"


@dataclass
class _SetBuilder(Generic[EXPR]):
    lhs: Expr

    def to(self, rhs: EXPR) -> Action:
        lhs = to_runtime_expr(self.lhs)
        return Set(lhs, convert_to_same_type(rhs, lhs))

    def __str__(self) -> str:
        return f"set_({self.lhs})"


@dataclass
class _UnionBuilder(Generic[EXPR]):
    lhs: Expr

    def with_(self, rhs: EXPR) -> Action:
        lhs = to_runtime_expr(self.lhs)
        return Union_(lhs, convert_to_same_type(rhs, lhs))

    def __str__(self) -> str:
        return f"union({self.lhs})"


@dataclass
class _RuleBuilder:
    facts: tuple[Fact, ...]
    name: str | None
    ruleset: Ruleset | None

    def then(self, *actions: ActionLike) -> Command:
        return Rule(_action_likes(actions), self.facts, self.name or "", _ruleset_name(self.ruleset))


def expr_parts(expr: Expr) -> TypedExprDecl:
    """
    Returns the underlying type and decleration of the expression. Useful for testing structural equality or debugging.
    """
    if not isinstance(expr, RuntimeExpr):
        raise TypeError(f"Expected a RuntimeExpr not {expr}")
    return expr.__egg_typed_expr__


def to_runtime_expr(expr: Expr) -> RuntimeExpr:
    if not isinstance(expr, RuntimeExpr):
        raise TypeError(f"Expected a RuntimeExpr not {expr}")
    return expr


def run(ruleset: Ruleset | None = None, *until: Fact) -> Run:
    """
    Create a run configuration.
    """
    return Run(_ruleset_name(ruleset), tuple(until))


def seq(*schedules: Schedule) -> Schedule:
    """
    Run a sequence of schedules.
    """
    return Sequence(tuple(schedules))


CommandLike = Command | Expr


def _command_like(command_like: CommandLike) -> Command:
    if isinstance(command_like, Expr):
        return expr_action(command_like)
    return command_like


CommandGenerator = Callable[..., Iterable[Command]]


def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
    """
    Calls the function with variables of the type and name of the arguments.
    """
    # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
    # but not in the globals
    current_frame = inspect.currentframe()
    assert current_frame
    register_frame = current_frame.f_back
    assert register_frame
    original_frame = register_frame.f_back
    assert original_frame
    hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
    args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
    return gen(*args)


ActionLike = Action | Expr


def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
    return tuple(map(_action_like, action_likes))


def _action_like(action_like: ActionLike) -> Action:
    if isinstance(action_like, Expr):
        return expr_action(action_like)
    return action_like


FactLike = Fact | Expr


def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]:
    return tuple(map(_fact_like, fact_likes))


def _fact_like(fact_like: FactLike) -> Fact:
    if isinstance(fact_like, Expr):
        return expr_fact(fact_like)
    return fact_like
