"""See the [`litellm` documention](https://docs.litellm.ai/docs/embedding/supported_embedding)."""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../pts/api/llm/08_embeddings.pct.py.

# %% auto 0
__all__ = ['embedding', 'sig', 'async_embedding', 'batch_embeddings', 'async_batch_embeddings']

# %% ../../../pts/api/llm/08_embeddings.pct.py 3
import inspect
from inspect import Parameter
import asyncio
try:
    import litellm
    import functools
    from adulib.llm._utils import _llm_func_factory, _llm_async_func_factory
    from adulib.llm.tokens import token_counter
except ImportError as e:
    raise ImportError(f"Install adulib[llm] to use this API.") from e

# %% ../../../pts/api/llm/08_embeddings.pct.py 7
embedding = _llm_func_factory(
    func=litellm.embedding,
    func_name="embedding",
    func_cache_name="embedding",
    retrieve_log_data=lambda model, func_kwargs, response, cache_args: {
        "method": "embedding",
        "input_tokens": sum([token_counter(model=model, text=inp, **cache_args) for inp in func_kwargs['input']]),
        "output_tokens": None,
        "cost": response._hidden_params['response_cost'],
    }
)

embedding.__doc__ = """
This function is a wrapper around a corresponding function in the `litellm` library, see [this](https://docs.litellm.ai/docs/embedding/supported_embedding) for a full list of the available arguments.
""".strip()
sig = inspect.signature(embedding)
sig = sig.replace(parameters=[
    Parameter("model", Parameter.POSITIONAL_OR_KEYWORD, annotation=str),
    Parameter("input", Parameter.POSITIONAL_OR_KEYWORD, annotation=list[str]),
    *sig.parameters.values()
])
embedding.__signature__ = sig

# %% ../../../pts/api/llm/08_embeddings.pct.py 10
async_embedding = _llm_async_func_factory(
    func=functools.wraps(litellm.embedding)(litellm.aembedding), # This is needed as 'litellm.aembedding' lacks the right signature
    func_name="async_embedding",
    func_cache_name="embedding",
    retrieve_log_data=lambda model, func_kwargs, response, cache_args: {
        "method": "embedding",
        "input_tokens": sum([token_counter(model=model, text=inp, **cache_args) for inp in func_kwargs['input']]),
        "output_tokens": None,
        "cost": response._hidden_params['response_cost'],
    }
)

async_embedding.__doc__ = """
This function is a wrapper around a corresponding function in the `litellm` library, see [this](https://docs.litellm.ai/docs/embedding/supported_embedding) for a full list of the available arguments.
""".strip()
sig = inspect.signature(async_embedding)
sig = sig.replace(parameters=[
    Parameter("model", Parameter.POSITIONAL_OR_KEYWORD, annotation=str),
    Parameter("input", Parameter.POSITIONAL_OR_KEYWORD, annotation=list[str]),
    *sig.parameters.values()
])
async_embedding.__signature__ = sig

# %% ../../../pts/api/llm/08_embeddings.pct.py 13
def batch_embeddings(
    model: str,
    input: list[str] = None,
    batch_size: int = 1000,
    verbose: bool = False,
    **kwargs
):
    """
    Compute embeddings for a list of input strings in batches synchronously.

    Args:
        model (str): The embedding model to use.
        input (list[str]): List of input strings to embed.
        batch_size (int): Number of inputs per batch.
        verbose (bool): If True, display a progress bar.
        **kwargs: Additional keyword arguments passed to `embedding`.

    Returns:
        list: List of embedding vectors for each input string.
    """
    batches = []
    for i in range(0, len(input), batch_size):
        batch = input[i:i + batch_size]
        batches.append(batch)
    
    responses = []
    if verbose:
        from tqdm import tqdm
        for batch in tqdm(batches, desc="Processing embedding batches"):
            response = embedding(model=model, input=batch, **kwargs)
            responses.append(response)
    else:
        for batch in batches:
            response = embedding(model=model, input=batch, **kwargs)
            responses.append(response)
        
    embeddings = []
    for response in responses:
        embeddings.extend([d['embedding'] for d in response.data])
    return embeddings

# %% ../../../pts/api/llm/08_embeddings.pct.py 16
async def async_batch_embeddings(
    model: str,
    input: list[str] = None,
    batch_size: int = 1000,
    verbose: bool = False,
    **kwargs
):
    """
    Compute embeddings for a list of input strings in batches asynchronously.

    Args:
        model (str): The embedding model to use.
        input (list[str]): List of input strings to embed.
        batch_size (int): Number of inputs per batch.
        verbose (bool): If True, display a progress bar.
        **kwargs: Additional keyword arguments passed to `async_embedding`.

    Returns:
        list: List of embedding vectors for each input string.
    """
    embedding_tasks = []
    for i in range(0, len(input), batch_size):
        batch = input[i:i + batch_size]
        embedding_tasks.append(async_embedding(model=model, input=batch, **kwargs))
    
    if verbose:
        from tqdm.asyncio import tqdm_asyncio
        responses = await tqdm_asyncio.gather(*embedding_tasks, desc="Processing embedding batches", total=len(embedding_tasks))
    else:
        responses = await asyncio.gather(*embedding_tasks)
        
    embeddings = []
    for response in responses:
        embeddings.extend([d['embedding'] for d in response.data])
    return embeddings
