# pyqt_df_sorter.py
import sys
from typing import List
import pandas as pd
from PyQt5.QtWidgets import QPushButton

from PyQt5.QtCore import (
    QAbstractTableModel, Qt, QVariant, pyqtSignal, QModelIndex
)
from PyQt5.QtWidgets import (
    QApplication, QDialog, QVBoxLayout, QHBoxLayout, QPushButton,
    QTableView, QFileDialog, QMessageBox
)
from PyQt5.QtWidgets import (
    QDialog, QVBoxLayout, QHBoxLayout, QTableView,
    QListWidget, QListWidgetItem, QLabel
)

class PandasReorderModel(QAbstractTableModel):
    """
    Read-only model that supports:
      - Sorting via header click (overrides sort())
      - Column reordering via header drag (handled in the view; we expose helpers)
      - Row reordering via helper functions (move rows up/down)
    We keep row_order and col_order as index maps into the original DataFrame.
    """
    sort_changed = pyqtSignal()  # emitted after sort or manual reordering

    def __init__(self, df: pd.DataFrame):
        super().__init__()
        self._df_original = df.copy()
        self._df = df.copy()  # a working copy we map with orders
        self.row_order: List[int] = list(range(len(self._df)))
        self.col_order: List[int] = list(range(len(self._df.columns)))
        self._ascending = True
        self._last_sort_col = -1

    # --- Required model methods ---
    def rowCount(self, parent=QModelIndex()):
        return len(self.row_order)

    def columnCount(self, parent=QModelIndex()):
        return len(self.col_order)

    def index_to_df_coords(self, row, col):
        return self.row_order[row], self.col_order[col]

    def data(self, index, role=Qt.DisplayRole):
        if not index.isValid():
            return QVariant()
        if role in (Qt.DisplayRole, Qt.EditRole):
            r_df, c_df = self.index_to_df_coords(index.row(), index.column())
            val = self._df.iat[r_df, c_df]
            # nicer display for NaN
            if pd.isna(val):
                return ""
            return str(val)
        return QVariant()

    def headerData(self, section, orientation, role=Qt.DisplayRole):
        if role != Qt.DisplayRole:
            return QVariant()
        if orientation == Qt.Horizontal:
            return str(self._df.columns[self.col_order[section]])
        else:
            # show original index label of the row
            return str(self._df.index[self.row_order[section]])

    def flags(self, index):
        if not index.isValid():
            return Qt.NoItemFlags
        # Read-only cells, selectable and enabled
        return Qt.ItemIsSelectable | Qt.ItemIsEnabled

    # --- Sorting ---
    def sort(self, column: int, order: Qt.SortOrder = Qt.AscendingOrder):
        # column is in view-coordinates; map to df column
        df_col = self.col_order[column]
        series = self._df.iloc[:, df_col]

        # Build a stable sort on row_order using the series values
        # We use pandas to get the sorted index positions; NaNs go last.
        ascending = (order == Qt.AscendingOrder)
        # Create a temporary DataFrame with keys: value + original position for stability
        temp = pd.DataFrame({
            "key": series.values,
            "orig_pos": range(len(series))
        })

        # Sort with NaNs last regardless of ascending/desc
        # pandas 1.x: sort_values has na_position
        temp_sorted = temp.sort_values(
            by=["key", "orig_pos"],
            ascending=[ascending, True],
            na_position="last"
        )

        # Now convert temp_sorted order into new row_order,
        # but we must map through the existing row_order to keep manual moves respected
        # Approach: get the current row_order -> current df rows; sort those rows by the chosen column
        # Build a DataFrame of (current_row, value), sort, then map to new order
        current_rows = pd.Index(self.row_order)
        keys_for_current = series.iloc[current_rows].reset_index(drop=True)
        temp2 = pd.DataFrame({
            "key": keys_for_current,
            "orig_pos": range(len(current_rows))
        }).sort_values(
            by=["key", "orig_pos"],
            ascending=[ascending, True],
            na_position="last"
        )
        # temp2.orig_pos are positions within current row_order
        new_row_order = [self.row_order[i] for i in temp2["orig_pos"].tolist()]

        self.beginResetModel()
        self.row_order = new_row_order
        self.endResetModel()

        self._ascending = ascending
        self._last_sort_col = column
        self.sort_changed.emit()

    # --- Column reordering support (called by the dialog when header sectionMoved fires) ---
    def move_column_section(self, logicalIndex, oldVisualIndex, newVisualIndex):
        """
        Update col_order when user drags columns in the header.
        logicalIndex is the logical column (0..n-1 in current model coords),
        oldVisualIndex/newVisualIndex come from header signals. However, in QTableView
        with a plain model, logicalIndex == oldVisualIndex at the time of the signal.
        We'll simply reorder self.col_order by moving the entry at oldVisualIndex to newVisualIndex.
        """
        if oldVisualIndex == newVisualIndex:
            return
        self.beginResetModel()
        item = self.col_order.pop(oldVisualIndex)
        self.col_order.insert(newVisualIndex, item)
        self.endResetModel()
        self.sort_changed.emit()

    # --- Row reordering (Move Up / Move Down) ---
    def move_selected_rows(self, selected_rows_visual: List[int], direction: int):
        """
        Move selected rows up (-1) or down (+1) in the *visual* order.
        selected_rows_visual should be unique, sorted.
        """
        if not selected_rows_visual:
            return

        # Normalize selection: unique, ascending (visual coords)
        rows = sorted(set(selected_rows_visual))
        if direction not in (-1, 1):
            return

        if direction == -1 and rows[0] == 0:
            return  # top already
        if direction == 1 and rows[-1] == len(self.row_order) - 1:
            return  # bottom already

        self.beginResetModel()
        if direction == -1:
            # Move from topmost downward to avoid index clashes
            for r in rows:
                self.row_order[r - 1], self.row_order[r] = self.row_order[r], self.row_order[r - 1]
        else:
            # Move from bottommost upward
            for r in reversed(rows):
                self.row_order[r + 1], self.row_order[r] = self.row_order[r], self.row_order[r + 1]
        self.endResetModel()
        self.sort_changed.emit()

    # --- Helpers ---
    def reset_order(self):
        self.beginResetModel()
        self.row_order = list(range(len(self._df)))
        self.col_order = list(range(len(self._df.columns)))
        self.endResetModel()
        self.sort_changed.emit()

    def current_dataframe_view(self) -> pd.DataFrame:
        # Apply current col and row order to produce a new DataFrame view
        df2 = self._df.iloc[self.row_order, [c for c in self.col_order]]
        # Preserve original index/columns labels as they appear
        df2.index = self._df.index[self.row_order]
        df2.columns = self._df.columns[self.col_order]
        return df2.copy()

    def set_new_dataframe(self, df: pd.DataFrame):
        self.beginResetModel()
        self._df_original = df.copy()
        self._df = df.copy()
        self.row_order = list(range(len(self._df)))
        self.col_order = list(range(len(self._df.columns)))
        self._ascending = True
        self._last_sort_col = -1
        self.endResetModel()
        self.sort_changed.emit()


class DataFrameSorterDialog(QDialog):
    """
    A dialog that shows a QTableView backed by PandasReorderModel.
    - Click headers to sort rows
    - Drag column headers to reorder columns
    - Move Up / Move Down buttons to reorder selected rows
    - Commit emits dataframeCommitted(DataFrame)
    """
    dataframeCommitted = pyqtSignal(pd.DataFrame)
    _instance = None 

    def __init__(self, df: pd.DataFrame, parent=None, which="bsm"):
        super().__init__(parent)
        
        self.df = df 
        
        self.setWindowTitle("DataFrame Sort & Reorder")
        self.resize(500, 250)

        main_layout = QVBoxLayout(self)
        self.setLayout(main_layout)

        # fix: df.columns not df.column
        self.row_names = []
        for name in list(df.index):
            self.row_names.append(name.strip())
        self.col_names = []
        
        # [str(c).split("\n")[0].strip() for c in df.columns ]
        for c in df.columns:
            name = str(c).split("\n")[0].strip()
            if name != "":
                self.col_names.append(name)

        # build side panel
        list_layout = QHBoxLayout()

        self.row_list = QListWidget()
        self.row_list.setSelectionMode(QListWidget.SingleSelection)
        for name in self.row_names:
            QListWidgetItem(str(name), self.row_list)

        self.col_list = QListWidget()
        self.col_list.setSelectionMode(QListWidget.SingleSelection)

        for l in [self.row_list, self.col_list]:
            l.setDragDropMode(QListWidget.InternalMove)
            l.setDefaultDropAction(Qt.MoveAction)
            l.setEditTriggers(QListWidget.DoubleClicked)
            l.setSelectionMode(QListWidget.ExtendedSelection)

        for name in self.col_names:
            QListWidgetItem(str(name), self.col_list)

        list_layout.addWidget(QLabel("Rows"))
        list_layout.addWidget(self.row_list)
        list_layout.addSpacing(20)
        list_layout.addWidget(QLabel("Columns"))
        list_layout.addWidget(self.col_list)

        main_layout.addLayout(list_layout)

        apply_btn = QPushButton("Apply Order")
        apply_btn.clicked.connect(lambda: self._emit_list_order(which))
        main_layout.addWidget(apply_btn)

        self.setAttribute(Qt.WA_DeleteOnClose, True)
        self.destroyed.connect(lambda: setattr(self.__class__, "_instance", None))

    @classmethod
    def open(cls, parent, df: pd.DataFrame, which="bsm"):
        """Open (or focus) the singleton dialog."""
        if cls._instance is not None and cls._instance.isVisible():
            # (optional) update data when reopening with a new df/which
            if df is not None:
                cls._instance.df = df
                # if you want to refresh the lists/model here, do it
            # bring it to front
            w = cls._instance
            w.show()
            w.raise_()
            w.activateWindow()
            return w

        # create a fresh instance
        dlg = cls(df, parent=parent, which=which)
        cls._instance = dlg
        dlg.show()  # non-modal; use exec_() if you want modal (see note below)
        return dlg
    
    def _emit_list_order(self, which):
        # 
        
        if which=="bsm": 
            self.df = self.parent().BSMtableView.model().df
            ro = [self.row_list.item(i).text()+"\n" for i in range(self.row_list.count())]
            co = [self.col_list.item(i).text() for i in range(self.col_list.count())]

            row_order = ro
            col_order = []

            for c in co: # search appearance of 'c' in self.df 
                col_a = None
                #col_l = None
                for i, c2 in enumerate(self.df.columns):
                    if c in c2: 
                        #if self.df.columns[i+1] == "\nL ":
                        #    col_a = i # self.df.columns[i]
                        #    col_l = i+1 # self.df.columns[i+1]
                        #    break 
                        col_a = c
                        break 
                #if (col_a is not None) and (col_l is not None):
                #    col_order += [col_a, col_l]
                if col_a is not None:
                    col_order += [col_a]

            self.parent().sorting_tables["bsm"]["rows"] = row_order 
            self.parent().sorting_tables["bsm"]["cols"] = col_order

            self.parent().gen_balance_matrix()
            self.parent().sort_table("bsm")
            


        elif which == "tfm": 
            self.df = self.parent().TFMtableView.model().df
            ro = [self.row_list.item(i).text() for i in range(self.row_list.count())]
            co = [self.col_list.item(i).text() for i in range(self.col_list.count())]

            row_order = ro
            col_order = []

            for c in co: # search appearance of 'c' in self.df 
                x = None 
                # y = None 
                for i, c2 in enumerate(self.df.columns):
                    if c in c2: 
                        x = c # i # self.df.columns[i]
                        #if i+1 < len(self.df.columns):
                        #    if "KA" in self.df.columns[i+1]:
                        #        # if self.df.columns[i+1] == "\nKA ":
                        #        y = i+1  # (CA account)
                        
                        break 
                if (x is not None):
                    col_order += [x] # , col_l]
                
                #if y is not None:
                #    col_order += [y]
            
            self.parent().sorting_tables["tfm"]["rows"] = row_order 
            self.parent().sorting_tables["tfm"]["cols"] = col_order
            
            self.parent().gen_transaction_matrix()
            self.parent().sort_table("tfm")