"""Package search service - search indexed packages."""

import asyncio
import logging
import re
from typing import Any

from openai import AsyncOpenAI

from wistx_mcp.config import settings
from wistx_mcp.tools.lib.mongodb_client import MongoDBClient
from wistx_mcp.tools.lib.package_indexing_service import PackageIndexingService
from wistx_mcp.tools.lib.package_registry_integrator import RegistryIntegratorFactory
from wistx_mcp.tools.lib.pattern_templates import PatternTemplates
from wistx_mcp.tools.lib.pattern_validator import PatternValidator

logger = logging.getLogger(__name__)


class PackageSearchService:
    """Service for searching indexed packages."""

    def __init__(self, mongodb_client: MongoDBClient):
        """Initialize package search service.

        Args:
            mongodb_client: MongoDB client instance
        """
        self.mongodb_client = mongodb_client
        self.indexing_service = PackageIndexingService(mongodb_client)
        self.embedding_client = AsyncOpenAI(api_key=settings.openai_api_key) if settings.openai_api_key else None
        self.pattern_validator = PatternValidator()

        if not self.embedding_client:
            logger.warning("OpenAI API key not set, semantic search will be limited")

    async def semantic_search(
        self,
        query: str,
        registry: str | None = None,
        domain: str | None = None,
        category: str | None = None,
        limit: int = 20,
    ) -> list[dict[str, Any]]:
        """Semantic search across indexed packages.

        Args:
            query: Search query
            registry: Optional registry filter
            domain: Optional domain filter
            category: Optional category filter
            limit: Maximum results

        Returns:
            List of matching packages
        """
        if not self.embedding_client:
            return await self._keyword_search(query, registry, domain, category, limit)

        try:
            query_embedding = await self._generate_embedding(query)

            from pinecone import Pinecone

            if not settings.pinecone_api_key:
                return await self._keyword_search(query, registry, domain, category, limit)

            pc = Pinecone(api_key=settings.pinecone_api_key)
            index = pc.Index(settings.pinecone_index_name)

            filter_dict: dict[str, Any] = {}
            if registry:
                filter_dict["registry"] = registry
            if domain:
                filter_dict["domain_tags"] = {"$in": [domain]}
            if category:
                filter_dict["category"] = category

            query_response = index.query(
                vector=query_embedding,
                filter=filter_dict if filter_dict else None,
                top_k=limit * 2,
                include_metadata=True,
            )

            package_ids = []
            score_map = {}
            for match in query_response.matches:
                package_id = match.metadata.get("package_id", "")
                if package_id:
                    package_ids.append(package_id)
                    score_map[package_id] = match.score

            if not package_ids:
                return []

            await self.mongodb_client.connect()
            if self.mongodb_client.database is None:
                return []

            collection = self.mongodb_client.database.packages
            cursor = collection.find({"package_id": {"$in": package_ids}})
            packages = await cursor.to_list(length=len(package_ids))

            for package in packages:
                package_id = package.get("package_id", "")
                package["vector_score"] = score_map.get(package_id, 0.0)

            packages.sort(key=lambda x: x.get("vector_score", 0), reverse=True)
            return packages[:limit]
        except Exception as e:
            logger.warning("Semantic search failed, falling back to keyword search: %s", e)
            return await self._keyword_search(query, registry, domain, category, limit)

    async def regex_search(
        self,
        pattern: str | None = None,
        template: str | None = None,
        registry: str | None = None,
        package_name: str | None = None,
        limit: int = 20,
        allow_unindexed: bool = True,
    ) -> list[dict[str, Any]]:
        """Regex search across package source code.

        Supports both indexed packages and on-demand fetching (no indexing required).

        Args:
            pattern: Regex pattern
            template: Pre-built template name
            registry: Optional registry filter
            package_name: Optional specific package
            limit: Maximum results
            allow_unindexed: If True, fetch packages on-demand if not indexed

        Returns:
            List of matches with package and file information
        """
        if template:
            pattern_str = PatternTemplates.get_template(template)
            if not pattern_str:
                raise ValueError(f"Invalid template: {template}")
        elif pattern:
            pattern_str = pattern
        else:
            raise ValueError("Either pattern or template must be provided")

        validation = await self.pattern_validator.validate_pattern(pattern_str)
        if not validation["valid"]:
            raise ValueError(f"Invalid regex pattern: {validation.get('error')}")

        compiled_pattern = re.compile(pattern_str, re.IGNORECASE | re.MULTILINE)

        matches = []
        packages_to_search = []

        if package_name:
            if not registry:
                raise ValueError("registry is required when package_name is specified")

            await self.mongodb_client.connect()
            is_indexed = False
            package_doc = None

            if self.mongodb_client.database is not None:
                collection = self.mongodb_client.database.packages
                package_doc = await collection.find_one({"name": package_name, "registry": registry})

            if package_doc:
                packages_to_search.append(package_doc)
                is_indexed = True
            elif allow_unindexed:
                logger.info("Package %s:%s not indexed, fetching on-demand", registry, package_name)
                try:
                    integrator = RegistryIntegratorFactory.create(registry)
                    metadata = await integrator.get_package_metadata(package_name)
                    await integrator.close()

                    package_doc = {
                        "package_id": f"{registry}:{package_name}",
                        "registry": registry,
                        "name": package_name,
                        "metadata": metadata,
                    }
                    packages_to_search.append(package_doc)
                except Exception as e:
                    logger.warning("Failed to fetch metadata for unindexed package %s:%s: %s", registry, package_name, e)
                    raise RuntimeError(f"Package {package_name} not found in registry {registry}") from e
            else:
                raise ValueError(f"Package {package_name} not indexed. Set allow_unindexed=True to fetch on-demand.")
        else:
            await self.mongodb_client.connect()
            if self.mongodb_client.database is not None:
                collection = self.mongodb_client.database.packages
                filter_dict: dict[str, Any] = {}
                if registry:
                    filter_dict["registry"] = registry
                cursor = collection.find(filter_dict).limit(100)
                packages_to_search = await cursor.to_list(length=100)

        for package_doc in packages_to_search:
            try:
                registry_name = package_doc.get("registry", "pypi")
                package_name_val = package_doc.get("name", "")

                integrator = RegistryIntegratorFactory.create(registry_name)
                source_files = await integrator.get_package_source(package_name_val)
                await integrator.close()

                if not source_files:
                    logger.warning("No source files found for package %s:%s", registry_name, package_name_val)
                    continue

                import hashlib

                for file_path, content in source_files.items():
                    file_hash = hashlib.sha256(file_path.encode()).hexdigest()
                    for match in compiled_pattern.finditer(content):
                        line_number = content[:match.start()].count("\n") + 1
                        matches.append({
                            "package_id": package_doc.get("package_id", f"{registry_name}:{package_name_val}"),
                            "package_name": package_name_val,
                            "registry": registry_name,
                            "file_path": file_path,
                            "filename_sha256": file_hash,
                            "line_number": line_number,
                            "match_text": match.group(),
                            "context": self._extract_context(content, match.start(), match.end()),
                        })

                        if len(matches) >= limit:
                            break

                if len(matches) >= limit:
                    break
            except Exception as e:
                logger.warning("Failed to search package %s: %s", package_doc.get("name"), e)
                continue

        return matches[:limit]

    async def hybrid_search(
        self,
        query: str,
        pattern: str | None = None,
        template: str | None = None,
        registry: str | None = None,
        domain: str | None = None,
        category: str | None = None,
        limit: int = 20,
    ) -> dict[str, Any]:
        """Hybrid search combining semantic and regex.

        Args:
            query: Natural language query
            pattern: Optional regex pattern
            template: Optional template name
            registry: Optional registry filter
            domain: Optional domain filter
            category: Optional category filter
            limit: Maximum results

        Returns:
            Dictionary with semantic and regex results
        """
        semantic_results = await self.semantic_search(query, registry, domain, category, limit)

        regex_results = []
        if pattern or template:
            try:
                regex_results = await self.regex_search(pattern, template, registry, limit=limit, allow_unindexed=True)
            except Exception as e:
                logger.warning("Regex search failed in hybrid search: %s", e)

        combined_results = self._combine_results(semantic_results, regex_results, limit)

        return {
            "packages": combined_results,
            "semantic_count": len(semantic_results),
            "regex_count": len(regex_results),
            "total": len(combined_results),
        }

    async def _keyword_search(
        self,
        query: str,
        registry: str | None = None,
        domain: str | None = None,
        category: str | None = None,
        limit: int = 20,
    ) -> list[dict[str, Any]]:
        """Keyword search fallback.

        Args:
            query: Search query
            registry: Optional registry filter
            domain: Optional domain filter
            category: Optional category filter
            limit: Maximum results

        Returns:
            List of matching packages
        """
        await self.mongodb_client.connect()
        if self.mongodb_client.database is None:
            return []

        collection = self.mongodb_client.database.packages
        filter_dict: dict[str, Any] = {
            "$text": {"$search": query},
        }

        if registry:
            filter_dict["registry"] = registry
        if domain:
            filter_dict["domain_tags"] = {"$in": [domain]}
        if category:
            filter_dict["category"] = category

        cursor = collection.find(filter_dict).limit(limit)
        return await cursor.to_list(length=limit)

    def _extract_context(self, content: str, start: int, end: int, context_lines: int = 3) -> str:
        """Extract context around match.

        Args:
            content: File content
            start: Match start position
            end: Match end position
            context_lines: Number of context lines

        Returns:
            Context string
        """
        lines = content.split("\n")
        start_line = content[:start].count("\n")
        end_line = content[:end].count("\n")

        context_start = max(0, start_line - context_lines)
        context_end = min(len(lines), end_line + context_lines + 1)

        context = "\n".join(lines[context_start:context_end])
        return context

    def _combine_results(
        self,
        semantic_results: list[dict[str, Any]],
        regex_results: list[dict[str, Any]],
        limit: int,
    ) -> list[dict[str, Any]]:
        """Combine semantic and regex results.

        Args:
            semantic_results: Semantic search results
            regex_results: Regex search results
            limit: Maximum results

        Returns:
            Combined results
        """
        combined = {}
        for result in semantic_results:
            package_id = result.get("package_id", "")
            if package_id:
                combined[package_id] = {
                    **result,
                    "match_type": "semantic",
                }

        for result in regex_results:
            package_id = result.get("package_id", "")
            if package_id in combined:
                combined[package_id]["match_type"] = "both"
                if "regex_matches" not in combined[package_id]:
                    combined[package_id]["regex_matches"] = []
                combined[package_id]["regex_matches"].append(result)
            else:
                combined[package_id] = {
                    "package_id": package_id,
                    "package_name": result.get("package_name"),
                    "registry": result.get("registry"),
                    "match_type": "regex",
                    "regex_matches": [result],
                }

        results = list(combined.values())
        results.sort(key=lambda x: (
            x.get("vector_score", 0) if x.get("match_type") in ["semantic", "both"] else 0,
            len(x.get("regex_matches", [])) if x.get("match_type") in ["regex", "both"] else 0,
        ), reverse=True)

        return results[:limit]

    async def _generate_embedding(self, text: str) -> list[float]:
        """Generate embedding for text.

        Args:
            text: Text to embed

        Returns:
            Embedding vector
        """
        if not self.embedding_client:
            raise ValueError("Embedding client not available")

        response = await self.embedding_client.embeddings.create(
            model="text-embedding-3-small",
            input=text,
        )
        return response.data[0].embedding

