"""
The flow matrix (TFM) contains the flows between agents and the stock-flow consistency. It has the following structure:
"""

__version__ = "0.3"
__author__ = "Thomas Baldauf"
__email__ = "thomas.baldauf@dlr.de"
__license__ = "MIT"
__birthdate__ = '15.11.2021'
__status__ = 'prod'  # options are: dev, test, prod

from .singleton import Singleton
from .world import World
from collections import defaultdict
import pandas as pd
import warnings
from enum import IntEnum, Enum
import numpy as np
# import seaborn as sns
import networkx as nx
import re
from typing import Optional, Type


class Accounts(Enum):
    """
    Enum for account type in flow matrix. Current and capital account is allowed
    """
    KA = 1  # capital account
    CA = 0  # current account


class FlowMatrix(Singleton):
    """
    The flow matrix contains the flows between agents and the stock-flow consistency. 
    It includes CA: Current Account, and KA: Capital Account entries.

    As a consistency check, the sum of rows and columns should be zero after the end of each simulation period.
    """

    """
+--------------------------------------------------------------------------------------+
| Flow Matrix                                                                          |
+===============+=========================+==================================+=========+
|               |     Agent 1             |           Agent 2                |  TOTAL  |
+---------------+-----------+-------------+----------------+-----------------+---------+
|               |     CA    | KA          |           CA   | KA              |         |
+---------------+-----------+-------------+----------------+-----------------+---------+
|               |           |             |                |                 |         |
| Flows         |    -x     |             |                |  +x             |   0     |
+---------------+-----------+-------------+----------------+-----------------+---------+
|               |           |             |                |                 |         |
| d(Stocks)     |     +x    |             |                |  -x             |   0     |
+---------------+-----------+-------------+----------------+-----------------+---------+
| TOTAL         |     0     |  0          |            0   |    0            |   0     |
+---------------+-----------+-------------+----------------+-----------------+---------+
"""

    # For more information, see
    #   - Investopedia Article https://www.investopedia.com/ask/answers/031615/whats-difference-between-current-account-and-capital-account.asp
    #   - A nice introduction by Gasselli (OECD) https://www.oecd.org/naec/new-economic-policymaking/grasselli_OECD_masterclass_2019.pdf

    reset_count = 0
    enabled = True

    def __init__(self) -> None:
        # constructor of the flow matrix

        if hasattr(self, "initialized"):  # only initialize once
            return

        self.initialized = True  # initialized flag

        # TODO maybe think about replacing above code with 'if self.do_init:' check?
        # Keeping it for now because it works...

        self._flow_data = []  # has two keys CA and KA, see reset()...
        self._stock_data = {}
        self.stocks = []

        self.reset()  # defaultdict will take care of missing keys here
        self.linear_log = []  # linear transaction log

    def reset(self, verbose=False):
        """
        reset the data
        :param verbose: boolean (default False), triggers a reset warning if True
        """

        # self._flow_data[Accounts.CA] = defaultdict(lambda: defaultdict(float))
        # self._flow_data[Accounts.KA] = defaultdict(lambda: defaultdict(float))
        self._flow_data = [defaultdict(lambda: defaultdict(
            float)), defaultdict(lambda: defaultdict(float))]

        # map [CA/KA] -> [subject] -> [agent1] --> quantity

        # self.stocks = []
        self._stock_data = {}  # defaultdict(lambda: defaultdict(float))

        if self.__class__.reset_count > 1:

            # warn the user every time this is being reset!
            if verbose:
                warnings.warn("FlowMatrix has been reset")

        else:
            self.__class__.reset_count += 1

    def add_stock(self, name):
        """
        add another stock for logging
        """
        if name not in self._stock_data:
            self._stock_data[name] = defaultdict(float)
            # print("add", name, "->", self._stock_data)

    def log_flow(self, direction, quantity, agent_from, agent_to, subject, price=None, invert=False):
        """
        Registers a flow at the flow matrix. This is no method for the user in most cases. It is automatically generated in the qattune gui.

        :param direction: tuple (from_account,to_account) either 0 (CA) or 1 (KA), e.g. (0,0) or Accounts.CA, Accounts.KA
        :param quantity: weight (quantity) of the flow
        :param price: float or None (default). If not None, a price conversion factor is applied. Should not be used in most cases
        :param invert: reverse sign of the transferred quantity? Default False, should not be used in most cases
        :param agent_from: Agent instance (optionally str), sender agent
        :param agent_to: Agent instnace (optionally str), receiver agent

        Example

        .. code-block:: python

            from sfctools import FlowMatrix, Accounts
            CA = Accounts.CA
            KA = Accounts.KA

            flow = (CA,KA) # from account -> to account
            FlowMatrix().log_flow(flow, 42.0, agent1, agent2, subject="my_subject")
            # ...

            df = FlowMatrix().to_dataframe(group=True)
            print(FlowMatrix().to_string(group=True))
            FlowMatrix().check_consistency() # ok if no error is raised

        """

        if not self.__class__.enabled:
            return

        if price is not None:
            Q = quantity * price
        else:
            Q = quantity

        if invert:
            Q = - Q

        from_account = direction[0].value
        to_account = direction[1].value

        self._flow_data[from_account][subject][agent_from] -= Q
        self._flow_data[to_account][subject][agent_to] += Q

    def check_consistency(self):
        """
        Checks the consistency of the transaction flow matrix.
        """

        df = self.to_dataframe(group=True)

        if df.empty:
            return

        if not df.empty:
            null_sym = "   .-   "

            # avoid numerical errors
            # so round to at max 4 orders of magntude less than max order of magnitude

            df2 = df.replace(null_sym, 0.0).astype(float)
            om_max = int(np.ceil(np.log10(df2.to_numpy().max())))
            om_min = int(np.ceil(np.log10(abs(df2.to_numpy().min()))))
            order_magnitude = max(om_max, om_min)
            df2 = df2.round(-order_magnitude + 4)

            if np.array(df2["Total"]).any():
                raise RuntimeError(
                    "Inconsistent Row In Flow Sheet: \n%s" % df.to_string())

            df3 = df2.T
            if np.array(df3["Total"]).any():
                raise RuntimeError(
                    "Inconsistent Column In Flow Sheet:\n %s" % df.to_string())

    def to_string(self, group=True, round=2, replace_zeros=True, justify="right", use_sep_lines=True, sorting=None,
                  sorting_transactions=None, merge_accounts=None, color_negative=False, color_values=False):
        """
        converts the flow matrix to string representation
        :return : str
        """
        if replace_zeros:
            df = self.to_dataframe(group=group, merge_accounts=merge_accounts, sorting_transactions=sorting_transactions).round(
                round)  # .replace(0.0, " .- ").to_string()
        else:
            df = self.to_dataframe(group=group, merge_accounts=merge_accounts,
                                   sorting_transactions=sorting_transactions).round(round)  # .to_string()

        if not use_sep_lines:
            if replace_zeros:
                if round > 0:
                    return df.replace(0.0, " .- ").to_string()
                else:
                    return df.replace(0.0, "").to_string()
            else:
                return df

        if replace_zeros:
            if round > 0:
                df = df.replace(0.0, " .- ")
            else:
                df = df.replace(0.0, "")

        df = df.rename(columns={"Total": "Total "})

        if sorting is not None:
            # print("sort column", df.columns)
            sorted_columns = []
            for s in sorting:
                if merge_accounts and (s in merge_accounts):
                    sorted_columns += [(s, '$|')]
                else:
                    sorted_columns += [(s, 'CA'), (s, 'KA')]
            for col in df.columns:
                if col[0] not in sorting:
                    sorted_columns += [col]
            # print(df)
            # print("sorted columns", sorted_columns)
            sorted_columns = [s for s in sorted_columns if s in df.columns]
            df = df[sorted_columns]
        # print(df)
        df = df.apply(pd.to_numeric, errors='coerce')
        if round > 0:
            df = df.round(round)
        else:
            df = df.round(0)

        if replace_zeros:
            if round > 0:
                df = df.fillna(" .- ")
            else:
                df = df.fillna("")

        # print("rounded df")
        # print(df)

        s = df.to_string(justify=justify)
        if round <= 0:
            #
            s = re.sub(r"(-?\d+)\.0\b", r"  \1", s)

        def convert_line(s, idx, r=" | "):
            # print("convert line", s)
            out = ""
            for j in range(len(s)):
                out += str(s[j])
                if j in idx:
                    out += r
            return out
        # add separator lines

        # print("lines before\n", s)
        lines = s.split("\n")
        # idx = [m.start()+1 for m in re.finditer(r"(KA|(?<=,)(?=,))", lines[1])]
        # idx = [m.start() + 1 for m in re.finditer(r"(KA|\$\|)", lines[1])]
        idx = [m.start() + 1 for m in re.finditer(r"(KA|\$\|)", lines[1])]

        # also include 'Total'
        try:
            last_split = lines[0].rfind("Total")
            split_pos = last_split - 1
            if all(abs(split_pos - i) > 2 for i in idx):  # avoid duplicates
                idx.append(split_pos)
        except:
            warnings.warn("Could not place total separator correctly.")

        # idx.append(max([len(str(i)) for i in df.index]))
        try:
            idx.append(max([len(str(i)) for i in df.index]))
            hline = [convert_line("-" * len(lines[0]), idx, r="-+-")]
            hrule = [convert_line("-" * len(lines[0]), idx, r="---")]

            # mylines = hrule
            mylines = [convert_line(i, idx) for i in lines[:2]]
            mylines += hline
            mylines += [convert_line(i, idx) for i in lines[2:-1]]
            mylines += hline
            mylines += [convert_line(i, idx) for i in [lines[-1]]]
            mylines += hrule
            mylines = ["|" + i + "|" for i in hrule + mylines]

            break_idx = None
            for i, line in enumerate(mylines):
                if line.startswith("|Δ"):
                    break_idx = i
                    break
            if break_idx is not None:
                hline = [
                    "|" + convert_line("-" * len(lines[0]) + "|", idx, r="-+-")]
                mylines = mylines[:break_idx] + hline + mylines[break_idx:]

            return_str = "\n".join(mylines).replace("$|", "  ")
        except Exception as e:
            # max might be not computable due to empty sequence
            # print(e)
            return_str = "<Empty FlowMatrix>"

        if color_negative:

            if not color_values:
                from colorama import Fore, Style, init

                # color any numbers with negative sign
                pattern = r"-\d+(\.\d+)?"
                return_str = re.sub(
                    pattern, lambda x: f"{Fore.RED}{x.group(0)}{Style.RESET_ALL}", return_str)
            else:
                warnings.warn(
                    "Cannot use color_negative and color_values at the same time")

        if color_values:
            from matplotlib.colors import Normalize
            from matplotlib import cm
            from colorama import Fore, Style, init
            init(autoreset=True)

            pattern = r"-?\d+(\.\d+)?"
            numbers = [float(match.group())
                       for match in re.finditer(pattern, return_str)]

            norm = Normalize(vmin=min(numbers), vmax=max(numbers))
            if isinstance(color_values, str):
                cmap_name = color_values
            else:
                cmap_name = "coolwarm"

            cmap = cm.get_cmap(cmap_name)

            # Function to convert a number to a color based on its size
            def colorize(match):
                number_str = str(match.group())
                number = float(match.group())  # NOTE do not typecast anymore
                # Get normalized color for the number
                color = cmap(norm(number))
                # Convert color to RGB with values in range 0-255 for colorama
                r, g, b, _ = [int(255 * c) for c in color]
                # Format with colorama escape codes
                return f"\033[38;2;{r};{g};{b}m{number_str}{Style.RESET_ALL}"

            # Apply coloring
            return_str = re.sub(pattern, colorize, return_str)

        return return_str

    def merge_accounts(self, df, merge_accounts, round=5):
        # merge CA and KA entry
        combined_df = df.copy()
        for entity in merge_accounts:
            try:
                ca_values = pd.to_numeric(
                    df[(entity, "CA")], errors='coerce').fillna(0).round(round)
                ka_values = pd.to_numeric(
                    df[(entity, "KA")], errors='coerce').fillna(0).round(round)

                combined_df[(entity, "$|")] = ca_values + ka_values
                combined_df.drop(
                    columns=[(entity, "CA"), (entity, "KA")], inplace=True)
            except Exception as e:
                warnings.warn("Warning: %s" % (str(e)))
        combined_df = combined_df.sort_index(axis=1)
        return combined_df

    def to_dataframe(self, group=True, insert_nullsym=True, merge_accounts=None, sorting_transactions=None):
        """
        Converts the data structure to a human-readable dataframe format.
        WARNING this is slow

        :param group: boolean switch (default True), if True it will group the agents of the same class together
        :param insert_nullsym: will insert '.-' symbol instead of zero, default True
        :return: pandas dataframe object
        """

        df_credit = pd.DataFrame(self._flow_data[0]).T
        df_capital = pd.DataFrame(self._flow_data[1]).T
        # print(df_credit)
        # print(df_capital)

        df_merge = pd.concat([df_credit, df_capital], axis=1, keys=[
                             'CA', 'KA']).swaplevel(0, 1, axis=1)
        try:
            df_merge = df_merge.sort_index(axis=1)
        except Exception as e:
            warnings.warn(str(e))

        df = df_merge.fillna(0.0).sort_index()
        agent_types = World().get_agent_types()

        if not group:

            df.loc["Total"] = df.sum()
            df["Total"] = df.T.sum()

            if merge_accounts:
                df = self.merge_accounts(df, merge_accounts)
            if sorting_transactions:
                df = self.sort_transactions(df, sorting_transactions)

            return df.round(4)

        else:
            """
            Provide an overview with aggregated classes
            """

            data = {}
            renamer = {}
            for a in agent_types:
                my_group = []
                for b in df.columns:
                    # print("b",b,"a",a,b[0])
                    if isinstance(b[0], a):
                        my_group.append(b[0])

                if len(my_group) > 0:
                    my_df = df.loc[:, df.columns.get_level_values(
                        0).isin(my_group)]
                    my_CA = my_df.loc[:, my_df.columns.get_level_values(
                        1).isin({"CA"})].sum(axis=1)
                    my_KA = my_df.loc[:, my_df.columns.get_level_values(
                        1).isin({"KA"})].sum(axis=1)

                    concat = pd.concat(
                        [my_CA, my_KA], axis=1, keys=['CA', 'KA'])

                    data[a] = concat

                if len(my_group) > 1:

                    if len(World().get_agents_of_type(a.__name__)) > 1:
                        if not a.__name__.endswith("y"):
                            renamer[a] = a.__name__ + "s"
                        else:
                            renamer[a] = a.__name__[:-1] + "ies"
                    else:
                        renamer[a] = a.__name__

                else:
                    renamer[a] = a.__name__

            if data == {}:
                return pd.DataFrame()

            df2 = pd.concat(data, axis=1)

            df2 = df2.rename(columns=renamer).sort_index()

            # df2 = df2.round(4)

            df2.loc["Total"] = df2.sum()
            df2["Total"] = df2.T.sum()

            df2 = df2.round(5)
            # cut one digit to obtain a 'consistently rounded table'
            df2 = df2.round(4)

            if insert_nullsym:
                null_sym = "   .-   "
                df2 = df2.replace(0.0, null_sym)

            sorted_cols2 = sorted(df2.columns)
            try:
                df2 = df2.reindex(sorted_cols2, axis=1)

            except ValueError:  # non-unique multi-indices here? bad sign... try to remove the non-unique part
                try:
                    sorted_cols2 = sorted(df2.columns)
                    # try to remove non-unique multi-indices
                    sorted_cols2 = list(set(sorted_cols2))
                    df2 = df2.reindex(sorted_cols2, axis=1)

                    warnings.warn(
                        "Encountered non-unique multi-index in FlowMatrix. Something might be wrong")

                except:  # other error?
                    raise RuntimeError(
                        "Something went wrong when aggregating the FlowMatrix as dataframe. You might be able to fix this error by cross-checking the naming of your agents and transactions.")

            # Make a list of all of the columns in the df
            cols = list(df2.columns.values)
            # print("COLS",cols)
            cols.pop(cols.index(("Total", "")))  # Remove Total from list
            # Create new dataframe with columns in the order you want
            df2 = df2[cols+[("Total", "")]]

            if merge_accounts:
                df2 = self.merge_accounts(df2, merge_accounts)
            if sorting_transactions:
                df2 = self.sort_transactions(df2, sorting_transactions)
            return df2

    def sort_transactions(self, df, sorting):
        row_index = sorted(df.index, key=lambda x: (
            sorting.index(x) if x in sorting else len(sorting)))
        df = df.reindex(row_index)
        return df

    def plot(self, group=True):
        self.plot_colored(show_plot=True, group=group)

    def plot_colored(self, show_plot=True, group=True, cmap='coolwarm', show_values=True):
        """
        Plots the flow matrix as a nice colored heat map.
        This will open up a matplotlib window...

        :param show_plot: show the plot as window? default True. If False, figure is returned instead
        :param group: aggregated view?
        :param cmap: colormap to use (default 'coolwarm')
        :param show_values: (default True) show numerical values above colored tiles
        :return fig: figure object
        """
        import matplotlib.pyplot as plt
        # TODO nicer formatting
        # TODO more plotting options
        import seaborn as sns
        df = self.to_dataframe(insert_nullsym=False, group=group)

        fig = plt.figure(figsize=(5, 3))
        sns.heatmap(df, annot=show_values, cmap=cmap,)
        ax = plt.gca()
        ax.set_ylabel('')
        ax.set_xlabel('')

        plt.tight_layout()
        if show_plot:
            plt.show()

        return fig

    def convert_sankey_df(self, sorting=None):
        """
        Converts the FlowMatrix data to a sankey-plottable dataframe format.
        This is meant to increase the flexibility in data retrieval and plotting capabilities. To directly plot the FlowMatrix, see plot_sankey

        :return: tuple of dataframes - sankey_source, sankey_sink
        """

        df = FlowMatrix().to_dataframe(insert_nullsym=False)
        # print("DF")
        # print(df)
        if sorting is not None:
            # Step 1: Separate rows that meet the "sorting" condition
            df_sorting = df[df.index.isin(sorting)]

            # Step 2: Separate rows that start with "Δ" or exactly match "Total"
            df_other = df[
                df.index.str.startswith("Δ") |
                (df.index == "Total")
            ]

            # Step 3: Apply categorical ordering and sort only `df_sorting`
            df_sorting.index = pd.CategoricalIndex(
                df_sorting.index, categories=sorting, ordered=True)
            df_sorting = df_sorting.sort_index()

            # Step 4: Concatenate the two DataFrames
            df_final = pd.concat([df_sorting, df_other])

            # Reset the index if needed
            df = df_final.reset_index(drop=False).set_index("index")
            # print("new df")
            # print(df)

        df_CA = df.iloc[:, df.columns.get_level_values(1) == 'CA']
        df_CA.columns = df_CA.columns.droplevel(1)  # current account

        df_KA = df.iloc[:, df.columns.get_level_values(
            1) == 'KA']  # capital account
        df_KA.columns = df_KA.columns.droplevel(1)

        source = []
        target = []
        types = []
        value = []

        source2 = []
        target2 = []
        types2 = []
        value2 = []

        subjects = {}
        for i, k in enumerate(df.index):
            subjects[str(k)] = i

        for i, df_i in enumerate([df_CA, df_KA]):

            for index, row in df_i.iterrows():

                for column in df_i.columns:

                    val = float(row[column])
                    # is a source
                    if val < 0.0 and index != "Total" and (not index.startswith("Δ")):
                        source.append(str(column) + "$[source]")
                        target.append(str(index))
                        value.append(-val)
                        types.append(subjects[str(index)])

                    # is a sink
                    elif val > 0.0 and index != "Total" and (not index.startswith("Δ")):
                        source2.append(str(column) + "$[sink]")
                        target2.append(str(index))
                        value2.append(val)
                        types2.append(subjects[str(index)])

        my_sankey_source = pd.DataFrame({"from": source,
                                         "to": target,
                                         "color_id": types,
                                         "value": value}).round(2)
        my_sankey_sink = pd.DataFrame({"to": source2,
                                       "from": target2,
                                       "color_id": types2,
                                       "value": value2}).round(2)

        return my_sankey_source, my_sankey_sink

    def plot_sankey(self, show_values=True, show_plot=True, colors=None, sorting_transactions=None, sorting_from=None, sorting_to=None, label_rot=0.0, norm_factor=1.0, dx_space=-30.1,
                    curvature=0.8, space_layout=False, band_space=3.2, fontsize=12, filename=None,
                    use_legend=False, alpha=0.8, flip_labels=False, add_braces=True, dy_labels=2.1):
        """
        plots a sankey diagram of the flow matrix

        :param show_values: boolean switch to plot the values of the edge weights
        :param sorting_transactions: list of column names to sort/filter to
        :param sorting_from: list of agent names (senders) to sort/filter
        :param sorting_to: list of agent names (receivers) to sort/filter
        :param show_plot: show the plot as window? default True. If False, figure is returned instead.
        :param colors: (optional) a list of colors (if None, default colors are chosen)
        :param label_rot: rotation of the labels 
        :param space_layout: if True, the layout is less flat and has more space 
        :param filename: filename, if given, the plot will be stored as a file
        :param use_legend: use a legend in the plot instead of direct labeling
        :return fig: figure object=
        """

        import matplotlib.pyplot as plt
        from ..misc.mpl_plotting import plot_sankey as sankey

        my_sankey_source, my_sankey_sink = self.convert_sankey_df(
            sorting=sorting_transactions)
        # print("source sankey", my_sankey_source)
        # print("sink sankey", my_sankey_sink)

        def convert_sort_df(df):
            if sorting_from is not None:
                df['from_prefix'] = df['from'].str.split('$').str[0]
                df['sort_key_from'] = df['from_prefix'].apply(
                    lambda x: sorting_from.index(x) if x in sorting_from else float('inf'))
                df = df.sort_values('sort_key_from').drop(
                    columns=['from_prefix', 'sort_key_from']).reset_index(drop=True)

            if sorting_to is not None:
                df['to_prefix'] = df['to'].str.split('$').str[0]
                df['sort_key_to'] = df['to_prefix'].apply(
                    lambda x: sorting_to.index(x) if x in sorting_to else float('inf'))
                df = df.sort_values('sort_key_to').drop(
                    columns=['to_prefix', 'sort_key_to']).reset_index(drop=True)
            return df

        my_sankey_source = convert_sort_df(my_sankey_source)
        my_sankey_sink = convert_sort_df(my_sankey_sink)

        fig = sankey([my_sankey_source, my_sankey_sink], show_values=show_values, show_plot=show_plot, colors=colors, label_rot=label_rot, filling_fraction=.4,
                     norm_factor=norm_factor, dx_space=dx_space, dx_left=15.5, dy_labels=dy_labels, fontsize=fontsize, curvature=curvature, space_layout=space_layout, band_space=band_space,
                     filename=filename, use_legend=use_legend, alpha=alpha, flip_labels=flip_labels, add_braces=add_braces)

        if show_plot:
            return

        if use_legend:
            plt.legend(loc=(1.04, 0))

        plt.tight_layout()
        return fig

    @property
    def capital_flow_data(self):
        """'shortcut' property. only returns the data from the capital account"""
        return self._flow_data[1]

    @property
    def capital_stock_data(self):
        return self._stock_data

    @property
    def current_flow_data(self):
        """'shortcut' property. only returns the dta fromt he current flow account"""
        return self._flow_data[0]

    def get_data(self):
        """
        Retrieve pickle-able data. This can later be used in from_data()
        """

        def recursively_convert_to_dict(d):
            if isinstance(d, defaultdict):
                d = {k: recursively_convert_to_dict(v) for k, v in d.items()}
            elif isinstance(d, dict):
                d = {k: recursively_convert_to_dict(v) for k, v in d.items()}
            return d
        return {"flow_data": [recursively_convert_to_dict(d) for d in self._flow_data], "stock_data": dict(self._stock_data)}

    def from_data(self, data, reset=True):
        """ 
        Applies data to the current instance
        :param data: the data to apply, should have keys 'flow_data' and 'stock_data', see get_data()
        :param reset: boolean switch to tell if the flowmatrix is reset before loading the new data (default True)
        """
        if reset:
            self.reset()

        def dict_to_nested_defaultdict(d):
            if isinstance(d, dict):
                return defaultdict(lambda: defaultdict(float),
                                   {k: dict_to_nested_defaultdict(v) for k, v in d.items()})
            else:
                return d
        flow_data = data.get("flow_data", [])
        self._flow_data = [dict_to_nested_defaultdict(fd) for fd in flow_data]
        self._stock_data = data.get("stock_data", {})
