# Copyright (c) 2025 Krnel
# Points of Contact:
#   - kimmy@krnel.ai

from typing import Any, Annotated
from pydantic import SerializerFunctionWrapHandler, field_serializer
import pytest
from krnel.graph import OpSpec
from krnel.graph.op_spec import find_subclass_of, graph_deserialize, graph_serialize, ExcludeFromUUID


class ExampleDataSource(OpSpec):
    """Example op spec for testing."""
    dataset_name: str
    import_date: str

DATA_SOURCE_A = ExampleDataSource(
    dataset_name="test",
    import_date="2023-01-01",
)
DATA_SOURCE_A__UUID = "ExampleDataSource_3dd8e10c1751a80fc03dcd280e5e13e5f115515b35eb7abaae2420883eff783c"
DATA_SOURCE_A__JSON_PARTIAL = {
    "type": "ExampleDataSource",
    "dataset_name": "test",
    "import_date": "2023-01-01",
}


class DatasetDoubleSize(OpSpec):
    """Operation that doubles the size of the dataset!"""
    power_level: str = "DOUBLE"
    source_dataset: ExampleDataSource

OPERATION_A = DatasetDoubleSize(
    source_dataset=DATA_SOURCE_A,
)
OPERATION_A__UUID = "DatasetDoubleSize_c904d918e458fedc5710df9dbae0e8c6428742d93c29d36b42b6ad1e89fe55c4"
OPERATION_A__JSON_PARTIAL = {
    "type": "DatasetDoubleSize",
    # Partial serialization for uuid (internal):
    "source_dataset": DATA_SOURCE_A__UUID,
    "power_level": "DOUBLE",
}
OPERATION_A__JSON_GRAPH = {
    "outputs": [OPERATION_A__UUID],
    "nodes": {
        OPERATION_A__UUID: OPERATION_A__JSON_PARTIAL,
        DATA_SOURCE_A__UUID: DATA_SOURCE_A__JSON_PARTIAL,
    }
}

OPERATION_B = DatasetDoubleSize(
    source_dataset=DATA_SOURCE_A,
    power_level="TRIPLE",
)
OPERATION_B__UUID = "DatasetDoubleSize_632f895b6c26d7c170f5c4bf380a1b32921b2539ce6d079be29f810968e6e5e7"
OPERATION_B__JSON_PARTIAL = {
    "type": "DatasetDoubleSize",
    # Partial serialization for uuid (internal):
    "source_dataset": DATA_SOURCE_A__UUID,
    "power_level": "TRIPLE",
}

class CombineTwoOperations(OpSpec):
    op_a: OpSpec
    op_b: OpSpec
OPERATION_COMBINE_TWO = CombineTwoOperations(
    op_a=OPERATION_A,
    op_b=DatasetDoubleSize(source_dataset=DATA_SOURCE_A, power_level="TRIPLE"),
)
OPERATION_COMBINE_TWO__UUID = "CombineTwoOperations_5e2fc5026c4402e9dd496b79af5bb78535d5ddc2c6c9b1363492b5b718b3d2cb"
OPERATION_COMBINE_TWO__JSON_PARTIAL = {
    "type": "CombineTwoOperations",
    "op_a": OPERATION_A__UUID,
    "op_b": OPERATION_B__UUID,
}
OPERATION_COMBINE_TWO__JSON_GRAPH = {
    "outputs": [OPERATION_COMBINE_TWO__UUID],
    "nodes": {
        OPERATION_COMBINE_TWO__UUID: OPERATION_COMBINE_TWO__JSON_PARTIAL,
        OPERATION_A__UUID: OPERATION_A__JSON_PARTIAL,
        OPERATION_B__UUID: OPERATION_B__JSON_PARTIAL,
        DATA_SOURCE_A__UUID: DATA_SOURCE_A__JSON_PARTIAL,
    }
}

def test_op_spec_immutability():
    """Test that OpSpec instances are immutable."""
    spec = DATA_SOURCE_A
    assert spec.dataset_name == "test"
    assert spec.import_date == "2023-01-01"

    # Try to mutate and expect immutability error (pydantic.ValidationError)
    from pydantic_core import ValidationError
    with pytest.raises(ValidationError):
        spec.dataset_name = "changed"


def test_op_spec_uuid():
    """Test that OpSpec generates a UUID."""
    spec = DATA_SOURCE_A
    spec2 = OPERATION_A

    # Update to match the actual UUID generated by the implementation (with 'type' included)
    assert spec.uuid == DATA_SOURCE_A__UUID
    assert spec2.uuid == OPERATION_A__UUID


def test_op_spec_uuid_consistency():
    """Test that identical specs have the same UUID."""
    spec1 = ExampleDataSource(dataset_name="test", import_date="2023-01-01")
    spec2 = ExampleDataSource(dataset_name="test", import_date="2023-01-01")
    spec3 = ExampleDataSource(dataset_name="bar", import_date="2023-01-01")

    assert spec1.uuid == DATA_SOURCE_A__UUID
    assert spec2.uuid == spec1.uuid
    assert spec3.uuid != spec1.uuid


def test_op_spec_serialization():
    """Test OpSpec serialization behavior."""
    data = DATA_SOURCE_A
    op_a = OPERATION_A

    # Test serialization of a single OpSpec instance contains child UUIDs
    data_dict = data.model_dump()
    assert data_dict == DATA_SOURCE_A__JSON_PARTIAL
    op_a_dict = op_a.model_dump()
    assert op_a_dict == OPERATION_A__JSON_PARTIAL

def test_op_spec_graph_serialization():
    """Test serializing a graph of OpSpec instances."""
    op_a = OPERATION_A
    op_b = OPERATION_B
    op_combined = OPERATION_COMBINE_TWO
    op_a_graphdict = graph_serialize(op_a)
    assert op_a_graphdict == OPERATION_A__JSON_GRAPH
    op_combined_graphdict = graph_serialize(op_combined)
    assert op_combined_graphdict == OPERATION_COMBINE_TWO__JSON_GRAPH

def test_op_spec_graph_deserialization():
    """Test deserializing a graph of OpSpec instances."""
    serialized = OPERATION_COMBINE_TWO__JSON_GRAPH
    [op] = graph_deserialize(serialized)
    assert op == OPERATION_COMBINE_TWO
    assert op.uuid == OPERATION_COMBINE_TWO__UUID
    assert isinstance(op, CombineTwoOperations)
    assert isinstance(op.op_a, DatasetDoubleSize)
    assert op.op_a.uuid == OPERATION_A__UUID
    assert isinstance(op.op_b, DatasetDoubleSize)
    assert op.op_b.uuid == OPERATION_B__UUID


def test_op_spec_reserialization_fails():
    """Test that a OpSpec fails to deserialize if type is missing."""
    serialized_graph = {
        "outputs": ["abcd"],
        "nodes": {
            "abcd": {
                "type": "SomeMissingType",
            }
        }
    }
    with pytest.raises(ValueError, match="Class with name SomeMissingType not found in OpSpec hierarchy"):
        graph_deserialize(serialized_graph)

def test_op_spec_parents():
    class SomeComparison(OpSpec):
        op_a: DatasetDoubleSize
        op_b: DatasetDoubleSize

    some_comparison = SomeComparison(
        op_a=OPERATION_A,
        op_b=DatasetDoubleSize(source_dataset=DATA_SOURCE_A),
    )

    assert OPERATION_A.get_dependencies() == [DATA_SOURCE_A] # type: ignore[reportUnhashable]


def test_get_parents_of_type_specific():
    """Test get_parents with a specific type filters parents."""
    class ParentA(OpSpec):
        foo: str
    class ParentB(OpSpec):
        bar: int
    class Child(OpSpec):
        a: ParentA
        b: ParentB
    pa = ParentA(foo="x")
    pb = ParentB(bar=1)
    c = Child(a=pa, b=pb)
    parents = c.get_dependencies()
    assert parents == [pa, pb]  # type: ignore[reportUnhashable]


def test_get_parents_recursive():
    """Test get_parents with recursive=True collects all ancestors."""
    class Grandparent(OpSpec):
        foo: str
    class Parent(OpSpec):
        gp: Grandparent
    class Child(OpSpec):
        p: Parent
    gp = Grandparent(foo="z")
    p = Parent(gp=gp)
    c = Child(p=p)
    parents = c.get_dependencies(recursive=True)
    assert {x for x in parents} == {p, gp}  # type: ignore[reportUnhashable]


def test_get_parents_no_parents():
    """Test get_parents returns empty set if no parents."""
    class Standalone(OpSpec):
        foo: int
    s = Standalone(foo=1)
    assert s.get_dependencies() == []


def test_serialize_as_any_annotation_working():
    # https://github.com/pydantic/pydantic/issues/12121
    from pydantic import BaseModel, SerializeAsAny
    class User(BaseModel):
        name: str
    class UserLogin(User):
        password: str
    class OuterModel(BaseModel):
        as_any: SerializeAsAny[User]
        as_user: User

    user = UserLogin(name='pydantic', password='password')
    assert OuterModel(as_any=user, as_user=user).model_dump() == {
        'as_any': {'name': 'pydantic', 'password': 'password'},
        'as_user': {'name': 'pydantic'},
    }

@pytest.mark.skip(reason="upstream bug in pydantic")
def test_serialize_as_any_annotation_broken():
    # https://github.com/pydantic/pydantic/issues/12121
    from pydantic import BaseModel, SerializeAsAny
    class User(BaseModel):
        name: str
    class UserLogin(User):
        password: str
    class OuterModel(BaseModel):
        @field_serializer('*', mode='wrap')
        def custom_field_serializer(v: Any, nxt: SerializerFunctionWrapHandler):
            return nxt(v)
        as_any: SerializeAsAny[User]
        as_user: User

    user = UserLogin(name='pydantic', password='password')
    assert OuterModel(as_any=user, as_user=user).model_dump() == {
        'as_any': {'name': 'pydantic', 'password': 'password'},
        'as_user': {'name': 'pydantic'},
    }


def test_serialize_opspec_field_list():
    """Test that OpSpec fields are serialized correctly."""
    class ExampleOp(OpSpec):
        name: str
        datasets: list[OpSpec] = []

    op1 = ExampleOp(name="op1")
    op2 = ExampleOp(name="op2", datasets=[op1])

    serialized = op2.model_dump()
    assert serialized['name'] == "op2"
    assert len(serialized['datasets']) == 1
    assert serialized['datasets'][0] == op1.uuid

    graph_serialized = graph_serialize(op2)
    [reserialized] = graph_deserialize(graph_serialized)
    assert reserialized == op2
    assert reserialized.uuid == op2.uuid

def test_serialize_opspec_field_dict():
    """Test that OpSpec fields are serialized correctly."""
    class ExampleDictOp(OpSpec):
        name: str
        datasets: dict[str, OpSpec] = {}

    op1 = ExampleDictOp(name="op1", datasets={})
    op2 = ExampleDictOp(name="op2", datasets={"dataset1": op1})

    serialized = op2.model_dump()
    assert serialized['name'] == "op2"
    assert len(serialized['datasets']) == 1
    assert serialized['datasets']['dataset1'] == op1.uuid

    graph_serialized = graph_serialize(op2)
    [reserialized] = graph_deserialize(graph_serialized)
    assert reserialized == op2
    assert reserialized.uuid == op2.uuid


def test_two_subclasses_same_name_should_fail():
    """Test that two subclasses with the same name raises an error."""
    class BaseOp(OpSpec):
        pass

    class BaseOp(OpSpec):
        int_field: int

    with pytest.raises(ValueError, match="Multiple subclasses found for BaseOp"):
        find_subclass_of(OpSpec, "BaseOp")


def test_deserialize_missing_node_uuid_should_fail():
    """Test that deserializing a graph with a missing node UUID raises an error."""
    serialized_graph = {
        "outputs": ["some-uuid"],
        "nodes": {
            # "some-uuid" is missing
            "other-uuid": {
                "type": "ExampleDataSource",
                "dataset_name": "test",
                "import_date": "2023-01-01",
            }
        }
    }
    with pytest.raises(ValueError, match="Node with UUID some-uuid not found in graph"):
        graph_deserialize(serialized_graph)


def test_deserialize_uuid_mismatch_should_fail():
    """Test that deserializing a graph whose UUID no longer matches the content raises an error."""
    serialized_graph = {
        "outputs": ["some-uuid"],
        "nodes": {
            # incorrect UUID
            "some-uuid": {
                "type": "ExampleDataSource",
                "dataset_name": "test",
                "import_date": "2023-01-01",
            }
        }
    }
    with pytest.raises(ValueError, match="UUID mismatch"):
        graph_deserialize(serialized_graph)


def test_deserialize_cycle_should_fail():
    """Test that deserializing a graph with a cycle raises an error."""
    serialized_graph = {
        "outputs": ["node-a"],
        "nodes": {
            "node-a": {
                "type": "DatasetDoubleSize",
                "power_level": "DOUBLE",
                "source_dataset": "node-b",  # points to node-b
            },
            "node-b": {
                "type": "DatasetDoubleSize",
                "power_level": "DOUBLE",
                "source_dataset": "node-a",  # points back to node-a, creating a cycle
            }
        }
    }
    with pytest.raises(ValueError, match="Cycle detected in graph"):
        graph_deserialize(serialized_graph)


def test_deserializing_unreachable_nodes_should_fail():
    """Test that deserializing a graph with unreachable nodes raises an error."""
    serialized_graph = {
        "outputs": [DATA_SOURCE_A__UUID],
        "nodes": {
            DATA_SOURCE_A__UUID: DATA_SOURCE_A__JSON_PARTIAL,
            "unreachable-node": {
                "type": "ExampleDataSource",
                "dataset_name": "unreachable",
                "import_date": "2023-01-01",
            }
        }
    }
    with pytest.raises(ValueError, match="Unreachable nodes in graph"):
        graph_deserialize(serialized_graph)


def test_op_spec_subs_basic():
    """Test basic OpSpec.subs functionality with simple field updates."""
    original = DATA_SOURCE_A
    original_uuid = original.uuid

    # Test substituting a single field
    modified = original.subs(dataset_name="modified_test")

    assert original.dataset_name == "test"  # original unchanged
    assert modified.dataset_name == "modified_test"
    assert modified.import_date == "2023-01-01"  # unchanged
    assert modified.uuid != original_uuid  # different content means different UUID
    assert original.uuid == original_uuid  # original UUID unchanged
    assert isinstance(modified, ExampleDataSource)


def test_op_spec_subs_multiple_fields():
    """Test OpSpec.subs with multiple field updates."""
    original = DATA_SOURCE_A
    original_uuid = original.uuid

    modified = original.subs(
        dataset_name="new_dataset",
        import_date="2024-01-01"
    )

    assert modified.dataset_name == "new_dataset"
    assert modified.import_date == "2024-01-01"
    assert original.dataset_name == "test"  # original unchanged
    assert original.import_date == "2023-01-01"  # original unchanged
    assert modified.uuid != original_uuid  # different content means different UUID
    assert original.uuid == original_uuid  # original UUID unchanged


def test_op_spec_subs_no_changes():
    """Test OpSpec.subs with no changes results in equivalent object."""
    original = DATA_SOURCE_A
    modified = original.subs()

    assert modified.dataset_name == original.dataset_name
    assert modified.import_date == original.import_date
    assert modified.uuid == original.uuid  # same content means same UUID
    assert modified == original


def test_op_spec_subs_with_opspec_field():
    """Test OpSpec.subs with OpSpec field updates."""
    original = OPERATION_A
    original_uuid = original.uuid
    new_source = ExampleDataSource(
        dataset_name="different_source",
        import_date="2024-01-01"
    )

    modified = original.subs(source_dataset=new_source)

    assert modified.source_dataset == new_source
    assert modified.power_level == "DOUBLE"  # unchanged
    assert original.source_dataset == DATA_SOURCE_A  # original unchanged
    assert modified.uuid != original_uuid  # different content means different UUID
    assert original.uuid == original_uuid  # original UUID unchanged


def test_op_spec_subs_complex_nested():
    """Test OpSpec.subs on complex nested structures."""
    original = OPERATION_COMBINE_TWO
    original_uuid = original.uuid
    new_op_a = DatasetDoubleSize(
        source_dataset=DATA_SOURCE_A,
        power_level="QUADRUPLE"
    )

    modified = original.subs(op_a=new_op_a)

    assert modified.op_a == new_op_a
    assert modified.op_b == original.op_b  # unchanged
    assert modified.uuid != original_uuid  # different content means different UUID
    assert original.uuid == original_uuid  # original UUID unchanged
    assert isinstance(modified, CombineTwoOperations)


def test_op_spec_subs_preserves_type():
    """Test that OpSpec.subs preserves the original class type."""
    original = OPERATION_A
    original_uuid = original.uuid
    modified = original.subs(power_level="TRIPLE")

    assert type(modified) is type(original)
    assert isinstance(modified, DatasetDoubleSize)
    assert modified.uuid != original_uuid  # different content means different UUID
    assert original.uuid == original_uuid  # original UUID unchanged


def test_op_spec_subs_invalid_field():
    """Test that OpSpec.subs raises error for invalid field names."""
    original = DATA_SOURCE_A

    with pytest.raises(ValueError, match="Invalid field names for ExampleDataSource: \\['nonexistent_field'\\]"):
        original.subs(nonexistent_field="value")


def test_op_spec_subs_multiple_invalid_fields():
    """Test that OpSpec.subs raises error for multiple invalid field names."""
    original = DATA_SOURCE_A

    with pytest.raises(ValueError, match="Invalid field names for ExampleDataSource: \\['bad_field', 'nonexistent_field'\\]"):
        original.subs(nonexistent_field="value", bad_field="another_value")


def test_op_spec_subs_with_list_field():
    """Test OpSpec.subs with list fields containing OpSpecs."""
    class MultiSourceOp(OpSpec):
        name: str
        sources: list[ExampleDataSource] = []

    source1 = ExampleDataSource(dataset_name="source1", import_date="2023-01-01")
    source2 = ExampleDataSource(dataset_name="source2", import_date="2023-01-02")

    original = MultiSourceOp(name="multi", sources=[source1])
    original_uuid = original.uuid
    modified = original.subs(sources=[source1, source2])

    assert len(modified.sources) == 2
    assert modified.sources[0] == source1
    assert modified.sources[1] == source2
    assert len(original.sources) == 1  # original unchanged
    assert modified.uuid != original_uuid  # different content means different UUID
    assert original.uuid == original_uuid  # original UUID unchanged


def test_op_spec_subs_with_dict_field():
    """Test OpSpec.subs with dict fields containing OpSpecs."""
    class DictSourceOp(OpSpec):
        name: str
        source_map: dict[str, ExampleDataSource] = {}

    source1 = ExampleDataSource(dataset_name="source1", import_date="2023-01-01")
    source2 = ExampleDataSource(dataset_name="source2", import_date="2023-01-02")

    original = DictSourceOp(name="dict_op", source_map={"a": source1})
    original_uuid = original.uuid
    modified = original.subs(source_map={"a": source1, "b": source2})

    assert len(modified.source_map) == 2
    assert modified.source_map["a"] == source1
    assert modified.source_map["b"] == source2
    assert len(original.source_map) == 1  # original unchanged
    assert modified.uuid != original_uuid  # different content means different UUID
    assert original.uuid == original_uuid  # original UUID unchanged


def test_op_spec_subs_immutability():
    """Test that OpSpec.subs doesn't modify the original instance."""
    original = DATA_SOURCE_A
    original_uuid = original.uuid
    original_dataset_name = original.dataset_name

    modified = original.subs(dataset_name="changed")

    # Original should be completely unchanged
    assert original.uuid == original_uuid
    assert original.dataset_name == original_dataset_name
    assert original != modified


def test_op_spec_subs_substitute_single_tuple():
    """Test OpSpec.subs with single tuple substitution."""
    old_source = ExampleDataSource(dataset_name="old", import_date="2023-01-01")
    new_source = ExampleDataSource(dataset_name="new", import_date="2023-01-01")
    operation = DatasetDoubleSize(source_dataset=old_source, power_level="DOUBLE")

    # Substitute old_source with new_source in the graph
    modified = operation.subs(substitute=(old_source, new_source))

    assert modified.source_dataset == new_source
    assert modified.source_dataset.dataset_name == "new"
    assert modified.power_level == "DOUBLE"  # unchanged
    assert operation.source_dataset == old_source  # original unchanged
    assert modified.uuid != operation.uuid


def test_op_spec_subs_substitute_list_of_tuples():
    """Test OpSpec.subs with list of tuple substitutions."""
    source1 = ExampleDataSource(dataset_name="source1", import_date="2023-01-01")
    source2 = ExampleDataSource(dataset_name="source2", import_date="2023-01-02")
    new_source1 = ExampleDataSource(dataset_name="new_source1", import_date="2023-01-01")
    new_source2 = ExampleDataSource(dataset_name="new_source2", import_date="2023-01-02")

    op1 = DatasetDoubleSize(source_dataset=source1, power_level="DOUBLE")
    op2 = DatasetDoubleSize(source_dataset=source2, power_level="TRIPLE")
    combined = CombineTwoOperations(op_a=op1, op_b=op2)

    # Substitute both sources in the graph
    modified = combined.subs(substitute=[(source1, new_source1), (source2, new_source2)])

    assert modified.op_a.source_dataset == new_source1
    assert modified.op_b.source_dataset == new_source2
    assert modified.op_a.source_dataset.dataset_name == "new_source1"
    assert modified.op_b.source_dataset.dataset_name == "new_source2"
    assert combined.op_a.source_dataset == source1  # original unchanged
    assert combined.op_b.source_dataset == source2  # original unchanged
    assert modified.uuid != combined.uuid


def test_op_spec_subs_substitute_single_opspec_with_changes():
    """Test OpSpec.subs with single OpSpec substitution and field changes.
    """
    old_source = ExampleDataSource(dataset_name="old", import_date="2023-01-01")
    operation = DatasetDoubleSize(source_dataset=old_source, power_level="DOUBLE")

    new_operation = operation.subs(
        old_source, dataset_name="updated", import_date="2024-01-01"
    )
    assert new_operation.source_dataset.dataset_name == "updated"
    assert new_operation.source_dataset.import_date == "2024-01-01"
    assert new_operation.power_level == "DOUBLE"  # unchanged
    assert operation.source_dataset.dataset_name == "old"  # original unchanged
    assert new_operation.uuid != operation.uuid  # different content means different UUID


def test_op_spec_subs_substitute_nested_dependencies():
    """Test OpSpec.subs with substitutions in nested dependency graphs."""
    source = ExampleDataSource(dataset_name="original", import_date="2023-01-01")
    op1 = DatasetDoubleSize(source_dataset=source, power_level="DOUBLE")
    op2 = DatasetDoubleSize(source_dataset=source, power_level="TRIPLE")
    # Create a nested structure where op1 and op2 both depend on source
    combined = CombineTwoOperations(op_a=op1, op_b=op2)

    new_source = ExampleDataSource(dataset_name="replacement", import_date="2023-01-01")

    # Substitute the root source - should propagate through the graph
    modified = combined.subs(substitute=(source, new_source))

    assert modified.op_a.source_dataset == new_source
    assert modified.op_b.source_dataset == new_source
    assert modified.op_a.source_dataset.dataset_name == "replacement"
    assert modified.op_b.source_dataset.dataset_name == "replacement"
    assert modified.op_a.power_level == "DOUBLE"  # unchanged
    assert modified.op_b.power_level == "TRIPLE"  # unchanged
    assert combined.op_a.source_dataset == source  # original unchanged
    assert combined.op_b.source_dataset == source  # original unchanged
    assert modified.uuid != combined.uuid


def test_op_spec_subs_substitute_no_match():
    """Test OpSpec.subs when substitution node is not in the graph raises an error."""
    source = ExampleDataSource(dataset_name="source", import_date="2023-01-01")
    operation = DatasetDoubleSize(source_dataset=source, power_level="DOUBLE")

    unrelated_source = ExampleDataSource(dataset_name="unrelated", import_date="2023-01-01")
    new_source = ExampleDataSource(dataset_name="new", import_date="2023-01-01")

    # This should raise an error since unrelated_source is not in the graph
    with pytest.raises(ValueError, match="Supposed to substitute"):
        operation.subs(substitute=(unrelated_source, new_source))


def test_op_spec_subs_substitute_identity():
    """Test OpSpec.subs substituting a node with itself."""
    source = ExampleDataSource(dataset_name="source", import_date="2023-01-01")
    operation = DatasetDoubleSize(source_dataset=source, power_level="DOUBLE")

    # Replace source with identical source - should result in same graph
    equivalent_source = ExampleDataSource(dataset_name="source", import_date="2023-01-01")
    modified = operation.subs(substitute=(source, equivalent_source))

    # Should be different instance but same content and UUID
    assert modified.source_dataset == equivalent_source
    assert modified.source_dataset.uuid == source.uuid  # same content
    assert modified.uuid == operation.uuid  # graph structure unchanged


def test_op_spec_subs_substitute_error_conflicting_args():
    """Test OpSpec.subs raises error when both substitutions and changes are provided."""
    source = ExampleDataSource(dataset_name="source", import_date="2023-01-01")
    new_source = ExampleDataSource(dataset_name="new", import_date="2023-01-01")
    operation = DatasetDoubleSize(source_dataset=source, power_level="DOUBLE")

    with pytest.raises(ValueError, match="Cannot provide both substitutions and field changes"):
        operation.subs(substitute=[(source, new_source)], power_level="TRIPLE")


def test_op_spec_subs_substitute_error_invalid_type():
    """Test OpSpec.subs raises error for invalid substitute parameter types."""
    source = ExampleDataSource(dataset_name="source", import_date="2023-01-01")
    operation = DatasetDoubleSize(source_dataset=source, power_level="DOUBLE")

    # Invalid substitute type
    with pytest.raises(ValueError, match="Invalid substitute argument"):
        operation.subs(substitute="invalid")

    # Invalid tuple length
    with pytest.raises(ValueError, match="Invalid substitute argument"):
        operation.subs(substitute=(source,))

    # Invalid tuple contents
    with pytest.raises(ValueError, match="Invalid substitute argument"):
        operation.subs(substitute=(source, "invalid"))


def test_op_spec_subs_substitute_preserve_type():
    """Test OpSpec.subs with substitutions preserves the root type."""
    source = ExampleDataSource(dataset_name="source", import_date="2023-01-01")
    new_source = ExampleDataSource(dataset_name="new", import_date="2023-01-01")
    operation = DatasetDoubleSize(source_dataset=source, power_level="DOUBLE")

    modified = operation.subs(substitute=(source, new_source))

    assert type(modified) is type(operation)
    assert isinstance(modified, DatasetDoubleSize)


def test_op_spec_subs_substitute_complex_graph():
    """Test OpSpec.subs with substitutions in a complex multi-level graph."""
    source = ExampleDataSource(dataset_name="root", import_date="2023-01-01")
    op1 = DatasetDoubleSize(source_dataset=source, power_level="DOUBLE")
    op2 = DatasetDoubleSize(source_dataset=source, power_level="TRIPLE")
    combined = CombineTwoOperations(op_a=op1, op_b=op2)

    new_source = ExampleDataSource(dataset_name="new_root", import_date="2023-01-01")

    # Replace root source - should affect both branches
    modified = combined.subs(substitute=(source, new_source))

    assert modified.op_a.source_dataset == new_source
    assert modified.op_b.source_dataset == new_source
    assert modified.op_a.source_dataset.dataset_name == "new_root"
    assert modified.op_b.source_dataset.dataset_name == "new_root"
    assert modified.op_a.power_level == "DOUBLE"  # unchanged
    assert modified.op_b.power_level == "TRIPLE"  # unchanged
    assert combined.op_a.source_dataset == source  # original unchanged
    assert combined.op_b.source_dataset == source  # original unchanged
    assert modified.uuid != combined.uuid


def test_exclude_from_uuid_annotation():
    """Test that ExcludeFromUUID annotation excludes fields from UUID computation."""

    class TestOpWithExcluded(OpSpec):
        important_field: str
        cache_field: Annotated[str, ExcludeFromUUID()] = ""

    # Two instances with same important_field but different cache_field
    op1 = TestOpWithExcluded(important_field="test", cache_field="cache1")
    op2 = TestOpWithExcluded(important_field="test", cache_field="cache2")

    # Should have the same UUID since cache_field is excluded
    assert op1.uuid == op2.uuid

    # But different important_field should give different UUID
    op3 = TestOpWithExcluded(important_field="different", cache_field="cache1")
    assert op3.uuid != op1.uuid


def test_exclude_from_uuid_serialization_still_includes_field():
    """Test that excluded fields are still included in regular serialization."""

    class TestOpWithExcluded(OpSpec):
        important_field: str
        cache_field: Annotated[str, ExcludeFromUUID()] = ""

    op = TestOpWithExcluded(important_field="test", cache_field="cached_value")

    # Regular serialization should include all fields
    serialized = op.model_dump()
    assert "important_field" in serialized
    assert "cache_field" in serialized
    assert serialized["important_field"] == "test"
    assert serialized["cache_field"] == "cached_value"

    # UUID-specific serialization should exclude cache_field
    uuid_serialized = op._model_dump_for_uuid()
    assert "important_field" in uuid_serialized
    assert "cache_field" not in uuid_serialized
    assert uuid_serialized["important_field"] == "test"


def test_exclude_from_uuid_multiple_exclusions():
    """Test multiple fields can be excluded from UUID computation."""

    class TestOpWithMultipleExcluded(OpSpec):
        data: str
        cache_ttl: Annotated[int, ExcludeFromUUID()] = 3600
        last_access: Annotated[str, ExcludeFromUUID()] = ""
        debug_info: Annotated[str, ExcludeFromUUID()] = ""

    op1 = TestOpWithMultipleExcluded(
        data="important",
        cache_ttl=1800,
        last_access="2023-01-01",
        debug_info="trace1"
    )
    op2 = TestOpWithMultipleExcluded(
        data="important",
        cache_ttl=7200,
        last_access="2023-12-31",
        debug_info="trace2"
    )

    # Should have the same UUID since only excluded fields differ
    assert op1.uuid == op2.uuid

    # But different data field should give different UUID
    op3 = TestOpWithMultipleExcluded(data="different")
    assert op3.uuid != op1.uuid


def test_exclude_from_uuid_graph_serialization():
    """Test that excluded fields don't affect graph serialization UUIDs."""

    class TestOpWithExcluded(OpSpec):
        name: str
        cache_info: Annotated[str, ExcludeFromUUID()] = ""

    op1 = TestOpWithExcluded(name="op1", cache_info="cache1")
    op2 = TestOpWithExcluded(name="op1", cache_info="cache2")  # different cache

    # Graph serialization should produce identical results (same UUIDs)
    graph1 = graph_serialize(op1)
    graph2 = graph_serialize(op2)

    assert graph1["outputs"] == graph2["outputs"]
    assert list(graph1["nodes"].keys()) == list(graph2["nodes"].keys())

    # But the node data should still contain the cache_info field
    node_uuid = op1.uuid
    assert graph1["nodes"][node_uuid]["cache_info"] == "cache1"
    assert graph2["nodes"][node_uuid]["cache_info"] == "cache2"


def test_exclude_from_uuid_with_opspec_dependencies():
    """Test UUID exclusion works correctly with OpSpec dependencies."""

    class TestParent(OpSpec):
        data: str
        metadata: Annotated[str, ExcludeFromUUID()] = ""

    class TestChild(OpSpec):
        parent: TestParent
        child_data: str
        child_cache: Annotated[str, ExcludeFromUUID()] = ""

    parent1 = TestParent(data="parent_data", metadata="meta1")
    parent2 = TestParent(data="parent_data", metadata="meta2")  # different metadata

    # Parents should have same UUID (metadata excluded)
    assert parent1.uuid == parent2.uuid

    child1 = TestChild(parent=parent1, child_data="child", child_cache="cache1")
    child2 = TestChild(parent=parent2, child_data="child", child_cache="cache2")

    # Children should have same UUID (both cache fields excluded, parents equivalent)
    assert child1.uuid == child2.uuid

    # But different child_data should give different UUID
    child3 = TestChild(parent=parent1, child_data="different", child_cache="cache1")
    assert child3.uuid != child1.uuid


def test_graph_of_multiple_types():
    class SomeSourceA(OpSpec):
        name: str | int
    class SomeSourceB(OpSpec):
        name2: int
    class SomeSourceC(OpSpec):
        in_src: SomeSourceA | SomeSourceB
    a = SomeSourceA(name="a")
    b = SomeSourceB(name2=1)
    c1 = SomeSourceC(in_src=a)
    c2 = SomeSourceC(in_src=b)

    graph_serialized = graph_serialize(c1)
    [reserialized] = graph_deserialize(graph_serialized)
    assert reserialized == c1
    assert reserialized.uuid == c1.uuid
    graph_serialized = graph_serialize(c2)
    [reserialized] = graph_deserialize(graph_serialized)
    assert reserialized == c2
    assert reserialized.uuid == c2.uuid

def test_invalid_graph_of_multiple_types():
    class SomeSourceA(OpSpec):
        name: str
    class SomeSourceB(OpSpec):
        name2: int
    class SomeSourceC(OpSpec):
        in_src: SomeSourceA | SomeSourceB | str
    a = SomeSourceA(name="a")
    b = SomeSourceB(name2=1)
    c1 = SomeSourceC(in_src=a)
    c2 = SomeSourceC(in_src=b)

    with pytest.raises(Exception):
        graph_serialized = graph_serialize(c1)
        [reserialized] = graph_deserialize(graph_serialized)
        assert reserialized == c1
        assert reserialized.uuid == c1.uuid
        graph_serialized = graph_serialize(c2)
        [reserialized] = graph_deserialize(graph_serialized)
        assert reserialized == c2
        assert reserialized.uuid == c2.uuid

def test_valid_graph_of_multiple_types():
    class SomeSourceD(OpSpec):
        name: str
    class SomeSourceE(OpSpec):
        name2: int
    class SomeSourceF(OpSpec):
        in_src: SomeSourceD | SomeSourceE | None
        in_src_2: SomeSourceD | SomeSourceE | None = None
    a = SomeSourceD(name="a")
    b = SomeSourceE(name2=1)
    c1 = SomeSourceF(in_src=a)
    c2 = SomeSourceF(in_src=b)
    c3 = SomeSourceF(in_src=None)
    c3 = SomeSourceF(in_src=None, in_src_2=a)
    c4 = SomeSourceF(in_src=b, in_src_2=a)
    c5 = SomeSourceF(in_src=a, in_src_2=a)
    def _check(grph):
        graph_serialized = graph_serialize(grph)
        [reserialized] = graph_deserialize(graph_serialized)
        assert reserialized == grph
        assert reserialized.uuid == grph.uuid

    _check(c1)
    _check(c2)
    _check(c3)
    _check(c4)
    _check(c5)