##
# File:    OeSearchMoleculeProvider.py
# Author:  J. Westbrook
# Date:    4-Mar-2020
#
# Updates:
#
##
"""
Utilities deliver OE molecule data for searchable chemical component data.
"""
__docformat__ = "restructuredtext en"
__author__ = "John Westbrook"
__email__ = "john.westbrook@rcsb.org"
__license__ = "Apache 2.0"

import logging
import os
import time

from rcsb.utils.chem.ChemCompSearchIndexProvider import ChemCompSearchIndexProvider
from rcsb.utils.chem.OeIoUtils import OeIoUtils
from rcsb.utils.io.MarshalUtil import MarshalUtil

# from rcsb.utils.io.SingletonClass import SingletonClass

logger = logging.getLogger(__name__)


class OeSearchMoleculeProvider(object):
    """Utilities build and deliver OE molecules for search applications. Source molecular
    definitions are taken from SMILES descriptors generated by ChemCompSearchIndexProvider()
    """

    def __init__(self, **kwargs):
        """Utilities build and deliver OE molecules for search applications. Source molecular
           definitions are taken from SMILES descriptors generated by ChemCompSearchIndexProvider()

        Args:
            cachePath (str, optional): path to the directory containing cache files (default: '.')
            ccFileNamePrefix (str, optional) file name prefix for chemical component search index (default: "cc")
            oeFileNamePrefix (str, optional) file name prefix for all generated databases (default: "oe")

        """
        # Database file names with be prefixed with base prefix plus the molecular build type and perception options
        oeFileNamePrefixBase = kwargs.get("oeFileNamePrefix", "oe")
        self.__ccFileNamePrefix = kwargs.get("ccFileNamePrefix", "oe")
        limitPerceptions = kwargs.get("limitPerceptions", False)
        if limitPerceptions:
            self.__oeFileNamePrefix = oeFileNamePrefixBase + "-limit"
        else:
            self.__oeFileNamePrefix = oeFileNamePrefixBase
        #
        cachePath = kwargs.get("cachePath", ".")
        self.__dirPath = os.path.join(cachePath, "oe_mol")
        #
        self.__fpDbD = {}
        self.__ssDb = None
        self.__oeMolD = {}
        self.__oeMolDb = None
        self.__oeMolDbTitleD = None
        #
        self.__mU = MarshalUtil(workPath=self.__dirPath)
        self.__reload(**kwargs)

    def testCache(self):
        return self.__mU.exists(os.path.join(self.__dirPath, self.__getOeSearchMolFileName())) and self.__mU.exists(os.path.join(self.__dirPath, self.__getOeMolDbFileName()))

    def getSubSearchDb(self, screenType="SMARTS", numProc=1, forceRefresh=False):
        if not self.__ssDb or forceRefresh:
            oeIo = OeIoUtils()
            fp = os.path.join(self.__dirPath, self.__getSubSearchFileName(screenType))
            logger.info("Opening screened substructure search database %r", fp)
            self.__ssDb = oeIo.loadOeSubSearchDatabase(fp, screenType, numProc=numProc)
        return self.__ssDb

    def getFingerPrintDb(self, fpType, fpDbType="STANDARD", rebuild=False):
        if fpType not in self.__fpDbD or rebuild:
            oeIo = OeIoUtils()
            fastFpDbPath = os.path.join(self.__dirPath, self.__getFastFpDbFileName(fpType))
            oeMolDbFilePath = os.path.join(self.__dirPath, self.__getOeMolDbFileName())
            fpDb = oeIo.loadOeFingerPrintDatabase(oeMolDbFilePath, fastFpDbPath, inMemory=True, fpType=fpType, fpDbType=fpDbType)
            if fpDb:
                self.__fpDbD[fpType] = fpDb
        #
        return self.__fpDbD[fpType]

    def __getOeMolDbTitleIndex(self):
        oeMolDbTitleD = {}
        try:
            for idx in range(self.__oeMolDb.GetMaxMolIdx()):
                oeMolDbTitleD[self.__oeMolDb.GetTitle(idx)] = idx
        except Exception as e:
            logger.exception("Failing with %s", str(e))
        return oeMolDbTitleD

    def getOeMolDatabase(self):
        if not self.__oeMolDb:

            oeIo = OeIoUtils()
            self.__oeMolDb = oeIo.loadOeBinaryDatabaseAndIndex(os.path.join(self.__dirPath, self.__getOeMolDbFileName()))
            self.__oeMolDbTitleD = self.__getOeMolDbTitleIndex()
        return self.__oeMolDb, self.__oeMolDbTitleD

    def getOeMolD(self):
        try:
            if not self.__oeMolD:
                startTime = time.time()
                oeIo = OeIoUtils()
                self.__oeMolD = oeIo.readOeBinaryMolCache(os.path.join(self.__dirPath, self.__getOeSearchMolFileName()))
                logger.info("Loading OE binary molecule cache length %d (%.4f seconds)", len(self.__oeMolD), time.time() - startTime)
            return self.__oeMolD
        except Exception as e:
            logger.exception("Failing with %s", str(e))
        return None

    def getMol(self, searchCcId):
        try:
            if not self.__oeMolD:
                startTime = time.time()
                oeIo = OeIoUtils()
                self.__oeMolD = oeIo.readOeBinaryMolCache(os.path.join(self.__dirPath, self.__getOeSearchMolFileName()))
                logger.info("Loading OE binary molecule cache length %d (%.4f seconds)", len(self.__oeMolD), time.time() - startTime)
            return self.__oeMolD[searchCcId]
        except Exception as e:
            logger.exception("Get molecule %r failing with %s", searchCcId, str(e))
        return None

    def __getFastFpDbFileName(self, fpType):
        return "%s-si-fast-fp-database-%s.fpbin" % (self.__oeFileNamePrefix, fpType)

    def __getSubSearchFileName(self, screenType):
        return "%s-si-ss-database-%s.oeb" % (self.__oeFileNamePrefix, screenType)

    def __getOeMolDbFileName(self):
        return "%s-si-mol-db-components.oeb" % self.__oeFileNamePrefix

    def __getOeSearchMolFileName(self):
        """Raw binary files of OE molecules in the search index.

        Returns:
            str: file name
        """
        return "%s-si-search-mol-components.oeb" % self.__oeFileNamePrefix

    def __reload(self, **kwargs):
        """Reload the dictionary of OE molecules and related data artifacts for chemical component definitions.

        Args:
            limitPerceptions(bool): process input descriptors in essentially verbatim mode (default: True)
            fpTypeList (list): fingerprint type (TREE,PATH,MACCS,CIRCULAR,LINGO)
            screenTypeList (list): fast sub search screen type (MOLECULE, SMARTS, MDL, ... )
            useCache (bool, optional): flag to use cached files. Defaults to True.
            cachePath (str): path to the top cache directory. Defaults to '.'.
            numProc (int): number processors to engage in screen substructure search database generation.
            suppressHydrogens (bool, optional): flag to suppress explicit hydrogens in the OE data store.
            molLimit (int):

        Returns:
            (bool) : True for success or False othewise

        """
        try:
            useCache = kwargs.get("useCache", True)
            cachePath = kwargs.get("cachePath", ".")
            numProc = kwargs.get("numProc", 2)
            molLimit = kwargs.get("molLimit", None)
            fpTypeList = kwargs.get("fpTypeList", ["TREE", "PATH", "MACCS", "CIRCULAR", "LINGO"])
            # screenTypeList = kwargs.get("screenTypeList", ["SMARTS"])
            screenTypeList = kwargs.get("screenTypeList", None)

            limitPerceptions = kwargs.get("limitPerceptions", False)
            suppressHydrogens = kwargs.get("suppressHydrogens", False)
            quietFlag = kwargs.get("quietFlag", True)
            logSizes = kwargs.get("logSizes", False)
            fpDbType = "STANDARD"
            buildScreenedDb = True
            #
            oeCount = 0
            errCount = 0
            failIdList = []
            oeIo = OeIoUtils(quietFlag=quietFlag)
            # --------
            oeSearchMolFilePath = os.path.join(self.__dirPath, self.__getOeSearchMolFileName())
            if not useCache or (useCache and not self.__mU.exists(oeSearchMolFilePath)):
                cmpKwargs = {k: v for k, v in kwargs.items() if k not in ["cachePath", "useCache", "molLimit"]}
                ccsiP = ChemCompSearchIndexProvider(cachePath=cachePath, useCache=True, molLimit=molLimit, **cmpKwargs)
                ok = ccsiP.testCache(minCount=molLimit, logSizes=logSizes)
                # ----
                ccIdxD = ccsiP.getIndex() if ok else {}
                idxCount = len(ccIdxD)
                # ------- JDW OE mol construction here -----
                startTime = time.time()
                oeCount, errCount, failIdList = oeIo.buildOeBinaryMolCacheFromIndex(
                    oeSearchMolFilePath, ccIdxD, quietFlag=quietFlag, fpTypeList=fpTypeList, limitPerceptions=limitPerceptions, suppressHydrogens=suppressHydrogens
                )
                if failIdList:
                    logger.info("failures %r", failIdList)
                endTime = time.time()
                logger.info("Constructed %d/%d cached oeMols  (unconverted %d) (%.4f seconds)", oeCount, idxCount, errCount, endTime - startTime)
            # --------
            oeMolDbFilePath = os.path.join(self.__dirPath, self.__getOeMolDbFileName())
            if not useCache or (useCache and not self.__mU.exists(oeMolDbFilePath)):
                startTime = time.time()
                molCount = oeIo.createOeBinaryDatabaseAndIndex(oeSearchMolFilePath, oeMolDbFilePath)
                endTime = time.time()
                logger.info("Created and stored %d indexed oeMols in OE database format (%.4f seconds)", molCount, endTime - startTime)

            # --------
            if fpDbType == "FAST":
                for fpType in fpTypeList:
                    startTime = time.time()
                    #  Fast FP search database file names
                    fpPath = os.path.join(self.__dirPath, self.__getFastFpDbFileName(fpType))
                    if not useCache or (useCache and not self.__mU.exists(fpPath)):
                        ok = oeIo.createOeFingerPrintDatabase(oeMolDbFilePath, fpPath, fpType=fpType)
                        endTime = time.time()
                        logger.info("Created and stored %s fingerprint database (%.4f seconds)", fpType, endTime - startTime)
            # --------
            if buildScreenedDb and screenTypeList:
                for screenType in screenTypeList:
                    startTime = time.time()
                    fp = os.path.join(self.__dirPath, self.__getSubSearchFileName(screenType))
                    if not useCache or (useCache and not self.__mU.exists(fp)):
                        ok = oeIo.createOeSubSearchDatabase(oeSearchMolFilePath, fp, screenType=screenType, numProc=numProc)
                        endTime = time.time()
                        logger.info("Constructed screened substructure database (status %r) with screenType %s (%.4f seconds)", ok, screenType, endTime - startTime)
                        # ---------
                        ssDb = oeIo.loadOeSubSearchDatabase(fp, screenType=screenType, numProc=numProc)
                        ok = ssDb.NumMolecules() == oeCount
                        # ----------
            #
            return True
        except Exception as e:
            logger.exception("Failing with %s", str(e))
        return False
