"""Vector search using Pinecone and MongoDB with hybrid retrieval and reranking."""

import logging
from typing import Any

from openai import AsyncOpenAI
from pinecone import Pinecone

from wistx_mcp.config import settings
from wistx_mcp.tools.lib.mongodb_client import MongoDBClient

logger = logging.getLogger(__name__)


class VectorSearch:
    """Vector search using Pinecone for semantic search."""

    def __init__(self, mongodb_client: MongoDBClient, openai_api_key: str | None = None):
        """Initialize vector search with hybrid retrieval and reranking.

        Args:
            mongodb_client: MongoDB client instance
            openai_api_key: OpenAI API key for embeddings
        """
        self.mongodb_client = mongodb_client
        self.openai_api_key = openai_api_key or settings.openai_api_key

        if not self.openai_api_key:
            raise ValueError("OpenAI API key is required for vector search")

        if not settings.pinecone_api_key:
            raise ValueError("Pinecone API key is required for vector search. Set PINECONE_API_KEY environment variable.")

        self.embedding_client = AsyncOpenAI(api_key=self.openai_api_key)

        pc = Pinecone(api_key=settings.pinecone_api_key)
        self.index = pc.Index(settings.pinecone_index_name)

        try:
            from api.services.bm25_service import BM25Service
            self.bm25_service = BM25Service(mongodb_client)
        except Exception as e:
            logger.warning("BM25 service not available: %s", e)
            self.bm25_service = None

        try:
            from api.services.reranking_service import RerankingService
            self.reranking_service = RerankingService()
        except Exception as e:
            logger.warning("Reranking service not available: %s", e)
            self.reranking_service = None

    async def _get_query_embedding(self, query: str) -> list[float]:
        """Get embedding for query string.

        Args:
            query: Query string

        Returns:
            Embedding vector

        Raises:
            RuntimeError: If embedding generation fails
            ValueError: If query is empty or invalid
        """
        if not query or not query.strip():
            raise ValueError("Query cannot be empty")

        try:
            response = await self.embedding_client.embeddings.create(
                model="text-embedding-3-small",
                input=query,
            )
            if not response.data or len(response.data) == 0:
                raise RuntimeError("Empty embedding response from OpenAI")
            return response.data[0].embedding
        except Exception as e:
            logger.error("Failed to generate embedding for query: %s", e, exc_info=True)
            raise RuntimeError(f"Failed to generate embedding: {e}") from e

    async def search_compliance(
        self,
        query: str,
        standards: list[str] | None = None,
        severity: str | None = None,
        limit: int = 10,
    ) -> list[dict[str, Any]]:
        """Search compliance controls using Pinecone vector search.

        Args:
            query: Search query
            standards: Filter by compliance standards
            severity: Filter by severity level
            limit: Maximum number of results

        Returns:
            List of compliance controls with full document data

        Raises:
            RuntimeError: If search operation fails
            ValueError: If query is invalid
        """
        if not query or not query.strip():
            raise ValueError("Query cannot be empty")

        if limit <= 0 or limit > 50000:
            raise ValueError("Limit must be between 1 and 50000")

        try:
            query_embedding = await self._get_query_embedding(query)
        except (RuntimeError, ValueError) as e:
            logger.error("Failed to generate embedding: %s", e)
            raise

        filter_dict: dict[str, Any] = {"collection": "compliance_controls"}

        if standards:
            if not isinstance(standards, list):
                raise ValueError("Standards must be a list")
            filter_dict["standard"] = {"$in": standards}

        if severity:
            if severity not in ["CRITICAL", "HIGH", "MEDIUM", "LOW"]:
                raise ValueError(f"Invalid severity: {severity}")
            filter_dict["severity"] = severity

        initial_limit = min(limit * 2, 10000)

        try:
            query_response = self.index.query(
                vector=query_embedding,
                filter=filter_dict if filter_dict else None,
                top_k=initial_limit,
                include_metadata=True,
            )
        except Exception as e:
            logger.error("Pinecone query failed: %s", e, exc_info=True)
            raise RuntimeError(f"Pinecone query failed: {e}") from e

        if not query_response or not hasattr(query_response, "matches"):
            logger.warning("Invalid Pinecone response structure")
            return []

        vector_control_ids = []
        for match in query_response.matches:
            if match.metadata and match.metadata.get("control_id"):
                vector_control_ids.append(match.metadata["control_id"])

        vector_score_map = {
            match.metadata["control_id"]: match.score
            for match in query_response.matches
            if match.metadata and match.metadata.get("control_id")
        }

        bm25_control_ids = []
        bm25_score_map = {}
        if self.bm25_service:
            try:
                await self.mongodb_client.connect()
                if self.mongodb_client.database is None:
                    logger.warning("MongoDB database not available for BM25 compliance search")
                else:
                    collection = self.mongodb_client.database.compliance_controls
                    mongo_filter: dict[str, Any] = {}
                    if standards:
                        mongo_filter["standard"] = {"$in": standards}
                    if severity:
                        mongo_filter["severity"] = severity

                    cursor = collection.find(mongo_filter if mongo_filter else {})
                    max_bm25_load = max(limit * 2, 50000)
                    controls_for_bm25 = await cursor.to_list(length=max_bm25_load)

                    if controls_for_bm25:
                        from rank_bm25 import BM25Okapi
                        import re

                        def tokenize(text: str) -> list[str]:
                            text_lower = text.lower()
                            tokens = re.findall(r"\b[a-z0-9]+\b", text_lower)
                            return tokens

                        tokenized_docs = []
                        control_ids_bm25 = []
                        for control in controls_for_bm25:
                            control_id = control.get("control_id", "")
                            if not control_id:
                                continue

                            contextual_desc = control.get("contextual_description", "")
                            title = control.get("title", "")
                            description = control.get("description", "")
                            requirement = control.get("requirement", "")

                            searchable_text = ""
                            if contextual_desc:
                                searchable_text += contextual_desc + " "
                            searchable_text += f"{title} {description} {requirement}"

                            tokens = tokenize(searchable_text)
                            if tokens:
                                tokenized_docs.append(tokens)
                                control_ids_bm25.append(control_id)

                        if tokenized_docs:
                            bm25_index = BM25Okapi(tokenized_docs)
                            query_tokens = tokenize(query)
                            if query_tokens:
                                scores = bm25_index.get_scores(query_tokens)
                                bm25_results = [
                                    (control_ids_bm25[i], float(scores[i]))
                                    for i in range(len(control_ids_bm25))
                                    if scores[i] > 0
                                ]
                                bm25_results.sort(key=lambda x: x[1], reverse=True)
                                bm25_control_ids = [cid for cid, _ in bm25_results[:initial_limit]]
                                bm25_score_map = {cid: score for cid, score in bm25_results[:initial_limit]}
            except Exception as e:
                logger.warning("BM25 search for compliance controls failed: %s", e)

        combined_control_ids = list(set(vector_control_ids + bm25_control_ids))

        if not combined_control_ids:
            logger.info("No control IDs found in search results for query: %s", query[:50])
            return []

        try:
            await self.mongodb_client.connect()

            if self.mongodb_client.database is None:
                logger.error("MongoDB database is None after connection")
                raise RuntimeError("MongoDB database connection failed")

            collection = self.mongodb_client.database.compliance_controls
            cursor = collection.find({"control_id": {"$in": combined_control_ids}})
            results = await cursor.to_list(length=len(combined_control_ids))

            if len(results) != len(combined_control_ids):
                logger.warning(
                    "Mismatch between search results (%d) and MongoDB results (%d)",
                    len(combined_control_ids),
                    len(results),
                )

            for result in results:
                control_id = result.get("control_id", "")
                vector_score = vector_score_map.get(control_id, 0.0)
                bm25_score = bm25_score_map.get(control_id, 0.0)

                normalized_vector_score = vector_score
                normalized_bm25_score = bm25_score / 10.0 if bm25_score > 0 else 0.0

                hybrid_score = (normalized_vector_score * 0.7) + (normalized_bm25_score * 0.3)
                result["hybrid_score"] = hybrid_score
                result["vector_score"] = vector_score
                result["bm25_score"] = bm25_score

            results.sort(key=lambda x: x.get("hybrid_score", 0), reverse=True)

            if self.reranking_service and results:
                try:
                    results = self.reranking_service.rerank(
                        query=query,
                        articles=results,
                        top_k=limit,
                    )
                except Exception as e:
                    logger.warning("Reranking failed for compliance controls: %s", e)

            final_results = results[:limit]
            
            if len(final_results) > 10000:
                logger.warning(
                    "Large compliance result set: %d results (limit: %d). Performance may be impacted.",
                    len(final_results),
                    limit
                )
            
            return final_results
        except Exception as e:
            logger.error("MongoDB query failed: %s", e, exc_info=True)
            raise RuntimeError(f"MongoDB query failed: {e}") from e

    async def search_knowledge_articles(
        self,
        query: str,
        domains: list[str] | None = None,
        content_types: list[str] | None = None,
        user_id: str | None = None,
        organization_id: str | None = None,
        include_global: bool = True,
        limit: int = 1000,
    ) -> list[dict[str, Any]]:
        """Search knowledge articles using hybrid retrieval (vector + BM25) with reranking.

        Args:
            query: Search query
            domains: Filter by domains
            content_types: Filter by content types
            user_id: User ID for user-specific content
            organization_id: Organization ID for org-shared content
            include_global: Include global/shared content in results
            limit: Maximum number of results

        Returns:
            List of knowledge articles with full document data, reranked by relevance
        """
        query_embedding = await self._get_query_embedding(query)

        filter_dict: dict[str, Any] = {"collection": "knowledge_articles"}

        if domains:
            filter_dict["domain"] = {"$in": domains}

        if content_types:
            filter_dict["content_type"] = {"$in": content_types}

        visibility_filters = []
        if include_global:
            visibility_filters.append("global")
        if user_id:
            visibility_filters.append("user")
            if not include_global:
                filter_dict["user_id"] = user_id
        if organization_id:
            visibility_filters.append("organization")
            if not include_global and not user_id:
                filter_dict["organization_id"] = organization_id

        if visibility_filters:
            filter_dict["visibility"] = {"$in": visibility_filters}

        initial_limit = min(limit * 2, 10000)

        query_response = self.index.query(
            vector=query_embedding,
            filter=filter_dict if filter_dict else None,
            top_k=initial_limit,
            include_metadata=True,
        )

        vector_article_ids = [
            match.metadata["article_id"]
            for match in query_response.matches
            if match.metadata.get("article_id")
        ]

        vector_score_map = {
            match.metadata["article_id"]: match.score
            for match in query_response.matches
            if match.metadata.get("article_id")
        }

        bm25_article_ids = []
        bm25_score_map = {}
        if self.bm25_service:
            try:
                bm25_results = await self.bm25_service.search(query, limit=initial_limit)
                bm25_article_ids = [aid for aid, _ in bm25_results]
                bm25_score_map = {aid: score for aid, score in bm25_results}
            except Exception as e:
                logger.warning("BM25 search failed: %s", e)

        combined_article_ids = list(set(vector_article_ids + bm25_article_ids))

        if not combined_article_ids:
            return []

        await self.mongodb_client.connect()

        if self.mongodb_client.database is None:
            return []

        collection = self.mongodb_client.database.knowledge_articles

        mongo_filter: dict[str, Any] = {"article_id": {"$in": combined_article_ids}}

        visibility_query = []
        if include_global:
            visibility_query.append({"visibility": "global", "user_id": None})
        if user_id:
            visibility_query.append({"visibility": "user", "user_id": user_id})
        if organization_id:
            visibility_query.append({"visibility": "organization", "organization_id": organization_id})

        if visibility_query:
            if len(visibility_query) > 1:
                mongo_filter["$or"] = visibility_query
            else:
                mongo_filter.update(visibility_query[0])

        cursor = collection.find(mongo_filter)
        results = await cursor.to_list(length=len(combined_article_ids))

        for result in results:
            article_id = result.get("article_id", "")
            vector_score = vector_score_map.get(article_id, 0.0)
            bm25_score = bm25_score_map.get(article_id, 0.0)

            normalized_vector_score = vector_score
            normalized_bm25_score = bm25_score / 10.0 if bm25_score > 0 else 0.0

            hybrid_score = (normalized_vector_score * 0.7) + (normalized_bm25_score * 0.3)
            result["hybrid_score"] = hybrid_score
            result["vector_score"] = vector_score
            result["bm25_score"] = bm25_score

        results.sort(key=lambda x: x.get("hybrid_score", 0), reverse=True)

        if self.reranking_service and results:
            try:
                results = self.reranking_service.rerank(
                    query=query,
                    articles=results,
                    top_k=limit,
                )
            except Exception as e:
                logger.warning("Reranking failed: %s", e)

        final_results = results[:limit]
        
        if len(final_results) > 10000:
            logger.warning(
                "Large knowledge articles result set: %d results (limit: %d). Performance may be impacted.",
                len(final_results),
                limit
            )
        
        return final_results

    async def search_code_examples(
        self,
        query: str,
        code_types: list[str] | None = None,
        cloud_provider: str | None = None,
        services: list[str] | None = None,
        min_quality_score: int | None = None,
        compliance_standard: str | None = None,
        limit: int = 1000,
    ) -> list[dict[str, Any]]:
        """Search code examples using hybrid retrieval (vector + BM25) with reranking.
        
        Args:
            query: Search query
            code_types: Filter by code types (terraform, kubernetes, docker, etc.)
            cloud_provider: Filter by cloud provider (aws, gcp, azure)
            services: Filter by cloud services (rds, s3, ec2, etc.)
            min_quality_score: Minimum quality score (0-100)
            compliance_standard: Filter by compliance standard (requires compliance mappings)
            limit: Maximum number of results
            
        Returns:
            List of code examples with full document data, reranked by relevance
        """
        if not query or not query.strip():
            raise ValueError("Query cannot be empty")
        
        if limit <= 0 or limit > 50000:
            raise ValueError("Limit must be between 1 and 50000")
        
        try:
            query_embedding = await self._get_query_embedding(query)
        except (RuntimeError, ValueError) as e:
            logger.error("Failed to generate embedding: %s", e)
            raise
        
        filter_dict: dict[str, Any] = {"collection": "code_examples"}
        
        if code_types:
            if not isinstance(code_types, list):
                raise ValueError("code_types must be a list")
            filter_dict["infrastructure_type"] = {"$in": code_types}
        
        if cloud_provider:
            if cloud_provider not in ["aws", "gcp", "azure", "oracle", "alibaba"]:
                raise ValueError(f"Invalid cloud provider: {cloud_provider}")
            filter_dict["cloud_provider"] = cloud_provider
        
        if services:
            if not isinstance(services, list):
                raise ValueError("services must be a list")
            filter_dict["services"] = {"$in": services}
        
        if min_quality_score is not None:
            filter_dict["quality_score"] = {"$gte": min_quality_score}
        
        initial_limit = min(limit * 2, 10000)
        
        try:
            query_response = self.index.query(
                vector=query_embedding,
                filter=filter_dict if filter_dict else None,
                top_k=initial_limit,
                include_metadata=True,
            )
        except Exception as e:
            logger.error("Pinecone query failed: %s", e, exc_info=True)
            raise RuntimeError(f"Pinecone query failed: {e}") from e
        
        if not query_response or not hasattr(query_response, "matches"):
            logger.warning("Invalid Pinecone response structure")
            return []
        
        vector_example_ids = []
        for match in query_response.matches:
            if match.metadata and match.metadata.get("example_id"):
                vector_example_ids.append(match.metadata["example_id"])
        
        vector_score_map = {
            match.metadata["example_id"]: match.score
            for match in query_response.matches
            if match.metadata and match.metadata.get("example_id")
        }
        
        bm25_example_ids = []
        bm25_score_map = {}
        
        if self.bm25_service:
            try:
                await self.mongodb_client.connect()
                if self.mongodb_client.database is None:
                    logger.warning("MongoDB database not available for BM25 code examples search")
                else:
                    collection = self.mongodb_client.database.code_examples
                    mongo_filter: dict[str, Any] = {}
                    
                    if code_types:
                        mongo_filter["code_type"] = {"$in": code_types}
                    if cloud_provider:
                        mongo_filter["cloud_provider"] = cloud_provider
                    if services:
                        mongo_filter["services"] = {"$in": services}
                    if min_quality_score is not None:
                        mongo_filter["quality_score"] = {"$gte": min_quality_score}
                    
                    cursor = collection.find(mongo_filter if mongo_filter else {})
                    max_bm25_load = max(limit * 2, 50000)
                    examples_for_bm25 = await cursor.to_list(length=max_bm25_load)
                    
                    if examples_for_bm25:
                        from rank_bm25 import BM25Okapi
                        import re
                        
                        def tokenize(text: str) -> list[str]:
                            text_lower = text.lower()
                            tokens = re.findall(r"\b[a-z0-9]+\b", text_lower)
                            return tokens
                        
                        tokenized_docs = []
                        example_ids_bm25 = []
                        
                        for example in examples_for_bm25:
                            example_id = example.get("example_id", "")
                            if not example_id:
                                continue
                            
                            contextual_desc = example.get("contextual_description", "")
                            title = example.get("title", "")
                            description = example.get("description", "")
                            code = example.get("code", "")[:2000]
                            
                            searchable_text = ""
                            if contextual_desc:
                                searchable_text += contextual_desc + " "
                            searchable_text += f"{title} {description} {code}"
                            
                            tokens = tokenize(searchable_text)
                            if tokens:
                                tokenized_docs.append(tokens)
                                example_ids_bm25.append(example_id)
                        
                        if tokenized_docs:
                            bm25_index = BM25Okapi(tokenized_docs)
                            query_tokens = tokenize(query)
                            if query_tokens:
                                scores = bm25_index.get_scores(query_tokens)
                                bm25_results = [
                                    (example_ids_bm25[i], float(scores[i]))
                                    for i in range(len(example_ids_bm25))
                                    if scores[i] > 0
                                ]
                                bm25_results.sort(key=lambda x: x[1], reverse=True)
                                bm25_example_ids = [eid for eid, _ in bm25_results[:initial_limit]]
                                bm25_score_map = {eid: score for eid, score in bm25_results[:initial_limit]}
            except Exception as e:
                logger.warning("BM25 search for code examples failed: %s", e)
        
        combined_example_ids = list(set(vector_example_ids + bm25_example_ids))
        
        if compliance_standard:
            try:
                await self.mongodb_client.connect()
                if self.mongodb_client.database is not None:
                    mappings_collection = self.mongodb_client.database.code_example_compliance_mappings
                    compliance_filter = {
                        "standard": compliance_standard,
                        "implementation_status": "implemented",
                        "example_id": {"$in": combined_example_ids} if combined_example_ids else {},
                    }
                    cursor = mappings_collection.find(compliance_filter)
                    compliant_mappings = await cursor.to_list(length=None)
                    compliant_example_ids = [m.get("example_id") for m in compliant_mappings if m.get("example_id")]
                    
                    if compliant_example_ids:
                        combined_example_ids = [eid for eid in combined_example_ids if eid in compliant_example_ids]
                    else:
                        combined_example_ids = []
            except Exception as e:
                logger.warning("Compliance filtering failed: %s", e)
        
        if not combined_example_ids:
            logger.info("No example IDs found in search results for query: %s", query[:50])
            return []
        
        try:
            await self.mongodb_client.connect()
            
            if self.mongodb_client.database is None:
                logger.error("MongoDB database is None after connection")
                raise RuntimeError("MongoDB database connection failed")
            
            collection = self.mongodb_client.database.code_examples
            cursor = collection.find({"example_id": {"$in": combined_example_ids}})
            results = await cursor.to_list(length=len(combined_example_ids))
            
            if len(results) != len(combined_example_ids):
                logger.warning(
                    "Mismatch between search results (%d) and MongoDB results (%d)",
                    len(combined_example_ids),
                    len(results),
                )
            
            for result in results:
                example_id = result.get("example_id", "")
                vector_score = vector_score_map.get(example_id, 0.0)
                bm25_score = bm25_score_map.get(example_id, 0.0)
                
                normalized_vector_score = vector_score
                normalized_bm25_score = bm25_score / 10.0 if bm25_score > 0 else 0.0
                
                hybrid_score = (normalized_vector_score * 0.7) + (normalized_bm25_score * 0.3)
                result["hybrid_score"] = hybrid_score
                result["vector_score"] = vector_score
                result["bm25_score"] = bm25_score
            
            results.sort(key=lambda x: x.get("hybrid_score", 0), reverse=True)
            
            if self.reranking_service and results:
                try:
                    results = self.reranking_service.rerank(
                        query=query,
                        articles=results,
                        top_k=limit,
                    )
                except Exception as e:
                    logger.warning("Reranking failed for code examples: %s", e)
            
            final_results = results[:limit]
            
            if len(final_results) > 10000:
                logger.warning(
                    "Large code examples result set: %d results (limit: %d). Performance may be impacted.",
                    len(final_results),
                    limit
                )
            
            return final_results
        except Exception as e:
            logger.error("MongoDB query failed: %s", e, exc_info=True)
            raise RuntimeError(f"MongoDB query failed: {e}") from e

