from typing import TYPE_CHECKING, Tuple, Union, Optional, cast
from inspect import getmodule
from logging import getLogger
import pandas as pd
import numpy as np
from numpy.typing import NDArray

if TYPE_CHECKING:
    from graphistry.Plottable import Plottable

from .feature_utils import FeatureMixin
from .ai_utils import search_to_df, FaissVectorSearch
from .constants import WEIGHT, DISTANCE


if TYPE_CHECKING:
    MIXIN_BASE = FeatureMixin
else:
    MIXIN_BASE = object

logger = getLogger(__name__)


QueryVector = Union[NDArray[np.float32], NDArray[np.float64]]


class SearchToGraphMixin(MIXIN_BASE):
    def __init__(self, *a, **kw):
        super().__init__(*a, **kw)

    def assert_fitted(self):
        # assert self._umap is not None, 'Umap needs to be fit first, run g.umap(..) to fit a model'
        assert (
            self._get_feature("nodes") is not None
        ), "Graphistry Instance is not fit, run g.featurize(kind='nodes', ..) to fit a model ' \
        'if you have nodes & edges dataframe or g.umap(kind='nodes', ..) if you only have nodes dataframe"

    def assert_features_line_up_with_nodes(self):
        ndf = self._nodes
        X = self._get_feature("nodes")
        a, b = ndf.shape[0], X.shape[0]
        assert a == b, (
            "Nodes dataframe and feature vectors are not same size, "
            f"found nodes: {a}, feats: {b}. Did you mutate nodes between fit?"
        )

    def build_index(self, angular=False, n_trees=None):
        # builds local index
        self.assert_fitted()
        self.assert_features_line_up_with_nodes()
        X = self._get_feature("nodes")
        if 'cudf' in str(getmodule(X)):
            X = X.to_pandas()
        self.search_index = FaissVectorSearch(
            X.values
        )  # self._build_search_index(X, angular, n_trees, faiss=False)

    def _query_from_dataframe(self, qdf: pd.DataFrame, top_n: int, thresh: float) -> Tuple[pd.DataFrame, QueryVector]:
        # Use the loaded featurizers to transform the dataframe
        result = self.transform(qdf, None, kind="nodes", return_graph=False)
        assert isinstance(result, tuple), "transform with return_graph=False should return tuple"
        vect, _ = result

        nodes = self._nodes
        if 'cudf' in str(getmodule(nodes)):
            nodes = nodes.to_pandas()

        results = self.search_index.search_df(vect, self._nodes, top_n)
        results = results.query(f"{DISTANCE} < {thresh}")

        return results, cast(QueryVector, vect)

    def _query(self, query: str, top_n: int, thresh: float) -> Tuple[pd.DataFrame, Optional[QueryVector]]:
        # build the query dataframe
        if not hasattr(self, "search_index"):
            self.build_index()

        qdf = pd.DataFrame([])

        cols_text = self._node_encoder.text_cols  # type: ignore

        if len(cols_text) == 0:
            logger.warn(
                "** Querying is only possible using Transformer/Ngrams embeddings"
            )
            return pd.DataFrame([]), None

        qdf[cols_text[0]] = [query]
        if len(cols_text) > 1:
            for col in cols_text[1:]:
                qdf[col] = [""]

        # this is hookey and needs to be fixed on skrub side (with errors='ignore')
        # if however min_words = 0, all columns will be textual,
        # and no other data_encoder will be generated
        if hasattr(self._node_encoder.data_encoder, "columns_"):  # type: ignore

            other_cols = self._node_encoder.data_encoder.columns_  # type: ignore

            if other_cols is not None and len(other_cols):
                logger.warn(
                    "** There is no easy way to encode categorical or other features at query time. "
                    f"Set `thresh` to a large value if no results show up.\ncolumns: {other_cols}"
                )
                df = self._nodes
                dt = df[other_cols].dtypes
                for col, v in zip(other_cols, dt.values):
                    if str(v) in ["string", "object", "category"]:
                        qdf[col] = df.sample(1)[col].values  # so hookey
                    elif str(v) in [
                        "int",
                        "float",
                        "float64",
                        "float32",
                        "float16",
                        "int64",
                        "int32",
                        "int16",
                        "uint64",
                        "uint32",
                        "uint16",
                    ]:
                        qdf[col] = df[col].mean()

        return self._query_from_dataframe(qdf, thresh=thresh, top_n=top_n)

    def search(
        self,
        query: str,
        cols=None,
        thresh: float = 5000,
        fuzzy: bool = True,
        top_n: int = 10,
    ) -> Tuple[pd.DataFrame, Optional[QueryVector]]:
        """Natural language query over nodes that returns a dataframe of results sorted by relevance column "distance".

            If node data is not yet feature-encoded (and explicit edges are given),
            run automatic feature engineering:
            ::

                g2 = g.featurize(kind='nodes', X=['text_col_1', ..],
                min_words=0 # forces all named columns are textually encoded
                )

            If edges do not yet exist, generate them via
            ::

                g2 = g.umap(kind='nodes', X=['text_col_1', ..],
                min_words=0 # forces all named columns are textually encoded
                )
            
            If an index is not yet built, it is generated `g2.build_index()` on the fly at search time.
            Otherwise, can set `g2.build_index()` to build it ahead of time.

        Args:
            :query (str): natural language query.
            :cols (list or str, optional): if fuzzy=False, select which column to query.
                                            Defaults to None since fuzzy=True by defaul.
            :thresh (float, optional): distance threshold from query vector to returned results.
                                        Defaults to 5000, set large just in case,
                                        but could be as low as 10.
            :fuzzy (bool, optional): if True, uses embedding + annoy index for recall,
                                        otherwise does string matching over given `cols`
                                        Defaults to True.
            :top_n (int, optional): how many results to return. Defaults to 100.

        Returns:
            **pd.DataFrame, vector_encoding_of_query:**
            rank ordered dataframe of results matching query

            vector encoding of query via given transformer/ngrams model if fuzzy=True else None
        """
        if not fuzzy:
            if cols is None:
                logger.error(
                    f"Columns to search for `{query}` \
                             need to be given when fuzzy=False, found {cols}"
                )

            logger.info(f"-- Word Match: [[ {query} ]]")
            return (
                pd.concat(
                    [
                        search_to_df(query, col, self._nodes, as_string=True)
                        for col in cols
                    ]
                ),
                None,
            )
        else:
            logger.info(f"-- Search: [[ {query} ]]")
            return self._query(query, thresh=thresh, top_n=top_n)

    def search_graph(
        self,
        query: str,
        scale: float = 0.5,
        top_n: int = 100,
        thresh: float = 5000,
        broader: bool = False,
        inplace: bool = False,
    ) -> "Plottable":
        """Input a natural language query and return a graph of results.
            See help(g.search) for more information

        Args:
            :query (str): query input eg "coding best practices"
            :scale (float, optional): edge weigh threshold,  Defaults to 0.5.
            :top_n (int, optional): how many results to return. Defaults to 100.
            :thresh (float, optional): distance threshold from query vector to returned results.
                                        Defaults to 5000, set large just in case,
                                        but could be as low as 10.
            :broader (bool, optional): if True, will retrieve entities connected via an edge
                that were not necessarily bubbled up in the results_dataframe. Defaults to False.
            :inplace (bool, optional): whether to return new instance (default) or mutate self.
                                        Defaults to False.

        Returns:
            graphistry Instance: g
        """
        if inplace:
            res = self
        else:
            res = cast('SearchToGraphMixin', self.bind())

        edf = edges = res._edges
        # print('shape of edges', edf.shape)
        rdf = df = res._nodes
        # print('shape of nodes', rdf.shape)

        if 'cudf' in str(getmodule(edges)):
            import cudf

            if not isinstance(rdf, cudf.DataFrame):
                rdf = cudf.from_pandas(rdf)
                df = rdf

            concat = cudf.concat
            cudf_coercion = True
        else:
            concat = pd.concat
            cudf_coercion = False

        node = res._node
        indices = rdf[node]
        if cudf_coercion:
            import cudf
            if not isinstance(indices, cudf.Series):
                indices = cudf.Series.from_pandas(indices)
        src = res._source
        dst = res._destination
        if query != "":
            # run a real query, else return entire graph
            rdf, _ = res.search(query, thresh=thresh, fuzzy=True, top_n=top_n)
            if not rdf.empty:
                if cudf_coercion:
                    import cudf
                    #if not isinstance(indices, cudf.Series):
                    #    indices = cudf.Series.from_pandas(indices)
                    if not isinstance(rdf, cudf.DataFrame):
                        rdf = cudf.from_pandas(rdf)
                indices = rdf[node]

                # now get edges from indices
                if broader:  # this will make a broader graph, finding NN in src OR dst
                    edges = edf[(edf[src].isin(indices)) | (edf[dst].isin(indices))]
                else:  # finds only edges between results from query, if they exist,
                    # default smaller graph
                    edges = edf[(edf[src].isin(indices)) & (edf[dst].isin(indices))]
            else:
                logger.warn(
                    "**No results found due to empty DataFrame, returning original graph"
                )
                return res

        try:  # for umap'd edges
            edges = edges.query(f"{WEIGHT} > {scale}")
        except:  # for explicit edges
            pass

        #logger.info('type edges=%s, indices=%s', type(edges), type(indices))
        #raise ValueError(f'stop here: {type(edges)}, {type(indices)}')

        found_indices = concat([edges[src], edges[dst], indices], axis=0).unique()
        emb = None
        node_feats = res._node_features
        if cudf_coercion:
            import cudf
            if not isinstance(node_feats, cudf.DataFrame):
                node_feats = cudf.from_pandas(node_feats)

        node_emb = res._node_embedding
        if cudf_coercion and res._umap is not None:
            import cudf
            node_emb = res._node_embedding
            if not isinstance(node_emb, cudf.DataFrame):
                node_emb = cudf.from_pandas(node_emb)

        try:
            tdf = rdf.iloc[found_indices]
            feats = node_feats.iloc[found_indices]  # type: ignore
            if res._umap is not None:
                emb = node_emb.iloc[found_indices]  # type: ignore
        except Exception:  # for explicit relabeled nodes
            #logger.exception(e)
            tdf = rdf[df[node].isin(found_indices)]
            feats = node_feats.loc[tdf.index]  # type: ignore
            if res._umap is not None:
                emb = node_emb[df[node].isin(found_indices)]  # type: ignore
        logger.info(f" - Returning edge dataframe of size {edges.shape[0]}")
        # get all the unique nodes
        logger.info(
            f" - Returning {tdf.shape[0]} unique nodes given scale {scale} and thresh {thresh}"
        )

        g = res.edges(edges, src, dst).nodes(tdf, node)
        # add them back so they sync with .dbscan etc calls
        g._node_features = feats
        g._node_embedding = emb

        if g._name is not None:
            name = f"{g._name}-query:{query}"
        else:
            name = f"query:{query}"
        g = g.name(name)  # type: ignore
        return g

    def save_search_instance(self, savepath):
        from joblib import dump  # type: ignore   # need to make this onnx or similar

        self.build_index()
        search = self.search_index
        del self.search_index  # can't pickle Annoy
        dump(self, savepath)
        self.search_index = search  # add it back
        logger.info(f"Saved: {savepath}")

    @classmethod
    def load_search_instance(self, savepath):
        from joblib import load  # type: ignore   # need to make this onnx or similar

        cls = load(savepath)
        cls.build_index()
        return cls
