import io
import time
from contextlib import redirect_stderr, redirect_stdout
from typing import Any, Optional

from tensorlake.functions_sdk.functions import (
    FunctionCallResult,
    GraphRequestContext,
    ProgressReporter,
    TensorlakeFunctionWrapper,
)
from tensorlake.functions_sdk.graph_definition import ComputeGraphMetadata
from tensorlake.functions_sdk.invocation_state.invocation_state import RequestState

from ...proto.function_executor_pb2 import (
    Task,
    TaskResult,
)
from ...std_outputs_capture import flush_logs, read_till_the_end
from .function_inputs_loader import FunctionInputs, FunctionInputsLoader
from .response_helper import ResponseHelper


class Handler:
    def __init__(
        self,
        task: Task,
        invocation_state: RequestState,
        function_wrapper: TensorlakeFunctionWrapper,
        function_stdout: io.StringIO,
        function_stderr: io.StringIO,
        graph_metadata: ComputeGraphMetadata,
        progress_reporter: Optional[ProgressReporter],
        logger: Any,
    ):
        self._task: Task = task
        self._invocation_state: RequestState = invocation_state
        self._logger = logger.bind(
            module=__name__,
            invocation_id=task.graph_invocation_id,
            task_id=task.task_id,
            allocation_id=task.allocation_id,
        )
        self._function_wrapper: TensorlakeFunctionWrapper = function_wrapper
        self._function_stdout: io.StringIO = function_stdout
        self._function_stderr: io.StringIO = function_stderr
        self._input_loader = FunctionInputsLoader(task)
        self._response_helper = ResponseHelper(
            function_name=task.function_name,
            graph_metadata=graph_metadata,
            logger=self._logger,
        )
        self._progress_reporter = progress_reporter

    def run(self) -> TaskResult:
        """Runs the task.

        Raises an exception if our own code failed, customer function failure doesn't result in any exception.
        Details of customer function failure are returned in the response.
        """
        self._logger.info("running function")
        start_time = time.monotonic()
        inputs: FunctionInputs = self._input_loader.load()
        result: TaskResult = self._run_task(inputs)
        self._logger.info(
            "function finished",
            duration_sec=f"{time.monotonic() - start_time:.3f}",
        )
        return result

    def _run_task(self, inputs: FunctionInputs) -> TaskResult:
        """Runs the customer function while capturing what happened in it.

        Function stdout and stderr are captured so they don't get into Function Executor process stdout
        and stderr. Raises an exception if our own code failed, customer function failure doesn't result in any exception.
        Details of customer function failure are returned in the response.
        """
        # Flush any logs buffered in memory before doing stdout, stderr capture.
        # Otherwise our logs logged before this point will end up in the function's stdout capture.
        flush_logs(self._function_stdout, self._function_stderr)
        stdout_start: int = self._function_stdout.tell()
        stderr_start: int = self._function_stderr.tell()

        try:
            with redirect_stdout(self._function_stdout), redirect_stderr(
                self._function_stderr
            ):
                result: FunctionCallResult = self._run_func(inputs)
                # Ensure that whatever outputted by the function gets captured.
                flush_logs(self._function_stdout, self._function_stderr)
                return self._response_helper.from_function_call(
                    result=result,
                    is_reducer=_function_is_reducer(self._function_wrapper),
                    stdout=read_till_the_end(self._function_stdout, stdout_start),
                    stderr=read_till_the_end(self._function_stderr, stderr_start),
                )
        except BaseException as e:
            return self._response_helper.from_function_exception(
                exception=e,
                stdout=read_till_the_end(self._function_stdout, stdout_start),
                stderr=read_till_the_end(self._function_stderr, stderr_start),
                metrics=None,
            )

    def _run_func(self, inputs: FunctionInputs) -> FunctionCallResult:
        ctx: GraphRequestContext = GraphRequestContext(
            request_id=self._task.graph_invocation_id,
            graph_name=self._task.graph_name,
            graph_version=self._task.graph_version,
            request_state=self._invocation_state,
            progress_reporter=self._progress_reporter,
        )
        return self._function_wrapper.invoke_fn_ser(
            ctx, inputs.input, inputs.init_value
        )


def _function_is_reducer(func_wrapper: TensorlakeFunctionWrapper) -> bool:
    return func_wrapper.indexify_function.accumulate is not None
