import pytest
from datetime import datetime
from unittest.mock import MagicMock
from typing import List

from vellum.workflows.inputs.base import BaseInputs
from vellum.workflows.inputs.dataset_row import DatasetRow
from vellum.workflows.nodes.bases.base import BaseNode
from vellum.workflows.sandbox import WorkflowSandboxRunner
from vellum.workflows.state.base import BaseState
from vellum.workflows.triggers import ScheduleTrigger
from vellum.workflows.workflows.base import BaseWorkflow


@pytest.fixture
def mock_logger(mocker):
    return mocker.patch("vellum.workflows.sandbox.load_logger")


@pytest.mark.parametrize(
    ["run_kwargs", "expected_last_log"],
    [
        ({}, "final_results: first"),
        ({"index": 1}, "final_results: second"),
        ({"index": -4}, "final_results: first"),
        ({"index": 100}, "final_results: second"),
    ],
    ids=["default", "specific", "negative", "out_of_bounds"],
)
def test_sandbox_runner__happy_path(mock_logger, run_kwargs, expected_last_log):
    # GIVEN we capture the logs to stdout
    logs = []
    mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)

    # AND an example workflow
    class Inputs(BaseInputs):
        foo: str

    class StartNode(BaseNode):
        class Outputs(BaseNode.Outputs):
            bar = Inputs.foo

    class Workflow(BaseWorkflow[Inputs, BaseState]):
        graph = StartNode

        class Outputs(BaseWorkflow.Outputs):
            final_results = StartNode.Outputs.bar

    # AND a dataset for this workflow
    inputs: List[Inputs] = [
        Inputs(foo="first"),
        Inputs(foo="second"),
    ]

    # WHEN we run the sandbox
    runner = WorkflowSandboxRunner(workflow=Workflow(), inputs=inputs)
    runner.run(**run_kwargs)

    # THEN we see the logs
    assert logs == [
        "Just started Node: StartNode",
        "Just finished Node: StartNode",
        "Workflow fulfilled!",
        "----------------------------------",
        expected_last_log,
    ]


def test_sandbox_runner_with_dict_inputs(mock_logger):
    """
    Test that WorkflowSandboxRunner can run with dict inputs in DatasetRow.
    """

    # GIVEN we capture the logs to stdout
    logs = []
    mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)

    class Inputs(BaseInputs):
        message: str

    class StartNode(BaseNode):
        class Outputs(BaseNode.Outputs):
            result = Inputs.message

    class Workflow(BaseWorkflow[Inputs, BaseState]):
        graph = StartNode

        class Outputs(BaseWorkflow.Outputs):
            final_output = StartNode.Outputs.result

    dataset = [
        DatasetRow(label="test_row", inputs={"message": "Hello from dict"}),
    ]

    # WHEN we run the sandbox with the DatasetRow containing dict inputs
    runner = WorkflowSandboxRunner(workflow=Workflow(), dataset=dataset)
    runner.run()

    assert logs == [
        "Just started Node: StartNode",
        "Just finished Node: StartNode",
        "Workflow fulfilled!",
        "----------------------------------",
        "final_output: Hello from dict",
    ]


def test_sandbox_runner_with_workflow_trigger(mock_logger):
    """
    Test that WorkflowSandboxRunner can run with DatasetRow containing workflow_trigger.
    """

    # GIVEN we capture the logs to stdout
    logs = []
    mock_logger.return_value.info.side_effect = lambda msg: logs.append(msg)

    class MySchedule(ScheduleTrigger):
        class Config(ScheduleTrigger.Config):
            cron = "* * * * *"
            timezone = "UTC"

    class StartNode(BaseNode):
        class Outputs(BaseNode.Outputs):
            result = MySchedule.current_run_at

    class Workflow(BaseWorkflow):
        graph = MySchedule >> StartNode

        class Outputs(BaseWorkflow.Outputs):
            final_output = StartNode.Outputs.result

    # AND a dataset with workflow_trigger
    dataset = [
        DatasetRow(
            label="test_row",
            inputs={"current_run_at": datetime.min, "next_run_at": datetime.now()},
            workflow_trigger=MySchedule,
        ),
    ]

    # WHEN we run the sandbox with the DatasetRow containing workflow_trigger
    runner = WorkflowSandboxRunner(workflow=Workflow(), dataset=dataset)
    runner.run()

    # THEN the workflow should run successfully
    assert logs == [
        "Just started Node: StartNode",
        "Just finished Node: StartNode",
        "Workflow fulfilled!",
        "----------------------------------",
        "final_output: 0001-01-01 00:00:00",
    ]

    # AND the dataset row should still have the trigger class
    assert dataset[0].workflow_trigger == MySchedule


def test_sandbox_runner_with_node_output_mocks(mock_logger, mocker):
    """
    Tests that WorkflowSandboxRunner passes node_output_mocks from DatasetRow to workflow.stream().
    """

    class Inputs(BaseInputs):
        message: str

    class TestNode(BaseNode):
        class Outputs(BaseNode.Outputs):
            result: str

    class Workflow(BaseWorkflow[Inputs, BaseState]):
        graph = TestNode

        class Outputs(BaseWorkflow.Outputs):
            final_output = TestNode.Outputs.result

    mock_outputs = TestNode.Outputs(result="mocked_result")

    # AND a dataset with node_output_mocks
    dataset = [
        DatasetRow(
            label="test_with_mocks",
            inputs={"message": "test"},
            node_output_mocks=[mock_outputs],
        ),
    ]

    workflow_instance = Workflow()
    original_stream = workflow_instance.stream
    stream_mock = MagicMock(return_value=original_stream(inputs=Inputs(message="test")))
    mocker.patch.object(workflow_instance, "stream", stream_mock)

    # WHEN we run the sandbox with the DatasetRow containing node_output_mocks
    runner = WorkflowSandboxRunner(workflow=workflow_instance, dataset=dataset)
    runner.run()

    stream_mock.assert_called_once()
    call_kwargs = stream_mock.call_args.kwargs
    assert "node_output_mocks" in call_kwargs
    assert call_kwargs["node_output_mocks"] == [mock_outputs]
