import os
import pytest
import pathlib
from typing import Optional, Union
from pydantic import BaseModel

import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever


class TestAnswer(BaseModel):
    answer: str
    explanation: str


class TestGraphCompletionCoTRetriever:
    @pytest.mark.asyncio
    async def test_graph_completion_cot_context_simple(self):
        system_directory_path = os.path.join(
            pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
        )
        cognee.config.system_root_directory(system_directory_path)
        data_directory_path = os.path.join(
            pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
        )
        cognee.config.data_root_directory(data_directory_path)

        await cognee.prune.prune_data()
        await cognee.prune.prune_system(metadata=True)
        await setup()

        class Company(DataPoint):
            name: str

        class Person(DataPoint):
            name: str
            works_for: Company

        company1 = Company(name="Figma")
        company2 = Company(name="Canva")
        person1 = Person(name="Steve Rodger", works_for=company1)
        person2 = Person(name="Ike Loma", works_for=company1)
        person3 = Person(name="Jason Statham", works_for=company1)
        person4 = Person(name="Mike Broski", works_for=company2)
        person5 = Person(name="Christina Mayer", works_for=company2)

        entities = [company1, company2, person1, person2, person3, person4, person5]

        await add_data_points(entities)

        retriever = GraphCompletionCotRetriever()

        context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))

        assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
        assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"

        answer = await retriever.get_completion("Who works at Canva?")

        assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
        assert all(isinstance(item, str) and item.strip() for item in answer), (
            "Answer must contain only non-empty strings"
        )

    @pytest.mark.asyncio
    async def test_graph_completion_cot_context_complex(self):
        system_directory_path = os.path.join(
            pathlib.Path(__file__).parent,
            ".cognee_system/test_graph_completion_cot_context_complex",
        )
        cognee.config.system_root_directory(system_directory_path)
        data_directory_path = os.path.join(
            pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
        )
        cognee.config.data_root_directory(data_directory_path)

        await cognee.prune.prune_data()
        await cognee.prune.prune_system(metadata=True)
        await setup()

        class Company(DataPoint):
            name: str
            metadata: dict = {"index_fields": ["name"]}

        class Car(DataPoint):
            brand: str
            model: str
            year: int

        class Location(DataPoint):
            country: str
            city: str

        class Home(DataPoint):
            location: Location
            rooms: int
            sqm: int

        class Person(DataPoint):
            name: str
            works_for: Company
            owns: Optional[list[Union[Car, Home]]] = None

        company1 = Company(name="Figma")
        company2 = Company(name="Canva")

        person1 = Person(name="Mike Rodger", works_for=company1)
        person1.owns = [Car(brand="Toyota", model="Camry", year=2020)]

        person2 = Person(name="Ike Loma", works_for=company1)
        person2.owns = [
            Car(brand="Tesla", model="Model S", year=2021),
            Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4),
        ]

        person3 = Person(name="Jason Statham", works_for=company1)

        person4 = Person(name="Mike Broski", works_for=company2)
        person4.owns = [Car(brand="Ford", model="Mustang", year=1978)]

        person5 = Person(name="Christina Mayer", works_for=company2)
        person5.owns = [Car(brand="Honda", model="Civic", year=2023)]

        entities = [company1, company2, person1, person2, person3, person4, person5]

        await add_data_points(entities)

        retriever = GraphCompletionCotRetriever(top_k=20)

        context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))

        print(context)

        assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger"
        assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma"
        assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham"

        answer = await retriever.get_completion("Who works at Figma?")

        assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
        assert all(isinstance(item, str) and item.strip() for item in answer), (
            "Answer must contain only non-empty strings"
        )

    @pytest.mark.asyncio
    async def test_get_graph_completion_cot_context_on_empty_graph(self):
        system_directory_path = os.path.join(
            pathlib.Path(__file__).parent,
            ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
        )
        cognee.config.system_root_directory(system_directory_path)
        data_directory_path = os.path.join(
            pathlib.Path(__file__).parent,
            ".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
        )
        cognee.config.data_root_directory(data_directory_path)

        await cognee.prune.prune_data()
        await cognee.prune.prune_system(metadata=True)

        retriever = GraphCompletionCotRetriever()

        await setup()

        context = await retriever.get_context("Who works at Figma?")
        assert context == [], "Context should be empty on an empty graph"

        answer = await retriever.get_completion("Who works at Figma?")

        assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
        assert all(isinstance(item, str) and item.strip() for item in answer), (
            "Answer must contain only non-empty strings"
        )

    @pytest.mark.asyncio
    async def test_get_structured_completion(self):
        system_directory_path = os.path.join(
            pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion"
        )
        cognee.config.system_root_directory(system_directory_path)
        data_directory_path = os.path.join(
            pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion"
        )
        cognee.config.data_root_directory(data_directory_path)

        await cognee.prune.prune_data()
        await cognee.prune.prune_system(metadata=True)
        await setup()

        class Company(DataPoint):
            name: str

        class Person(DataPoint):
            name: str
            works_for: Company

        company1 = Company(name="Figma")
        person1 = Person(name="Steve Rodger", works_for=company1)

        entities = [company1, person1]
        await add_data_points(entities)

        retriever = GraphCompletionCotRetriever()

        # Test with string response model (default)
        string_answer = await retriever.get_structured_completion("Who works at Figma?")
        assert isinstance(string_answer, str), f"Expected str, got {type(string_answer).__name__}"
        assert string_answer.strip(), "Answer should not be empty"

        # Test with structured response model
        structured_answer = await retriever.get_structured_completion(
            "Who works at Figma?", response_model=TestAnswer
        )
        assert isinstance(structured_answer, TestAnswer), (
            f"Expected TestAnswer, got {type(structured_answer).__name__}"
        )
        assert structured_answer.answer.strip(), "Answer field should not be empty"
        assert structured_answer.explanation.strip(), "Explanation field should not be empty"
