__version__ = "0.4"
__author__ = "Thomas Baldauf"
__email__ = "thomas.baldauf@dlr.de"
__license__ = "MIT"
__birthdate__ = '15.11.2021'
__status__ = 'prod'  # options are: dev, test, prod

"""
Small plotting library for bar plots and line plots. Usage is optional / voluntary...
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pylab import rcParams
import matplotlib.patches as mpatches
import matplotlib as mpl
from functools import partial
import itertools


def add_value_labels(ax, spacing=.9, fmt="{0:f}", stacked=False, legend="best"):
    """Add labels to the end of each bar in a bar chart.

    Arguments:
        ax (matplotlib.axes.Axes): The matplotlib object containing the axes
            of the plot to annotate.
        spacing (int): The distance between the labels and the bars.
    """

    my_figure = ax.get_figure()
    L = abs(ax.get_ylim()[1] - ax.get_ylim()[0])
    min_height = L*0.015
    small_height = L*0.05
    # print("min_heigth", min_height)

    # For each bar: Place a label
    for rect in ax.patches:
        # Get X and Y placement of label from rect.
        outside = False

        if abs(rect.get_height()) <= min_height:
            continue

        elif abs(rect.get_height()) <= small_height:
            outside = True
            y_value = rect.get_y() + rect.get_height() * 0.1
            x_value = rect.get_x() + rect.get_width() * 1.05

        # print("rect height", rect.get_height())
        else:
            y_value = rect.get_y() + rect.get_height() * 0.41
            x_value = rect.get_x() + rect.get_width() * 0.50

        # print("yvalue",x_value,y_value)
        # Number of points between bar and label. Change to your liking.

        space = spacing
        # Vertical alignment for positive values
        va = 'bottom'

        # If value of bar is negative: Place label below bar
        if y_value < 0:
            # Invert space to place label below
            space *= -1
            # Vertically align label at top
            va = 'top'

        # Use Y value as label and format number with one decimal place
        # label = "{:.1f}".format(y_value)
        hgt = rect.get_height()
        if hgt != 0.0:
            label = fmt.format(hgt)
        else:
            label = ""

        # Create annotation
        if outside:
            rot = 33
            ha = "left"
        else:
            rot = 0
            ha = "center"

        t = ax.annotate(
            label,  # Use `label` as label
            (x_value, y_value),  # Place label at end of the bar
            xytext=(0, space),  # Vertically shift label by `space`
            textcoords="offset points",  # Interpret `xytext` as offset in points
            ha=ha,  # Horizontally center label
            va=va)  # Vertically align label differently for

        if not outside:
            t.set_bbox(dict(facecolor='white', alpha=0.175,
                       edgecolor='white', linewidth=0.0))
        # positive and negative values.


def matplotlib_barplot(data, xlabel, ylabel, title, color="indigo", hatches=None, size=(6, 5), tight_layout=True,
                       fmt="{0:f}", stacked=False, show_labels=True, legend="best", show=True, yerr=None, ax=None):
    """
    Creates a bar plot in matplotlib. The 'macro' plotting style will be used if it is set up

    :param xlabel: x axis  label
    :param ylabel: y axis  label
    :param title: plot title
    :param color: plot color
    :param hatches: hatch pattern
    :param size: figure size
    :param tight_layout: re-format plot
    :param fmt: number format, {0:d} or {0:f} for example
    :param stacked: stack column? or show side-by side (boolean)
    :paarm show_labels: show the column labels at bottom? (boolean)
    :param legend: location of legend, or 'off'
    :param show: if False, only the figure is returned. Else plot is shown
    :param ax: axis to plot on
    :param yerr: data for error bar (if any)
    """

    # if legend is 'off', no legend is shown

    plt.rcParams['axes.ymargin'] = .4
    rcParams['figure.figsize'] = size

    try:
        # if 'macro' plotting style is installed -> use it!
        plt.style.use("macro")
    except:
        pass

    plt.grid(alpha=0.0)

    if not isinstance(data, pd.DataFrame):
        if ylabel is None:
            ylabel = "Data"
        data = pd.DataFrame({ylabel: data})
        data.set_index(ylabel)

    # plt.figure(figsize=size)
    if ax is None:
        ax = plt.gca()

    data.plot.bar(color=color, stacked=stacked, ax=ax,
                  yerr=yerr, capsize=6, ecolor="black")

    if hatches is not None:
        n_cols = len(data)
        myhatches = itertools.cycle(hatches)
        for i, bar in enumerate(plt.gca().patches):
            if i % n_cols == 0:
                hatch = next(myhatches)
            bar.set_hatch(hatch)

    plt.gca().set_title(title)
    # plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=True)
    plt.ylabel(ylabel)
    plt.xlabel(xlabel)

    if show_labels:
        add_value_labels(plt.gca(), fmt=fmt)

    if legend != "off":
        plt.legend(loc=legend)
    else:
        old_legend = plt.gca().get_legend()
        if old_legend is not None:
            old_legend.remove()
    if tight_layout:
        plt.tight_layout()

    # fix reverse labeling
    # https://stackoverflow.com/questions/46908085/bar-plot-does-not-respect-order-of-the-legend-text-in-matplotlib
    #
    ax = plt.gca()
    h, l = ax.get_legend_handles_labels()
    if legend != "off":
        ax.legend(h[::-1], l[::-1], loc=legend, framealpha=0.0)

    if show:
        plt.show()
    else:
        return plt.gcf()


def matplotlib_lineplot(data, xlabel=None, ylabel=None, title="", xlim=None, ylim=None, color="indigo", legend="best", marker=None, show=False, ax=None):
    """
    creates a line plot in matplotlib. The 'macro' plotting style will be used if it is set up

    :param xlabel: x axis  label
    :param ylabel: y axis  label
    :param title: plot title
    :param color: plot color
    :param xlim, ylim: tuples of plot limits
    :param legend: location of legend, or 'off', or 'outside'
    :param show: if False, only the figure is returned. Else plot is shown
    :param ax: axis to plot on

    when overlapping barplot and lineplot, make sure you add an additional axis, see
    https://stackoverflow.com/questions/42948576/pandas-plot-does-not-overlay

    """
    try:
        plt.style.use("macro")
    except:
        pass

    if not isinstance(data, pd.DataFrame):
        if ylabel is None:
            label = "Data"
        data = pd.DataFrame({ylabel: data})
        data.set_index(ylabel)

    if ax is None:
        ax = plt.gca()

    for col in data.columns:
        ax.plot(data[col], color=color, marker=marker, label=col)

    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.title(title)

    if legend != "off":
        plt.legend(loc=legend)
    elif legend == "outside":
        plt.legend(loc=(1.02, 0))

    elif legen == "off":
        old_legend = plt.gca().get_legend()
        if old_legend is not None:
            old_legend.remove()

    if ylabel is not None:
        plt.ylabel(ylabel)

    if xlabel is not None:
        plt.gca().set_xlabel(xlabel)

    if ylabel is not None:
        plt.gca().set_ylabel(ylabel)
    plt.rcParams['axes.ymargin'] = .4

    if show:
        plt.show()
    else:
        return plt.gcf()


class Point:
    """
    a rudimentary point in 2d space
    """

    def __init__(self, x, y):
        self._x = x
        self._y = y
        self.offset_y = 0  # offset for bands

    @property
    def x(self):
        return self._x

    @property
    def y(self):
        return self._y + self.offset_y


class Label:
    """
    a rudimentary node label
    """

    names = []

    def __init__(self, x, y, text, flipped, layer_idx):
        self.x = x
        self.y = y
        self.text = text
        self.flipped = flipped
        self.layer_idx = layer_idx
        if text in self.__class__.names:
            # raise RuntimeError("Name already given")
            pass
        else:
            self.__class__.names.append(self.text)

    def shift_up(self):
        # shift the label upwards a bit
        self.y += 1.8  # 1.8
        self.x -= 0.2  # re-center


def casteljau(t, b0, b1, b2, b3):
    """
    casteljau bezier spline for 4 points.

    :param b0-b3: the point of the bezier curve as [x,y] array
    :param t: parameter for position on the curve, float
    """
    b = np.multiply((-b0 + 3*b1 - 3*b2 + b3), np.power(t, 3)) + np.multiply(
        (3*b0 - 6*b1 + 3*b2), np.power(t, 2)) + np.multiply((-3*b0 + 3*b1), t) + b0
    return b


def bezier(pstart: Point, pend: Point, npoints=100, curv=None):
    """
    compute the path of a 1d bezier curve

    :param pstart: starting point of the bezier curve
    :param pend: end point of the bezier curve
    :param npoints: number of discrete points to retrun
    :param curv: optionally you can give a curvature parameter here (float) default None

    :return: x,y coordinate tuple as ndarray of shape (npoints,)
    """

    # de-casteljau b-spline
    tarr = np.linspace(0, 1, npoints)

    if pstart.x < pend.x:
        pleft = pstart
        pright = pend
    else:
        pleft = pend
        pright = pstart

    if curv is None:
        curv = .27*(pright.x-pleft.x)

    b0 = np.array([pleft.x, pleft.y])
    b1 = np.array([pleft.x+curv, pleft.y])
    b2 = np.array([pright.x-curv, pright.y])
    b3 = np.array([pright.x, pright.y])

    spline = partial(casteljau, b0=b0, b1=b1, b2=b2, b3=b3)

    b = np.array([spline(t) for t in tarr])

    return b[:, 0], b[:, 1]


def plot_band(pstart: Point, pend: Point, npoints=100, curv=None, weight=2):
    """
    computes a band of two bezier curves

    :param pstart: starting point of the bezier curve
    :param pend: end point of the bezier curve
    :param npoints: number of discrete points to retrun
    :param curv: optionally you can give a curvature parameter here (float) default None
    :param weight: width of the band

    :return: x,y1,y2 coordinate 3-tuple of the band
    """

    w = weight  # width of the band
    d = 0  # horizontal shift of the bands realtive to each other
    e = 0  # distance of  '<' and '>' arc

    pstart1 = Point(pstart.x-d+.2*e, pstart.y)
    pend1 = Point(pend.x-d-e, pend.y)

    pstart2 = Point(pstart.x+d+.2*e, pstart.y+w)
    pend2 = Point(pend.x+d-e, pend.y+w)

    bezier1 = bezier(pstart1, pend1)
    bezier2 = bezier(pstart2, pend2)

    x = np.array(list(bezier1[0]))  # + list(bezier_end1[0]))
    y1 = np.array(list(bezier1[1]))  # + list(bezier_end1[1]))
    y2 = np.array(list(bezier2[1]))  # + list(bezier_end2[1]))

    return x, y1, y2


def plot_sankey(data, title="", show_plot=True,
                show_values=False,
                round_values=0,
                colors=None,
                label_rot=0,
                fontsize=12,
                # determines how 'bent' the bands look (in relative vertical offset)
                curvature=0.8,  # 1.8,
                separation=0.8,
                norm_factor=1.0,
                min_width=0.1,
                dy=5,
                dx=55,
                dx_left=1.0,  # 15.5,  # 0.2,  # 1.0, # 5.2
                dx_space=1.0,  # -0.5,  # -30.1,  # -0.5,
                dy_space=0.05,  # 25.0,
                dy_labels=2.1,  # 1.3,
                alpha=0.8,
                filling_fraction=0.8,  # 0.3,
                space_layout=False, band_space=3.2,
                filename=None,
                use_legend=False,
                flip_labels=True,
                add_braces=False):
    """
    plots a sankey diagram from data

    :param data: list of pandas dataframes, of the following format. length at least two.
    :param title: title of the plot
    :param show_plot: boolean switch to show plot window (default True). If False, figure is returned
    :param show_values: boolean switch to print numerical values of data as text label (default True)
    :param colors: (optional) a list of colors (if None, default colors are chosen)
    :param label_rot: rotation of the labels (default 65 degrees)
    :param fontsize: size of label font
    :param separation: separation point of the label along the bands, between 0 and 1,
    :param norm_factor: normalization factor for width
    :param min_width: minimum width of bands
    :param dy: vertical distance of bands
    :param dx: horizontal distance of bands
    :param dx_space: space between layers
    :param filling_fraction: fraction by which the bands align (scaling with dy)

    :return fig: None or matplotlib figure
    """

    """
    Example:

    +----------+----------+----------+-------------+
    |  from    |   to     |  value   | color_id    |
    +----------+----------+----------+-------------+
    |    A     |  C       | 1.0      |    0        |
    +----------+----------+----------+-------------+
    |    A     |  D       | 2.0      |    1        |
    +----------+----------+----------+-------------+
    |    B     |  D       | 3.0      |    2        |
    +----------+----------+----------+-------------+

    +----------+----------+----------+-------------+
    |  from    |   to     |  value   | color_id    |
    +----------+----------+----------+-------------+
    |    C     |  E       | 5.0      |    1        |
    +----------+----------+----------+-------------+
    |    C     |  F       | 6.0      |    2        |
    +----------+----------+----------+-------------+
    |    D     |  G       | 8.0      |    3        |
    +----------+----------+----------+-------------+

    """

    Label.names = []  # reset labels
    if not isinstance(data, list):
        data = [data]  # convert to list

    # define default colors
    if colors is None:
        colors = [
            "silver",
            "lightsteelblue",
            "yellowgreen",
            "darkkhaki",
            "darksalmon",
            "skyblue",
            "lightgray",
            "wheat",
            "lightgreen",
            "lavender",
            "thistle",
            'burlywood'
        ]

    # 1. normalize the values to reasonable width of bands
    maxval = 0.0  # stores maximum band width
    maxrows = 2  # stores number of rows in the data
    for layer in data:
        maxrows = max(maxrows, len(list(layer.iterrows())))
        for idx, row in layer.iterrows():
            val = abs(float(row["value"]))
            maxval = max(val, maxval)
    norm_factor = 1.0 / maxval if maxval > 1 else 1.0

    # compute global dy_min for consistent band widths
    H = 300  # height constant
    all_dy = []
    for layer in data:
        dy1 = H / len(layer["from"].unique())
        dy2 = H / len(layer["to"].unique())
        all_dy.append(max(dy1, dy2))
    global_dy_min = max(all_dy)

    def compute_band_width(value):
        return value * norm_factor * (0.8 * global_dy_min) * (0.5 * filling_fraction)

    # prepare figure
    figscale = 0.4
    # fig = plt.figure(figsize=(figscale*11.4*len(data),
    #                  figscale*1.5*(3+maxrows)))
    fig = plt.figure(figsize=(figscale*15.4*len(data),
                     figscale*1.2*(3+maxrows)))

    # 2. plot the bezier patches
    i = 10  #
    j0 = 6  # < starting position (i,j0)

    # define some points for the unique nodes
    my_points = {}  # keys will be names, value will be Points
    my_labels = {}  # keys will be names, value will be Labels

    class Band:
        """
        A sankey band class
        """
        instances = []

        def __init__(self, n_from, n_to, w, value=None, color="steelblue", barheight=None, name=None, reversed=False, from_name=None, to_name=None):
            """
            Creates a new band

            Args:
                n_from (point): starting point
                n_to (point): end point
                w (float): width
                color (str, optional): band color. Defaults to "steelblue".
                barheight (float, optional): height of auxiliary bar. Defaults to None.
                name (str, optional): name of the band. Defaults to None.
            """
            self.n_from = n_from
            self.n_to = n_to
            self.shift_y = 0.0
            self.width = 2 * w
            self.osy1 = 0.0
            self.osy2 = 0.0
            self.barheight = barheight
            self.__class__.instances.append(self)
            self.name = name
            self.reverse = reversed

            self.vbound = None
            # from_point = n_from
            # to_point = n_to

            from_point = Point(fromnode.x, fromnode.y)
            to_point = Point(tonode.x, tonode.y)

            # print("n_from, n_to", n_from, n_to)
            self.vbound = ((from_point.x, to_point.x),
                           (from_point.y + self.width, to_point.y + self.width),
                           (from_point.y, to_point.y))

            self.value = value
            self.from_name = from_name
            self.to_name = to_name

            self.color = color
            # if reversed:
            #     self.color = "yellow" # < for debugging

            self.mpl_obj = None

        def plot(self, show_frame=None):
            """
            plots the band

            Args:
                show_frame (str): default  None, or 'left'/'right'/'both'
            """

            fromnode = self.n_from
            tonode = self.n_to

            if self.reverse:
                dosy = self.width*1.2
            else:
                dosy = 0
            osy1 = self.osy1
            osy2 = self.osy2

            w = self.width
            point_lower1 = Point(fromnode.x, fromnode.y+osy1 - dosy)
            point_lower2 = Point(tonode.x, tonode.y + osy2 - dosy)

            # make the curves more smooth
            x, vals_lower, vals_upper = plot_band(
                point_lower1, point_lower2, weight=w)
            sorted_indices = np.argsort(x)
            x = np.array(x)[sorted_indices]
            vals_lower = np.array(vals_lower)[sorted_indices]
            vals_upper = np.array(vals_upper)[sorted_indices]

            argmin_x = np.argmin(x)
            argmax_x = np.argmax(x)
            bordercolor = "darkgray"

            if not self.reverse:
                self.mpl_obj = plt.fill_between(
                    x, vals_lower, vals_upper, color=self.color, alpha=alpha, edgecolor='none')  # , alpha=0.1)
                # self.color = self.mpl_obj.get_facecolor()[0]
            else:

                from matplotlib.patches import PathPatch
                from matplotlib.path import Path

                band_top = list(zip(x, vals_upper))
                # reverse to follow counterclockwise
                band_bottom = list(zip(x[::-1], vals_lower[::-1]))

                # 2. Compute arc edges (left + right)
                dy = abs(vals_upper[0] - vals_lower[0]) + dosy
                dw = abs(vals_upper[0] - vals_lower[0])
                w = 0.25*self.width
                width1 = 4 + w*0.7  # outer radius
                width2 = .5 + w*0.05  # inner radius

                # LEFT arc (at fromnode)
                pnt_left = Point(fromnode.x, fromnode.y + osy1 - dosy)
                theta1 = np.linspace(np.radians(270), np.radians(90), 30)
                arc_left_outer = [(pnt_left.x + width1 * np.cos(t),
                                   pnt_left.y + 0.5 * dy + 0.5 * dy * np.sin(t)) for t in theta1]

                theta2 = np.linspace(np.radians(90), np.radians(270), 30)
                arc_left_inner = [(pnt_left.x + width2 * np.cos(t),
                                   pnt_left.y + 0.5 * (dosy + dw) + 0.5 * (dosy - dw) * np.sin(t)) for t in theta2]

                # RIGHT arc (at tonode)
                pnt_right = Point(tonode.x, tonode.y + osy2 - dosy)
                theta3 = np.linspace(np.radians(90), np.radians(270), 30)
                arc_right_outer = [(pnt_right.x - width1 * np.cos(t),
                                    pnt_right.y + 0.5 * dy + 0.5 * dy * np.sin(t)) for t in theta3]

                theta4 = np.linspace(np.radians(270), np.radians(90), 30)
                arc_right_inner = [(pnt_right.x - width2 * np.cos(t),
                                    pnt_right.y + 0.5 * (dosy + dw) + 0.5 * (dosy - dw) * np.sin(t)) for t in theta4]

                # left arc outer → band top → right arc outer → right arc inner → band bottom → left arc inner
                # + arc_left_inner[::-1] + band_top
                vertices = (
                    arc_left_outer +
                    band_top +
                    arc_right_outer +
                    arc_right_inner +
                    band_bottom +
                    arc_left_inner +
                    [arc_left_outer[0]]  # close
                )

                codes = [Path.MOVETO] + [Path.LINETO] * \
                    (len(vertices) - 2) + [Path.CLOSEPOLY]

                path = Path(vertices, codes)
                patch = PathPatch(path, facecolor=self.color,
                                  edgecolor='none', lw=0, alpha=alpha, zorder=0)
                plt.gca().add_patch(patch)

                self.barheight = dw * 1.001

            # plot auxiliary bars
            # NOTE not used: add rectangles in black color
            # start_rect = patches.Rectangle((x[argmin_x], vals_lower[argmin_x]), 0.5, w, color="red",alpha=.5 )
            # plt.gca().add_patch(start_rect)

            def draw_frame_bar(x_idx, lower=None, upper=None, x_center=None):
                ddx = 2
                if self.barheight is None:
                    self.barheight = (
                        vals_upper[x_idx] - vals_lower[x_idx]) * 1.001

                if lower is None:
                    lower = vals_lower[x_idx] - dosy
                if upper is None:
                    upper = lower + self.barheight - dosy
                if x_center is None:
                    x_center = x[x_idx]

                plt.fill_between([x_center - ddx, x_center + ddx],
                                 [lower]*2, [upper]*2, color="gray")
                plt.plot([x_center - ddx, x_center + ddx], [lower]
                         * 2, color=bordercolor, linewidth=1)
                plt.plot([x_center - ddx, x_center + ddx], [upper]
                         * 2, color=bordercolor, linewidth=1)

            if not self.reverse:
                pass
                # if show_frame in ("left", "both"):
                #     draw_frame_bar(argmin_x)

                # if show_frame in ("right", "both"):
                #     draw_frame_bar(argmax_x)

                # return the (min, max), (min, max) of lower and upper line of the band
                self.vbound = (x[argmin_x], x[argmax_x]
                               ), (vals_upper[argmin_x], vals_upper[argmax_x]), (vals_lower[argmin_x], vals_lower[argmax_x])

            else:
                if show_frame in ("left", "both"):
                    x_center = np.max([p[0] for p in arc_left_outer])
                    lower = np.max([p[1] for p in arc_left_inner])
                    upper = np.max([p[1] for p in arc_left_outer])
                    # draw_frame_bar(None, lower, upper, x_center)

                if show_frame in ("right", "both"):
                    x_center2 = np.min([p[0] for p in arc_right_outer])
                    lower2 = np.max([p[1] for p in arc_right_inner])
                    upper2 = np.max([p[1] for p in arc_right_outer])
                    # draw_frame_bar(None, lower2, upper2, x_center)

                # return (x_center, x_center2), (band_top[argmin_x][1], band_top[argmax_x][1])
                self.vbound = (x_center, x_center2), (upper,
                                                      upper2), (lower, lower2)

            return self.vbound

        def maxy(self):
            w = 2 * self.width
            return max([self.n_from.y + self.osy1 + w, self.n_to.y + self.osy2 + w])

    from collections import defaultdict
    bands = defaultdict(lambda: [])
    H = 300  # constant height parameter

    for indx, layer in enumerate(data):
        offset_wy = 0
        j = j0 + offset_wy

        k = 0
        dy1 = H/len(layer["from"].unique())
        dy2 = H/len(layer["to"].unique())
        dy_min = max(dy1, dy2)

        for name in layer["from"].unique():

            if name not in my_points:
                new_point = Point(i+dx_space, j)
                my_points[name] = new_point

                band_width = 0
                # for _, row in layer.iterrows():
                #     if row["from"] == name:
                #         band_width += float(abs(row["value"])) * \
                #             norm_factor * (0.8*dy_min) * filling_fraction
                for indx2, layer2 in enumerate(data):
                    for _, row in layer2.iterrows():
                        if row["from"] == name:
                            #     for inst in Band.instances:
                            #         if inst.n_from == name:
                            #             band_width += Band.barheight
                            band_width += float(abs(row["value"]))
                        elif row["to"] == name and row["value"] < 0:
                            band_width += float(abs(row["value"]))

                band_width *= norm_factor * \
                    (0.8*global_dy_min) * filling_fraction

                # draw a bracket
                if add_braces:
                    dddx = .75
                    wdddx = 1.72
                    if not flip_labels:
                        dddx = - dddx
                        wdddx = - wdddx
                    bracket_curve_x = [new_point.x-wdddx, new_point.x -
                                       dddx-wdddx, new_point.x-dddx-wdddx, new_point.x-wdddx]
                    bracket_curve_y = [
                        new_point.y, new_point.y, new_point.y + band_width, new_point.y + band_width]
                    plt.plot(bracket_curve_x, bracket_curve_y,
                             color="gray", linewidth=1.5)

                if space_layout:
                    j += min(global_dy_min, band_width + 2)
                else:
                    # min(dy_min, band_width + 2) # + 20 # max(dy1, band_width * 1.02)
                    j += band_width + band_space

            myname = name
            if name not in my_labels:
                my_labels[name] = Label(
                    new_point.x-dx_left, new_point.y-.1-dy_labels+0.44*band_width, myname, flipped=flip_labels,
                    layer_idx=indx)
            else:
                my_labels[name].shift_up()

            k += 1

        if indx < len(data)-1:
            offset_wy = -curvature*int(len(layer["to"].unique())/2)
        else:
            offset_wy = 0

        j = j0 + offset_wy
        k = 0

        for idx, name in enumerate(layer["to"].unique()):

            if name not in my_points:
                new_point = Point(i+dx, j)
                my_points[name] = new_point

                # band_width = 0
                # for _, row in layer.iterrows():
                #     if row["to"] == name:
                #         band_width += float(abs(row["value"]))*norm_factor * \
                #             (0.8*dy_min) * filling_fraction
                band_width = 0
                for indx2, layer2 in enumerate(data):
                    for _, row in layer2.iterrows():
                        if row["to"] == name:
                            #     for inst in Band.instances:
                            #         if inst.n_from == name:
                            #             band_width += Band.barheight
                            band_width += float(abs(row["value"]))
                        elif row["from"] == name and row["value"] < 0:
                            band_width += float(abs(row["value"]))

                band_width *= norm_factor * \
                    (0.8*global_dy_min) * filling_fraction

                # draw a bracket
                if add_braces:
                    dddx = -.75
                    wdddx = -1.72
                    if not flip_labels:
                        dddx = - dddx
                        wdddx = - wdddx

                    bracket_curve_x = [new_point.x-wdddx, new_point.x -
                                       dddx-wdddx, new_point.x-dddx-wdddx, new_point.x-wdddx]
                    bracket_curve_y = [
                        new_point.y, new_point.y, new_point.y + band_width, new_point.y + band_width]
                    plt.plot(bracket_curve_x, bracket_curve_y,
                             color="gray", linewidth=1.5)

                if space_layout:
                    j += min(global_dy_min, band_width + 2)
                else:
                    # min(dy_min, band_width + 2) # + 20 # max(dy1, band_width * 1.02)
                    j += band_width + band_space

            myname = name  # use alias myname
            if name not in my_labels:
                # plt.annotate(name,(new_point.x+.5*dx_space,new_point.y-.1))
                my_labels[name] = Label(
                    new_point.x+.5*dx_space, new_point.y-.1-dy_labels+0.44*band_width, myname, flipped=flip_labels,
                    layer_idx=indx+1)
            else:
                my_labels[name].shift_up()

            k += 1

        for idx, row in layer.iterrows():

            color_idx = row["color_id"]

            val = float(row["value"])
            is_reversed = val < 0
            value = abs(val)

            from_name, to_name = row["from"], row["to"]

            fromnode = my_points[from_name]
            tonode = my_points[to_name]

            name = None
            if idx < len(layer.index) - 1:
                name = row["to"]

            my_band = Band(fromnode, tonode,  w=compute_band_width(value), value=val, color=colors[color_idx % len(colors)], barheight=None,
                           name=name, reversed=is_reversed, from_name=from_name, to_name=to_name)

            bands[fromnode].append(my_band)
            bands[tonode].append(my_band)

        i += dx

        # iterate through the dataframe rows
        for point, b in bands.items():
            offset1 = 0.0
            offset2 = 0.0
            for bi in b:
                if point == bi.n_from:
                    bi.osy1 = offset1
                    offset1 += bi.width

                elif point == bi.n_to:
                    bi.osy2 = offset2
                    offset2 += bi.width
        j = -offset_wy

    # Step 1: Find the minimum x-coordinate among all labels
    min_x = min(label.x for label in my_labels.values())
    Delta_x = 13  # Define the amount to shift left-aligned labels by
    if add_braces:
        Delta_x += 2
    Delta_y = .15  # 10+2.2 # 40*2.2
    Delta_y2 = 0   # 10+1.2 # 40*1.2

    # Step 2: align the labels in a better way
    def align_labels(labels, min_distance_y=0.0001, max_distance_x=2.5):
        return labels

        # NOTE below code is depricated. remove in future versions
        # # Sort labels by y position to process them in vertical order
        # sorted_labels = sorted(labels.items(), key=lambda item: item[1].y)

        # # Adjust y positions to avoid overlap when x positions are close
        # for name, label in sorted_labels:
        #     for name2, label2 in sorted_labels:

        #         # Check if labels are close in both x and y positions
        #         if abs(label.x - label2.x) < 1.1*max_distance_x and (label.y - label2.y) < 1.1*min_distance_y and label.y > label2.y:
        #             # Move the current label upwards to create enough space
        #             # print("move", label.text)
        #             label.y = label2.y + min_distance_y
        #             label2.y = label2.y - 0.1*min_distance_y
        # return {key: label for key, label in sorted_labels}

    my_labels = align_labels(my_labels)

    # Step 3: Plot labels with different alignments based on position
    min_x = float("inf")
    max_x = 0
    for label in my_labels.values():
        min_x = min(min_x, label.x)
        max_x = max(max_x, label.x)

    for label in my_labels.values():

        is_source_sink = "$[source]" in label.text or "$[sink]" in label.text
        if use_legend and not is_source_sink:
            label.text = ""
            continue

        # label.text += str("IDX ") + str(label.layer_idx)
        cond = label.layer_idx == 0
        # if flip_labels:
        #     cond = label.layer_idx > 0
        # else:
        #     cond = False  # label.layer_idx

        if cond:
            label.flipped = True

            # right-align and shift leftmost labels by Delta_x inside the diagram
            plt.gca().text(
                label.x + 4 + Delta_x,  # Shift right by Delta_x
                label.y + Delta_y + 0.8*dy_labels if is_source_sink else label.y + Delta_y2,
                label.text.replace("$[source]", "").replace("$[sink]", ""),
                rotation=label_rot,
                fontsize=fontsize+2 if is_source_sink else fontsize,
                weight="bold" if is_source_sink else 'normal',
                ha="left"  # Align text to the left
            )
            label.x = label.x + 4 + Delta_x

        else:

            # left-align other labels, ending at label.x
            plt.gca().text(
                label.x + 27 - Delta_x,
                label.y + Delta_y + 0.8*dy_labels if is_source_sink else label.y + Delta_y2,
                label.text.replace("$[source]", "").replace("$[sink]", ""),
                rotation=label_rot,
                fontsize=fontsize+2 if is_source_sink else fontsize,
                weight="bold" if is_source_sink else 'normal',
                ha="right"  # Right-align so text ends at label.x
            )
            label.x = label.x + 27 - Delta_x
    # 3. return the figure or show window
    plt.axis("off")

    # Band = type("Band", (), {})
    # Band.instances = []

    max_x = 0
    max_y = 0
    for point in my_points.values():
        max_x = max(max_x, point.x)
        max_y = max(max_y, point.y)
    for label in my_labels.values():
        max_x = max(max_x, label.x)
        max_y = max(max_y, label.y)
    maxy = max([b.maxy() for b in Band.instances])
    plt.xlim([3, max_x+5])
    plt.ylim([3, maxy+3])

    # plot all bands
    min_x = -1
    max_x = 1
    min_y = -1
    max_y = 1

    N_b = len(Band.instances)
    for b, band in enumerate(Band.instances[::-1]):
        bounds_x, bounds_y, _ = band.plot(show_frame="both")
        min_x = min(min_x, bounds_x[0])
        max_x = max(max_x, bounds_x[1])
        min_y = min(min_y, bounds_y[0])
        max_y = max(max_y, bounds_y[1])

    # plot helper bars and braces for each band

    my_points = {}
    i = 0
    j_from = 0
    j_to = defaultdict(int)
    D1 = 5
    D2 = 50
    dW = 100
    for layer in data:
        for name in layer["from"].unique():
            if name not in my_points:
                my_points[name] = Point(i, j_from)
                j_from += D1  # increased spacing to reduce overlap
        i += dW
        for name in layer["to"].unique():
            if name not in my_points:
                my_points[name] = Point(i, j_to[i])
                j_to[i] += D2  # increased spacing to reduce overlap
        dW += 400

    for layer in data:
        for _, row in layer.iterrows():
            from_name = row["from"]
            to_name = row["to"]
            value = float(row["value"])
            from_point = my_points[from_name]
            to_point = my_points[to_name]
            Band(from_name, to_name, value, from_point, to_point)

    ddx = 1
    used_y_ranges = defaultdict(list)
    for k in my_points.keys():
        lowest = float("inf")
        highest = 0
        x_center = None

        for b in Band.instances:
            match = False
            if b.from_name == k:
                match = True
                x_center = b.vbound[0][0]
                lower = b.vbound[2][0]
                upper = b.vbound[1][0]
            elif b.to_name == k:
                match = True
                x_center = b.vbound[0][1]
                lower = b.vbound[2][1]
                upper = b.vbound[1][1]

            if match:
                lowest = min(lowest, lower)
                highest = max(highest, upper)

        if x_center is not None:
            overlap_found = True
            while overlap_found:
                overlap_found = False
                for (lo, hi) in used_y_ranges[x_center]:
                    if not (highest < lo or lowest > hi):
                        lowest += 5
                        highest += 5
                        overlap_found = True
                        break
            used_y_ranges[x_center].append((lowest, highest))

            plt.fill_between([x_center - ddx, x_center + ddx],
                             [lowest]*2, [highest]*2, color="gray", alpha=1.0)
            plt.plot([x_center - ddx, x_center + ddx],
                     [lowest]*2, color="gray", linewidth=1)
            plt.plot([x_center - ddx, x_center + ddx],
                     [highest]*2, color="gray", linewidth=1)

        if add_braces and x_center is not None:
            # dddx = .75
            # wdddx = 2.1  # -1.1825
            # bracket_curve_x = [x_center-wdddx, x_center -
            #                    dddx-wdddx, x_center-dddx-wdddx, x_center-wdddx]
            # bracket_curve_y = [lowest, lowest, highest, highest]
            # plt.plot(bracket_curve_x, bracket_curve_y,
            #          color="gray", linewidth=1.5)
            pass

    if show_values:

        positions = []
        for b in Band.instances:

            x_center1, lower1, upper1 = b.vbound[0][0], b.vbound[2][0], b.vbound[1][0]
            x_center2, lower2, upper2 = b.vbound[0][1], b.vbound[2][1], b.vbound[1][1]

            dosx = -1
            dosy = 1
            if b.reverse:
                dosy = b.width*1.2

            try:
                str_val = f"{abs(b.value):.{round_values}f}"
            except:
                str_val = ""

            share = 0.49
            new_x = share*(x_center1 +
                           x_center2) + dosx
            new_y = share*0.5*(lower1 + upper1) + (1-share) * \
                0.5*(lower2 + upper2)-dosy

            eps = 5.0
            overlapping = True
            while overlapping:
                overlapping = False
                for p in positions:
                    if np.sqrt((p[0]-new_x)**2 + (p[1]-new_y)**2) < eps:
                        new_x += np.random.normal(2.2, 0.3)
                        new_y += np.random.normal(.2, 0.03)
                        overlapping = True
            positions.append((new_x, new_y))
            plt.annotate(str_val, (new_x, new_y))

    if use_legend:
        color_dict = {}
        for b, band in enumerate(Band.instances[::-1]):
            if band.name and ("$[sink]" not in band.name) and ("$[source]" not in band.name):
                color_dict[band.name] = band.color

        patches = [mpatches.Patch(color=color, label=name)
                   for name, color in color_dict.items()]
        plt.legend(handles=patches, loc='lower center', bbox_to_anchor=(
            0.5, -0.2), ncol=min(len(color_dict), 3), frameon=False, fontsize=fontsize+3)
        plt.subplots_adjust(bottom=0.5)

    plt.axis('tight')  # Adjusts the axes limits to fit the data snugly
    plt.tight_layout()
    if filename:
        show_plot = False
        plt.savefig(filename, bbox_inches="tight")
        plt.close()
    if show_plot:
        plt.tight_layout()
        plt.show()
    return fig
