
"""
Document embeddings query service.  Input is vector, output is an array
of chunks.  Pinecone implementation.
"""

from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig

import uuid
import os

from .... base import DocumentEmbeddingsQueryService

default_ident = "de-query"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")

class Processor(DocumentEmbeddingsQueryService):

    def __init__(self, **params):

        self.url = params.get("url", None)
        self.api_key = params.get("api_key", default_api_key)

        if self.api_key is None or self.api_key == "not-specified":
            raise RuntimeError("Pinecone API key must be specified")

        if self.url:

            self.pinecone = PineconeGRPC(
                api_key = self.api_key,
                host = self.url
            )

        else:

            self.pinecone = Pinecone(api_key = self.api_key)

        super(Processor, self).__init__(
            **params | {
                "url": self.url,
                "api_key": self.api_key,
            }
        )

    async def query_document_embeddings(self, msg):

        try:

            # Handle zero limit case
            if msg.limit <= 0:
                return []

            chunks = []

            for vec in msg.vectors:

                dim = len(vec)

                index_name = (
                    "d-" + msg.user + "-" + msg.collection + "-" + str(dim)
                )

                index = self.pinecone.Index(index_name)

                results = index.query(
                    vector=vec,
                    top_k=msg.limit,
                    include_values=False,
                    include_metadata=True
                )

                for r in results.matches:
                    doc = r.metadata["doc"]
                    chunks.append(doc)

            return chunks

        except Exception as e:

            print(f"Exception: {e}")
            raise e

    @staticmethod
    def add_args(parser):

        DocumentEmbeddingsQueryService.add_args(parser)

        parser.add_argument(
            '-a', '--api-key',
            default=default_api_key,
            help='Pinecone API key. (default from PINECONE_API_KEY)'
        )

        parser.add_argument(
            '-u', '--url',
            help='Pinecone URL.  If unspecified, serverless is used'
        )

def run():

    Processor.launch(default_ident, __doc__)

