from typing import Any, Generator, List, Optional

from tensorlake.functions_sdk.graph import Graph
from tensorlake.functions_sdk.graph_definition import ComputeGraphMetadata

from .http_client import (
    DEFAULT_SERVICE_URL,
    RequestMetadata,
    TensorlakeClient,
    WorkflowEvent,
)


class RemoteGraph:
    def __init__(
        self,
        name: str,
        server_url: Optional[str] = DEFAULT_SERVICE_URL,
        client: Optional[TensorlakeClient] = None,
    ):
        """
        Create a handle to call a RemoteGraph by name.

        Note: Use the class methods RemoteGraph.deploy or RemoteGraph.by_name to create a RemoteGraph object.

        :param name: The name of the graph.
        :param server_url: The URL of the server where the graph will be registered.
            Not used if client is provided.
        :param client: The IndexifyClient used to communicate with the server.
            Preferred over server_url.
        """
        self._name = name
        if client:
            self._client = client
        else:
            self._client = TensorlakeClient(service_url=server_url)

        self._graph_definition: ComputeGraphMetadata = self._client.graph(self._name)

    def run(self, block_until_done: bool = False, **kwargs) -> str:
        """
        Run the graph with the given inputs. The input is for the start function of the graph.

        :param block_until_done: If True, the function will block until the graph execution is complete.
        :param kwargs: The input to the start function of the graph. Pass the input as keyword arguments.
        :return: The invocation ID of the graph execution.

        Example:
            @tensorlake_function()
            def foo(x: int) -> int:
                return x + 1

            remote_graph = RemoteGraph.by_name("test")
            invocation_id = remote_graph.run(x=1)
        """
        return self._client.call(
            self._name,
            block_until_done,
            self._graph_definition.get_input_encoder(),
            **kwargs,
        )

    def call(self, **kwargs) -> str:
        return self._client.call(
            self._name,
            False,
            self._graph_definition.get_input_encoder(),
            **kwargs,
        )

    def stream(self, **kwargs) -> Generator[WorkflowEvent, None, str]:
        """
        Run the graph with the given inputs, streaming the events generated by the invocation.
        The input is for the start function of the graph.

        :param kwargs: The input to the start function of the graph. Pass the input as keyword arguments.
        :param block_until_done: If True, the function will stream events as they become available,
                                 and then block until the graph execution is complete.
        :return: A generator streaming InvocationEvent objects and returning the invocation ID.

        Example:
            @tensorlake_function()
            def foo(x: int) -> int:
                return x + 1

            remote_graph = RemoteGraph.by_name("test")
            for event in remote_graph.stream(x=1):
                print(event)
        """
        return self._client.call_stream(
            self._name,
            self._graph_definition.get_input_encoder(),
            **kwargs,
        )

    def metadata(self) -> ComputeGraphMetadata:
        """
        Get the metadata of the graph.
        """
        return self._client.graph(self._name)

    def replay_invocations(self):
        """
        Replay all the graph previous runs/invocations on the latest version of the graph.

        This is useful to make all the previous invocations go through
        an updated graph to take advantage of graph improvements.
        """
        self._client.replay_invocations(self._name)

    @classmethod
    def deploy(
        cls,
        graph: Graph,
        code_dir_path: str,
        server_url: Optional[str] = DEFAULT_SERVICE_URL,
        client: Optional[TensorlakeClient] = None,
        upgrade_tasks_to_latest_version: bool = False,
    ):
        """
        Create a new RemoteGraph from a local Graph object.

        :param graph: The local Graph object.
        :param code_dir_path: The path to the directory containing the file where the graph is defined.
        :param server_url: The URL of the server where the graph will be registered.
            Not used if client is provided.
        :param client: The IndexifyClient used to communicate with the server.
            Preferred over server_url.
        :param upgrade_tasks_to_latest_version: If True, all not started and running invocations
            will be upgraded to run on the latest version of the graph. This requires the graph
            code to be backward compatible between the versions.
        """
        graph.validate_graph()
        if not client:
            client = TensorlakeClient(service_url=server_url)

        client.register_compute_graph(
            graph=graph,
            code_dir_path=code_dir_path,
            upgrade_tasks_to_latest_version=upgrade_tasks_to_latest_version,
        )
        return cls(name=graph.name, server_url=server_url, client=client)

    @classmethod
    def by_name(
        cls,
        name: str,
        server_url: Optional[str] = DEFAULT_SERVICE_URL,
        client: Optional[TensorlakeClient] = None,
    ):
        """
        Create a handle to call a RemoteGraph by name.

        :param name: The name of the graph.
        :param server_url: The URL of the server where the graph will be registered.
            Not used if client is provided.
        :param client: The IndexifyClient used to communicate with the server.
            Preferred over server_url.
        :return: A RemoteGraph object.
        """
        return cls(name=name, server_url=server_url, client=client)

    def output(
        self,
        invocation_id: str,
        fn_name: str,
    ) -> List[Any]:
        """
        Returns the extracted objects by a graph for an ingested object.

        - If the extractor name is provided, only the objects extracted by that extractor are returned.
        - If the extractor name is not provided, all the extracted objects are returned for the input object.

        :param invocation_id (str): The ID of the ingested object
        :param fn_name (Optional[str]): The name of the function whose output is to be returned if provided
        :return (List[Any]): Output of the function.
        """

        return self._client.graph_outputs(
            graph=self._name,
            request_id=invocation_id,
            fn_name=fn_name,
        )

    def get_output(self, request_id: str, fn_name: str) -> Any:
        return self._client.graph_outputs(
            graph=self._name,
            request_id=request_id,
            fn_name=fn_name,
        )

    def request(self, request_id: str) -> RequestMetadata:
        return self._client.request(
            graph=self._name,
            request_id=request_id,
        )
