#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
:Purpose: This module provides a localised wrapper and specialised
functionality around the
``langchain_community.vectorstores.Chroma`` class, for
interacting with a Chroma database.
:Platform: Linux/Windows | Python 3.10+
:Developer: J Berendt
:Email: development@s3dev.uk
:Comments: This module uses the
``langchain_community.vectorstores.Chroma`` wrapper class,
rather than the base ``chromadb`` library as it provides the
``add_texts`` method which supports GPU processing and
parallelisation; which is implemented by this module's
:meth:`~ChromaDB.add_documents` method.
"""
# pylint: disable=import-error
# pylint: disable=wrong-import-order
from __future__ import annotations
import chromadb
import os
import torch
from glob import glob
from hashlib import md5
from langchain_huggingface import HuggingFaceEmbeddings
# langchain's Chroma is used rather than the base chromadb as it provides
# the add_texts method which support GPU processing and parallelisation.
from langchain_community.vectorstores import Chroma as _Chroma
[docs]
class ChromaDB(_Chroma):
"""Wrapper class around the ``chromadb`` library.
Args:
path (str): Path to the chroma database's *directory*.
collection (str): Collection name.
offline (bool, optional): Remain offline, used the cached
embedding function model rather than obtaining one online.
Defaults to False.
"""
# pylint: disable=line-too-long
_MODEL_CACHE = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), '.cache')
# Installing torch is a huge overhead, just for this. However, torch
# will already be installed as part of the sentence-transformers library,
# so we'll use it here.
_MODEL_KWARGS = {'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
# TODO: Add this to a config file.
_MODEL_NAME = 'all-MiniLM-L6-v2'
def __init__(self, path: str, collection: str, offline: bool=False):
"""Chroma database class initialiser."""
self._path = os.path.realpath(path)
self._cname = collection
self._offline = offline
self._client = None # Database 'client' object
self._dbc = None # Database 'collection' object.
self._set_client()
self._set_embedding_fn()
super().__init__(client=self._client,
collection_name=self._cname,
embedding_function=self._embfn,
persist_directory=self._path)
self._set_collection()
@property
def client(self):
"""Accessor to the :class:`chromadb.PersistentClient` class."""
return self._client
@property
def collection(self):
"""Accessor to the chromadb client's collection object."""
return self._dbc
@property
def embedding_function(self):
"""Accessor to the embedding function used."""
return self._embfn
@property
def path(self) -> str:
"""Accessor to the database's path."""
return self._path
[docs]
def add_documents(self, docs: list[langchain_core.documents.base.Document]): # noqa # pylint: disable=undefined-variable
"""Add multiple documents to the collection.
This method overrides the base class' ``add_documents`` method
to enable local ID derivation. Knowing *how* the IDs are derived
gives us greater understanding and querying ability of the
documents in the database. Each ID is derived locally by the
:meth:`_preproc` method from the file's basename, page number
and page content.
Additionally, this method wraps the
:func:`langchain_community.vectorstores.Chroma.add_texts`
method which supports GPU processing and parallelisation.
Args:
docs (list): A list of ``langchain_core.documents.base.Document``
document objects.
"""
# pylint: disable=arguments-differ
# pylint: disable=arguments-renamed
if not isinstance(docs, list):
docs = [docs]
ids_, docs_, meta_ = self._preproc(docs=docs)
self.add_texts(ids=ids_, texts=docs_, metadatas=meta_)
[docs]
def show_all(self):
"""Return the entire contents of the collection.
This is an alias around ``.collection.get()``.
"""
return self._dbc.get()
[docs]
def _get_embedding_function_model(self) -> str:
"""Derive the path to the embedding function model.
:Note:
If ``offline=True`` was passed into the class constructor,
the model cache is used, if available - otherwise the user
is warned.
If online usage is allowed, the model is obtained by the
means defined by the embedding function constructor.
Returns:
str: The name of the model. Or, if offline, the path to the
model's cache to be passed into the embedding function
constructor is returned.
"""
if self._offline:
if not os.path.exists(self._MODEL_CACHE):
os.makedirs(self._MODEL_CACHE)
msg = ('Offline mode has been chosen, yet the embedding function model cache does not exist. '
'Therefore, a model must be downloaded. Please enable online usage for the first run '
'so a model can be downloaded and stored into the cache for future (offline) use.')
raise FileNotFoundError(msg)
# Find the cache directory containing the named model, this enables offline use.
model_loc = os.path.commonpath(filter(lambda x: 'config.json' in x,
glob(os.path.join(self._MODEL_CACHE,
f'*{self._MODEL_NAME}*',
'**'),
recursive=True)))
return model_loc
return self._MODEL_NAME
[docs]
@staticmethod
def _preproc(docs: list):
"""Pre-process the document objects to create the IDs.
Parse the ``Document`` object into its parts for storage.
Additionally, create the ID as a hash of the source document's
basename, page number and content.
"""
ids = []
txts = []
metas = []
for doc in docs:
pc = doc.page_content
m = doc.metadata
pc_, src_ = map(str.encode, (pc, m['source']))
pg_ = str(m.get('pageno', 0)).zfill(4)
id_ = f'id_{md5(src_).hexdigest()}_{pg_}_{md5(pc_).hexdigest()}'
ids.append(id_)
txts.append(pc)
metas.append(m)
return ids, txts, metas
[docs]
def _set_client(self):
"""Set the database client object."""
settings = chromadb.Settings(anonymized_telemetry=False)
self._client = chromadb.PersistentClient(path=self._path,
settings=settings)
[docs]
def _set_collection(self):
"""Set the database collection object."""
self._dbc = self._client.get_or_create_collection(self._cname,
metadata={'hnsw:space': 'cosine'})
[docs]
def _set_embedding_fn(self):
"""Set the embeddings function object."""
model_name = self._get_embedding_function_model()
self._embfn = HuggingFaceEmbeddings(model_name=model_name,
model_kwargs=self._MODEL_KWARGS,
cache_folder=self._MODEL_CACHE)