from __future__ import annotations

import json
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Optional, Dict, Any, List

from upsonic.context.task import turn_task_to_string
from upsonic.context.sources import TaskOutputSource
from upsonic.schemas.data_models import RAGSearchResult

if TYPE_CHECKING:
    from upsonic.agent.agent import Direct
    from upsonic.graph.graph import State
    from upsonic.agent.context_managers.memory_manager import MemoryManager
    from upsonic.tasks.tasks import Task
    from upsonic.knowledge_base.knowledge_base import KnowledgeBase


class ContextManager:
    """
    A context manager for building the dynamic, task-specific context prompt.

    This manager is responsible for aggregating all situational data relevant
    to the current task.
    
    Features:
    - Advanced search strategy support (dense, full-text, hybrid)
    - Configurable query parameters for optimal retrieval
    - Error handling and logging
    - Support for metadata filtering and advanced query options
    """

    def __init__(self, agent: "Direct", task: "Task", state: Optional[State] = None):
        """
        Initializes the ContextManager.

        Args:
            agent: The parent `Direct` agent instance.
            task: The `Task` object for the current operation.
            state: An optional `State` object from a `Graph` execution.
                   This is essential for resolving `TaskOutputSource` context.
        """
        self.agent = agent
        self.task = task
        self.state = state
        self.context_prompt: str = ""

    async def _build_context_prompt(self, memory_handler: Optional[MemoryManager]) -> str:
        """
        Asynchronously builds the complete contextual prompt string.

        This method now fully supports graph-specific context by resolving
        `TaskOutputSource` objects into concrete data from the provided
        `state`, in addition to its existing responsibilities.

        Vector database search capabilities:
        - Configurable search strategies (dense, full-text, hybrid)
        - Metadata filtering support
        - Query-time parameter optimization
        - Error handling and fallback strategies

        Returns:
            A formatted string containing all relevant situational context.
        """
        final_context_parts = []

        if memory_handler:
            context_injection = memory_handler.get_context_injection()
            if context_injection:
                final_context_parts.append(context_injection)

        if self.task.context:

            knowledge_base_parts = []
            task_parts = []
            previous_task_output_parts = []
            additional_parts = []

            for item in self.task.context:
                from upsonic.tasks.tasks import Task
                from upsonic.knowledge_base.knowledge_base import KnowledgeBase
                if isinstance(item, Task):  
                    task_parts.append(f"Task ID ({item.get_task_id()}): " + turn_task_to_string(item))
                
                elif isinstance(item, KnowledgeBase): 
                    await self._process_knowledge_base_item(
                        item, 
                        knowledge_base_parts, 
                        self.task.description
                    )

                elif isinstance(item, str):
                    additional_parts.append(item)

                elif isinstance(item, TaskOutputSource) and self.state:  
                    await self._process_task_output_source(
                        item, 
                        previous_task_output_parts
                    )

            if task_parts:
                final_context_parts.append("<Tasks>\n" + "\n".join(task_parts) + "\n</Tasks>")
            if knowledge_base_parts:
                final_context_parts.append("<Knowledge Base>\n" + "\n".join(knowledge_base_parts) + "\n</Knowledge Base>")
            if previous_task_output_parts:
                final_context_parts.extend(previous_task_output_parts)
            if additional_parts:
                final_context_parts.append("<Additional Context>\n" + "\n".join(additional_parts) + "\n</Additional Context>")

        if not final_context_parts:
            return ""
        
        return "<Context>\n" + "\n\n".join(final_context_parts) + "\n</Context>"

    async def _process_knowledge_base_item(
        self, 
        knowledge_base: "KnowledgeBase", 
        knowledge_base_parts: List[str], 
        query: str
    ) -> None:
        """
        Process a KnowledgeBase item.
        
        Args:
            knowledge_base: The KnowledgeBase instance to query
            knowledge_base_parts: List to append results to
            query: The query string for retrieval
        """
        try:
            if knowledge_base.rag:
                await knowledge_base.setup_rag(self.agent)
                
                rag_results = await knowledge_base.query_async(query)
                
                if rag_results:
                    formatted_results = self._format_rag_results(rag_results, knowledge_base)
                    knowledge_base_parts.append(formatted_results)
                else:
                    print(f"Warning: No results found for KnowledgeBase '{knowledge_base.name}' with query: '{query}'")
            else:
                knowledge_base_parts.append(knowledge_base.markdown())
                
        except Exception as e:
            print(f"Error processing KnowledgeBase '{knowledge_base.name}': {str(e)}")
            try:
                knowledge_base_parts.append(knowledge_base.markdown())
            except Exception as fallback_error:
                print(f"Fallback also failed for KnowledgeBase '{knowledge_base.name}': {str(fallback_error)}")

    def _format_rag_results(self, rag_results: List[RAGSearchResult], knowledge_base: "KnowledgeBase") -> str:
        """
        Format RAG results with enhanced context and metadata.
        
        Args:
            rag_results: List of RAGSearchResult objects containing text and metadata
            knowledge_base: The KnowledgeBase instance
            
        Returns:
            Formatted string with RAG results including metadata
        """
        if not rag_results:
            return ""
        
        kb_info = f"Source: {knowledge_base.name}"
        if hasattr(knowledge_base, 'get_config_summary'):
            try:
                config = knowledge_base.get_config_summary()
                vector_db_info = config.get('vectordb', {})
                if isinstance(vector_db_info, dict):
                    provider = vector_db_info.get('class', 'Unknown')
                    kb_info += f" (Vector DB: {provider})"
            except Exception:
                pass
        
        formatted_chunks = []
        for i, result in enumerate(rag_results, 1):
            cleaned_text = result.text.strip()
            metadata_str = ""
            if result.metadata:
                source = result.metadata.get('source', 'Unknown')
                page_number = result.metadata.get('page_number', 'Unknown')
                chunk_id = result.chunk_id or result.metadata.get('chunk_id', 'Unknown')

                retrieved_keys = {'source', 'page_number', 'chunk_id'}
                metadata_parts = [f"source: {source}"]
                if page_number is not None:
                    metadata_parts.append(f"page: {page_number}")
                if chunk_id:
                    metadata_parts.append(f"chunk_id: {chunk_id}")
                if result.score is not None:
                    metadata_parts.append(f"score: {result.score:.3f}")

                for k, v in result.metadata.items():
                    if k not in retrieved_keys:
                        metadata_parts.append(f"{k}: {v}")

                metadata_str = f" [metadata: {', '.join(metadata_parts)}]"

            formatted_chunks.append(f"[{i}]{metadata_str} {cleaned_text}")

        return f"<rag source='{kb_info}'>{' '.join(formatted_chunks)}</rag>"

    async def _process_task_output_source(
        self, 
        item: TaskOutputSource, 
        previous_task_output_parts: List[str]
    ) -> None:
        """
        Process a TaskOutputSource item with error handling.
        
        Args:
            item: The TaskOutputSource instance
            previous_task_output_parts: List to append results to
        """
        try:
            source_output = self.state.get_task_output(item.task_description_or_id)
            
            if source_output is not None:
                output_str = self._format_task_output(source_output)
                
                previous_task_output_parts.append(
                    f"<PreviousTaskNodeOutput id='{item.task_description_or_id}'>\n{output_str}\n</PreviousTaskNodeOutput>"
                )
            else:
                print(f"Warning: No output found for task '{item.task_description_or_id}'")
                
        except Exception as e:
            print(f"Error processing TaskOutputSource '{item.task_description_or_id}': {str(e)}")

    def _format_task_output(self, source_output: Any) -> str:
        """
        Format task output with serialization.
        
        Args:
            source_output: The task output object
            
        Returns:
            Formatted string representation
        """
        try:
            if hasattr(source_output, 'model_dump_json'):
                return source_output.model_dump_json(indent=2)
            elif hasattr(source_output, 'model_dump'):
                return json.dumps(source_output.model_dump(), default=str, indent=2)
            elif hasattr(source_output, 'to_dict'):
                return json.dumps(source_output.to_dict(), default=str, indent=2)
            elif hasattr(source_output, '__dict__'):
                return json.dumps(source_output.__dict__, default=str, indent=2)
            else:
                return str(source_output)
        except Exception as e:
            print(f"Error formatting task output: {str(e)}")
            return str(source_output)

    def get_context_prompt(self) -> str:
        """Public getter to retrieve the constructed context prompt."""
        return self.context_prompt

    @asynccontextmanager
    async def manage_context(self, memory_handler: Optional[MemoryManager] = None):
        """The asynchronous context manager for building the task-specific context."""
        self.context_prompt = await self._build_context_prompt(memory_handler)
        self.task.context_formatted = self.context_prompt
            
        try:
            yield self
        finally:
            pass

    async def get_knowledge_base_health_status(self) -> Dict[str, Any]:
        """
        Get health status of all KnowledgeBase instances in the context.
        
        Returns:
            Dictionary containing health status of all KnowledgeBase components
        """
        health_status = {}
        
        if self.task.context:
            for item in self.task.context:
                if isinstance(item, "KnowledgeBase"):
                    try:
                        health_status[item.name] = await item.health_check_async()
                    except Exception as e:
                        health_status[item.name] = {
                            "healthy": False,
                            "error": str(e)
                        }
        
        return health_status

    def get_context_summary(self) -> Dict[str, Any]:
        """
        Get a comprehensive summary of the current context configuration.
        
        Returns:
            Dictionary containing detailed context summary information
        """
        summary = {
            "task": {
                "id": self.task.get_task_id() if hasattr(self.task, 'get_task_id') else "unknown",
                "description": self.task.description,
                "attachments": self.task.attachments,
                "response_format": str(self.task.response_format) if self.task.response_format else "str",
                "response_lang": self.task.response_lang,
                "not_main_task": self.task.not_main_task,
                "start_time": self.task.start_time,
                "end_time": self.task.end_time,
                "duration": self.task.duration,
                "price_id": self.task.price_id,
                "total_cost": self.task.total_cost,
                "total_input_tokens": self.task.total_input_token,
                "total_output_tokens": self.task.total_output_token,
                "tool_calls_count": len(self.task.tool_calls) if self.task.tool_calls else 0
            },
            "context": {
                "items_count": len(self.task.context) if self.task.context else 0,
                "knowledge_bases": [],
                "tasks": [],
                "task_output_sources": [],
                "additional_contexts": 0,
                "context_formatted": self.task.context_formatted is not None
            },
            "agent": {
                "id": self.agent.agent_id,
                "name": self.agent.name,
                "debug": self.agent.debug,
                "retry": self.agent.retry,
                "mode": self.agent.mode,
                "show_tool_calls": self.agent.show_tool_calls,
                "tool_call_limit": self.agent.tool_call_limit,
                "enable_thinking_tool": self.agent.enable_thinking_tool,
                "enable_reasoning_tool": self.agent.enable_reasoning_tool,
                "has_memory": self.agent.memory is not None,
                "has_knowledge": self.agent.knowledge is not None,
                "has_canvas": self.agent.canvas is not None
            },
            "state": {
                "available": self.state is not None
            }
        }
        
        if self.task.context:
            for item in self.task.context:
                if isinstance(item, "KnowledgeBase"):
                    kb_info = {
                        "name": item.name,
                        "type": "rag" if item.rag else "static",
                        "is_ready": getattr(item, '_is_ready', False),
                        "knowledge_id": getattr(item, 'knowledge_id', 'unknown'),
                        "sources_count": len(item.sources) if hasattr(item, 'sources') else 0
                    }
                    
                    
                    if hasattr(item, 'vector_db') and hasattr(item.vector_db, 'get_config_summary'):
                        try:
                            kb_info["vector_db"] = item.vector_db.get_config_summary()
                        except Exception:
                            kb_info["vector_db"] = {"provider": item.vector_db.__class__.__name__}
                    
                    if hasattr(item, 'get_collection_info_async'):
                        try:
                            kb_info["collection_info_available"] = True
                        except Exception:
                            kb_info["collection_info_available"] = False
                    
                    summary["context"]["knowledge_bases"].append(kb_info)
                    
                elif isinstance(item, "Task"):
                    summary["context"]["tasks"].append({
                        "id": item.get_task_id() if hasattr(item, 'get_task_id') else "unknown",
                        "description": item.description,
                        "not_main_task": item.not_main_task,
                        "has_response": item.response is not None,
                        "has_attachments": bool(item.attachments),
                        "tools_count": len(item.tools) if item.tools else 0
                    })
                    
                elif isinstance(item, TaskOutputSource):
                    summary["context"]["task_output_sources"].append({
                        "task_id": item.task_description_or_id,
                        "retrieval_mode": item.retrieval_mode,
                        "enabled": item.enabled,
                        "source_id": item.source_id
                    })
                    
                elif isinstance(item, str):
                    summary["context"]["additional_contexts"] += 1
        
        return summary