"""Adapters for integrating ChromaSQL with existing ChromaDB clients.

This module provides concrete implementations of ChromaSQL's multi-collection
abstractions for common client types. These adapters let you use ChromaSQL
with existing infrastructure without modifying your client code.

Example: Using with AsyncMultiCollectionQueryClient
----------------------------------------------------

    from chromasql.adapters import (
        MetadataFieldRouter,
    )
    from chromasql.multi_collection import execute_multi_collection
    from idxr.vectorize_lib.query_client import AsyncMultiCollectionQueryClient
    from idxr.query_lib.async_multi_collection_adapter import AsyncMultiCollectionAdapter

    # Setup your existing client
    client = AsyncMultiCollectionQueryClient(
        config_path=Path("query_config.json"),
        client_type="cloud",
        cloud_api_key=api_key,
    )
    await client.connect()

    # Create adapter and router
    adapter = AsyncMultiCollectionAdapter(client)
    router = MetadataFieldRouter(
        client._query_config,
        field_path=("model",),  # Your discriminator field
    )

    # Execute ChromaSQL query with fan-out
    result = await execute_multi_collection(
        query_str="SELECT * FROM demo WHERE metadata.model IN ('Table', 'Field');",
        router=router,
        collection_provider=adapter,
        embed_fn=my_embed_fn,
    )

    await client.close()
"""

from __future__ import annotations

import logging
from typing import Any, Dict, Optional, Sequence, Set

from .analysis import extract_metadata_values
from ._ast_nodes import Query

logger = logging.getLogger(__name__)


class MetadataFieldRouter:
    """Generic metadata-based router for multi-collection queries.

    This router extracts values from a specified metadata field path and uses
    a mapping configuration to determine which collections contain those values.
    It's designed to work with query configs generated by the vectorize_lib
    tooling.

    Parameters
    ----------
    query_config:
        Query configuration dict with model_to_collections mapping.
    field_path:
        Tuple of field names representing the metadata path to extract.
        For example, ``("model",)`` for ``metadata.model`` or
        ``("tenant", "id")`` for ``metadata.tenant.id``.
    fallback_to_all:
        If True (default), return all collections when the field is not
        constrained in the WHERE clause. If False, raise an error instead.

    Example
    -------
    >>> config = load_query_config(Path("query_config.json"))
    >>> router = MetadataFieldRouter(config, field_path=("model",))
    >>> query = parse("SELECT * FROM demo WHERE metadata.model = 'Table';")
    >>> collections = router.route(query)
    >>> print(collections)
    ['collection_00001', 'collection_00003']
    """

    def __init__(
        self,
        query_config: Dict[str, Any],
        field_path: tuple[str, ...],
        *,
        fallback_to_all: bool = True,
    ):
        """Initialize the router.

        Parameters
        ----------
        query_config:
            Query configuration dict (from load_query_config).
        field_path:
            Tuple representing the metadata field path to extract.
        fallback_to_all:
            Whether to query all collections when field is not filtered.
        """
        self.query_config = query_config
        self.field_path = field_path
        self.fallback_to_all = fallback_to_all

        # Validate config structure
        required_keys = ["model_to_collections", "collection_to_models"]
        for key in required_keys:
            if key not in query_config:
                raise ValueError(
                    f"Query config missing required key: {key}. "
                    f"Expected structure from vectorize_lib.query_config"
                )

    def route(self, query: Query) -> Optional[Sequence[str]]:
        """Determine which collections to query based on metadata field.

        Parameters
        ----------
        query:
            Parsed ChromaSQL query.

        Returns
        -------
        Optional[Sequence[str]]
            Collection names to query, or None to query all collections.

        Raises
        ------
        ValueError:
            If field is not constrained and fallback_to_all is False.
        """
        # Extract values from the metadata field path
        field_values = extract_metadata_values(query, field_path=self.field_path)

        if field_values is None:
            # Field not constrained in WHERE clause
            if self.fallback_to_all:
                logger.info(
                    "Metadata field %s not constrained; querying all collections",
                    ".".join(self.field_path),
                )
                return None  # Query all collections
            else:
                raise ValueError(
                    f"Query must filter on metadata.{'.'.join(self.field_path)} "
                    f"(fallback_to_all is disabled)"
                )

        # Map field values to collections
        collections: Set[str] = set()
        model_to_collections = self.query_config["model_to_collections"]

        for value in field_values:
            if value not in model_to_collections:
                logger.warning(
                    "Value %r for field %s not found in query config; skipping",
                    value,
                    ".".join(self.field_path),
                )
                continue

            # Add all collections that contain this value
            value_collections = model_to_collections[value]["collections"]
            collections.update(value_collections)

        if not collections:
            logger.warning(
                "No collections found for %s values: %s; falling back to all",
                ".".join(self.field_path),
                field_values,
            )
            return None if self.fallback_to_all else []

        logger.info(
            "Routed to %d collection(s) for %s IN %s",
            len(collections),
            ".".join(self.field_path),
            field_values,
        )

        return sorted(collections)


class SimpleAsyncClientAdapter:
    """Adapter for simple AsyncHttpClient or AsyncCloudClient usage.

    This adapter wraps a raw ChromaDB async client (without the multi-collection
    query client infrastructure). Useful for developers who want multi-collection
    routing with ChromaSQL but don't have a query_config.json setup.

    Parameters
    ----------
    client:
        AsyncHttpClient or AsyncCloudClient from chromadb.
    collection_names:
        List of all available collection names to query.

    Example
    -------
    >>> import chromadb
    >>> client = await chromadb.AsyncHttpClient(host="localhost", port=8000)
    >>> adapter = SimpleAsyncClientAdapter(
    ...     client=client,
    ...     collection_names=["coll_1", "coll_2", "coll_3"],
    ... )
    """

    def __init__(self, client: Any, collection_names: Sequence[str]):
        """Initialize the adapter.

        Parameters
        ----------
        client:
            Async ChromaDB client.
        collection_names:
            List of collection names available for querying.
        """
        self.client = client
        self.collection_names = list(collection_names)
        self._collection_cache: Dict[str, Any] = {}

    async def get_collection(self, name: str) -> Any:
        """Get a collection by name, with caching.

        Parameters
        ----------
        name:
            Collection name.

        Returns
        -------
        ChromaDB collection object.
        """
        if name in self._collection_cache:
            return self._collection_cache[name]

        collection = await self.client.get_collection(name=name)
        self._collection_cache[name] = collection
        return collection

    async def list_collection_names(self) -> Sequence[str]:
        """List all available collection names.

        Returns
        -------
        Sequence[str]
            Collection names provided at initialization.
        """
        return self.collection_names


__all__ = [
    "MetadataFieldRouter",
    "SimpleAsyncClientAdapter",
]
