from .base import BaseMachine, root


class GraphMachine(BaseMachine):
    def __init__(self, **kwargs):
        self.style_attributes["node"]["active"] = self.style_attributes["node"][
            "inactive"
        ]
        self.style_attributes["node"]["previous"] = self.style_attributes["node"][
            "inactive"
        ]
        super().__init__(
            show_conditions=True,
            show_state_attributes=True,
            **kwargs,
        )

    def _init_graphviz_engine(self, graph_engine):
        """This method takes the existing graph class selected based on what
        graph engine and creates and returns a subclass."""

        class Graph(super()._init_graphviz_engine(graph_engine)):
            def _transition_label(self, trans):
                """This method is overridden to add displaying of the before and
                after transition callbacks on the graph."""
                label = super()._transition_label(trans)
                if "before" in trans and trans["before"]:
                    label += r"\lbefore: {}".format(", ".join(trans["before"]))
                if "after" in trans and trans["after"]:
                    label += r"\lafter: {}".format(", ".join(trans["after"]))
                return label

        return Graph

    def get_graph(self, *args, **kwargs):
        self._add_model()
        self._model.get_graph(*args, **kwargs)

    def save_graph(self, path):
        """This method saves a graph of the hierarchical state machine to the
        provided path."""
        extension = path.split(".")[-1].lower()
        self.get_graph().render(format=extension, cleanup=True, outfile=path)
