
from .singleton import Singleton
from .world import World

from enum import Enum 
from collections import defaultdict 
import pandas as pd 

import warnings
import re 

class BMEntry:
    A = 0
    L = 1

class BalanceMatrix(Singleton):
    """
    The balance matrix conatins the flows between agents and the stock-flow consistency. It has the following structure:
    """
    
    """
    +----------------------------------------------------------------------------+
    | Balance Matrix                                                             |
    +===============+=========================+==================================+
    |               |     Agent 1             |           Agent 2                |
    +---------------+-----------+-------------+----------------+-----------------+
    |               |     A     |  L          |            A   | L               |
    +---------------+-----------+-------------+----------------+-----------------+
    | d(Deposits)   |    dx1    |             |          dx2   |                 |
    +---------------+-----------+-------------+----------------+-----------------+
    | d(Loans)      |           |  dy1        |                |                 |
    +---------------+-----------+-------------+----------------+-----------------+
    | d(Net Worth)  |           |  dnw1       |                |  dnw2           |
    +---------------+-----------+-------------+----------------+-----------------+
    | TOTAL         |    dx1    |  dy1 + dnw1 |           dx2  |  dnw2           |
    +---------------+-----------+-------------+----------------+-----------------+
    | Residual      |           | d(A1)-d(L1) |                | d(A2)-d(L2)     |
    +---------------+-----------+-------------+----------------+-----------------+
    """

    reset_count = 0

    def __init__(self):
        # constructor of the balance matrix

        if hasattr(self, "initialized"):  # only initialize once
            return
        
        self.initialized = True # initialized flag

        self.agents = []
        self.rows = [] 

        self.group = None 

        self._data = {}  # has two keys, A (Assets) and L (Liabilities), set reset() method
        self._start_data = {}  

        self.reset()     # defaultdict will take care of missing keys here

    
    def reset(self, verbose=False):
        """
        reset the data
        :param verbose: boolean (default False), triggers a reset warning if True
        """

        # starting data        
        self._start_data[BMEntry.A] = defaultdict(lambda: defaultdict(float))
        self._start_data[BMEntry.L] = defaultdict(lambda: defaultdict(float))

        # data to compare starting data with 
        self._data[BMEntry.A] = defaultdict(lambda: defaultdict(float))
        self._data[BMEntry.L] = defaultdict(lambda: defaultdict(float))

        if self.__class__.reset_count > 1:

            # warn the user every time this is being reset!
            if verbose:
                warnings.warn("BalanceMatrix has been reset")
        
        else:
            self.__class__.reset_count += 1


    def init_data(self, group=True, filter = None, verbose=False):
        """
        fill the balance sheet matrix with starting values (= balances at current point in time)
        """

        if verbose:
            print("BALANCE MATRIX CONSTRUCTION....")

        self.filter = filter 
        agents = World().agent_registry
        self.group = group 
        
        found_any  = False 
        for k, v in agents.items():
            
            # print(k,v)

            if len(v) > 0: found_any = True 
            for agent in v:
                if verbose:
                    print(" --> scan", agent)
                
                if self.group:
                    classname = agent.__class__.__name__ 
                else:
                    classname = str(agent)

                try:

                    # print(agent.balance_sheet.to_string())
                    # print(agent.balance_sheet.raw_data)
                    #if verbose:
                    #    print("***      ", dict(agent.balance_sheet.raw_data))

                    for name, vals_dict in dict(agent.balance_sheet.raw_data).items():
                        
                        self._start_data[BMEntry.A][name][classname] += vals_dict[0]
                        self._start_data[BMEntry.L][name][classname] += vals_dict[1]
                        self._start_data[BMEntry.L][name][classname] += vals_dict[2]

                        if name not in self.rows:
                            self.rows.append(name)

                    if classname not in self.agents:
                        
                        if (filter is None) or ((self.group and classname in filter) or ((not self.group) and agent.__class__.__name__ in filter)):
                            self.agents.append(classname)
                except Exception as e:
                    warnings.warn("An exception occured: %s" % (str(e)))
        
        # if (not found_any) and (filter is not None):
        #    warnings.warn("Did not find any agents for filter '%s'" % filter)

    def fill_data(self, filter=None):
        """
        fill the balance sheet matrix with the balances at the current point in time
        """

        if filter is None:
            filter = self.filter 
        else:
            self.filter = filter

        agents = World().agent_registry
        
        for k, v in agents.items():
            for agent in v:
                
                if self.group:
                    classname = agent.__class__.__name__ 
                else:
                    classname = str(agent)

                try:
                    for name, vals_dict in dict(agent.balance_sheet.raw_data).items():
                        
                        # print(name, vals_dict)
                        
                        self._data[BMEntry.A][name][classname] += vals_dict[0]
                        self._data[BMEntry.L][name][classname] += vals_dict[1]
                        self._data[BMEntry.L][name][classname] += vals_dict[2]

                        if name not in self.rows:
                            self.rows.append(name)

                    if classname not in self.agents:
                        
                        if (filter is None) or ((self.group and classname in filter) or ((not self.group) and agent.__class__.__name__ in filter)):
                            self.agents.append(classname)

                except:
                    pass # ignore the 'NoAgents'
                
    def to_dataframe(self, add_total_row=False, add_total_col=True, add_residual=True, residual_label="Net Wealth",
                     row_sorting=None, col_sorting=None, merge_accounts=None):
        """
        returns a pandas DataFrame representation of the balance matrix
        
        :param add_residual: add a residual calculation to the data frame (at the bottom) if set True (default False)
        :param residual_label (str): label of the 
        :param add_total: add total sum of assets and liabilities to the table 

        """
        # print("agent, rows", self.rows, self.agents)

        self.rows = sorted(self.rows)
        for a in ["NetWorth", "Net Worth", "NetWealth", "Net Wealth"]:
            if a in self.rows:
                self.rows.remove(a)
                self.rows.append(a) # append at the end 

        #print(self._data)
        #print(self.agents)
        #print(self.rows)

        all_data = []

        for agent in self.agents:
            data = defaultdict(lambda: {0: [], 1:[] })
            for row in self.rows:
                data[row][0] = self._data[BMEntry.A][row][agent] - self._start_data[BMEntry.A][row][agent]  
                data[row][1] = self._data[BMEntry.L][row][agent] - self._start_data[BMEntry.L][row][agent] 
            data_i = pd.DataFrame(data).T
            
            #print("data_i")
            #print(data_i)
            #print(data_i[0].sum(), data_i[1].sum(), data_i[0].sum() - data_i[1].sum())
            if add_residual:
                data_i.loc[residual_label] = {0: 0, 1: data_i[0].sum() - data_i[1].sum()}
            
            all_data.append(data_i)
        
        if len(all_data) == 0:
            raise RuntimeError("No data found. Have you executed init_data() and fill_data() yet?")

        df = pd.concat(all_data, axis=1, keys=self.agents)
        
        # print(df)

        df.columns = df.columns.set_levels(['A', 'L'], level=1) # , inplace=True)
        sum_A = df.iloc[:, df.columns.get_level_values(1)=='A'].sum(axis=1)
        sum_L = df.iloc[:, df.columns.get_level_values(1)=='L'].sum(axis=1)
        
        if add_total_col:
            df["Total"] = sum_A - sum_L
        if add_total_row:
            df.loc["Total"] = df.sum(axis=0)

        if row_sorting: #  is not None:
            # sort rows
            row_index = sorted(df.index, key=lambda x: (row_sorting.index(x) if x in row_sorting else len(row_sorting)))
            df = df.reindex(row_index)

        if col_sorting: # is not None:
            # sort cols
            col_index = sorted(
                    df.columns,
                    key=lambda x: (col_sorting.index(x[0]) if x[0] in col_sorting else len(col_sorting), x[1]))
            df = df[col_index]

        # merge accounts 
        combined_df = df.copy()
    
        return df
        
    def to_string(self, replace_zeros=True, justify="right", add_total_row=False, add_total_col=True, add_residual=True, round=2, flip=False,
                  residual_label="Net Wealth", row_sorting=None, col_sorting=None, color_values=False):
        """
        returns a string representation of the balance matrix
        """
        
        df = self.to_dataframe(add_total_row=add_total_row,add_total_col=add_total_col, add_residual=add_residual, residual_label=residual_label, row_sorting=row_sorting, col_sorting=col_sorting)
        df = df.round(round)

        if flip:
            df = df.T

        if replace_zeros:
            if round > 0:
                df = df.replace(0.0, " .- ")
            else:
                df = df.replace(0.0, "  ")
                
        df = df.rename(columns={"Total": "Total "})
        s = df.to_string(justify=justify)

        def convert_line(s, idx, r= " | "):
            # print("convert line", s)
            out = ""
            for j in range(len(s)):
                out += str(s[j])
                #if j in idx: 
                if j in idx and j < len(s)-1:
                    out += r
            return out
        
        if flip:
            lines = s.split("\n")
            hline = ["-"* len(lines[0])]
            hrule = ["-"* len(lines[0])]

            idx = [m.start() for m in re.finditer("\s+(?=\w)", lines[0])][1:]
            idx.append(max([len(str(i))-4 for i in df.T.index])) 

            mylines = ["   " + convert_line(lines[0],idx[:-1])]
            mylines += [convert_line(hrule[0], idx, r= "---")]
            i = 1 
            while i < len(lines) -2:
                mylines += [convert_line(lines[i],idx)]
                if i %2 == 0:
                    mylines += [convert_line(hline[0], idx, r = "-+-")]
                i += 1
            mylines += [convert_line(hrule[0], idx, r= "---")]
            mylines = ["|" + i + "|" for i in [convert_line(hrule[0], idx, r= "---")] + mylines]
            return_str = "\n".join(mylines) 
        
        else:
            # add separator lines 
            lines = s.split("\n")
            idx = [m.start() for m in re.finditer("L", lines[1])]
            idx.append(max([len(str(i)) for i in df.index])) 
            
            hline = [convert_line("-"* len(lines[0]), idx, r= "-+-")]
            hrule = [convert_line("-"* len(lines[0]), idx, r= "---")]
            
            # mylines = hrule
            mylines = [convert_line(i,idx) for i in lines[:2]] 
            mylines += hline
            mylines += [convert_line(i,idx) for i in lines[2:-1]]
            if add_total_row or add_residual:
                mylines += hline
            mylines += [convert_line(i,idx) for i in [lines[-1]]]
            mylines += hrule
            mylines = ["|" + i + "|" for i in hrule + mylines]
            return_str =  "\n".join(mylines)

        if color_values: 
            from matplotlib.colors import Normalize
            from matplotlib import cm 
            from colorama import Fore, Style, init
            init(autoreset=True)
            
            pattern = r"-?\d+(\.\d+)?"
            numbers = [float(match.group()) for match in re.finditer(pattern, return_str)]
            
            norm = Normalize(vmin=min(numbers), vmax=max(numbers))
            if isinstance(color_values, str):
                cmap_name=color_values
            else:
                cmap_name = "coolwarm"
            
            cmap = cm.get_cmap(cmap_name)
            
            # Function to convert a number to a color based on its size
            def colorize(match):
                number_str = str(match.group())
                number = float(match.group())
                # Get normalized color for the number
                color = cmap(norm(number))
                # Convert color to RGB with values in range 0-255 for colorama
                r, g, b, _ = [int(255 * c) for c in color]
                # Format with colorama escape codes
                return f"\033[38;2;{r};{g};{b}m{number_str}{Style.RESET_ALL}"
            
            # Apply coloring
            return_str = re.sub(pattern, colorize, return_str)
        return return_str