"""Module for the Compex class """

# pylint: disable=no-absolute-import, consider-using-from-import

# Import external packages
import numpy as np


# Local imports
from InDelsTopo.filtration import Filtration, _convert_words_to_blocks
import InDelsTopo.graphics as graphics
from InDelsTopo.utils import _combine_blocks_alphabet


# Complex Class
# -------------------------------------------------------
class Complex:
    """
    Represents an Insertion Chain Complex C[W] for a set of words W.

    The class stores the blocks of the complex in each dimension and provides
    methods to compute topological invariants such as the Euler characteristic
    and Betti numbers (over Z_2 or Z when using a SageMath kernel). It also
    supports graphical visualization of the complex in low dimensions.

    The k-dimensional blocks can be accessed via indexing syntax:
    `K[k]` returns the list of k-blocks.

    Attributes:
        dim (int): Maximum dimension of the complex.
        complex_dict (dict[int, list[Block]]): Maps each dimension to its corresponding
            list of blocks.
        height (float | None): Height value associated with the complex.
        _alphabet (Alphabet): Alphabet object containing all symbols used in W.
        _prod_symbol (str): Product symbol used in the blocks.
        _positions_dict (dict | None): Stores vertex positions for graphical visualization.

    Notes:
        - This class can be used to build and analyze Insertion Chain Complexes directly,
          or as a sublevel complex of a filtration.
        - Homology computations over Z require SageMath; otherwise, only Z_2 computations
          are available using SciPy.
        - It initializes as an empty complex. It can be made into the insertion chain complex
            of a set of words `W` by using the method `compute_d_skeleton(W)`.
        - Blocks can be added or removed by using the methods `add_blocks` and `remove_blocks`.


    Example:
    ```python
    >>> W = ["ab", "aab", "abb"]
    >>> K = Complex() # Creates an empty complex
    >>> K.compute_d_skeleton(W) # makes K = C[W]
    >>> K[1]  # Access 1-dimensional blocks
    [a(1,a)b, ab(1,b)]
    ```
    """

    def __init__(
        self, alphabet=None, prod_symbol=None, *, complex_dict=None, height=None
    ):
        """
        Initialize an Insertion Chain Complex.

        This constructor creates a new complex object, which can either start empty
        (with no blocks) or be initialized from a given dictionary of blocks.
        Optionally, the complex can store a height value, useful when representing
        a level in a filtration.

        Args:
            alphabet (Alphabet | None, optional): The alphabet containing all symbols
                to be used. Defaults to None.
            prod_symbol (str | None, optional): Product symbol used for block construction,
                Must be one of {'', '*', '.'}. If None, it is inferred from the expression:
                set to '*' if '*' appears, to '.' if '.' appears, or to '' (concatenation)
                otherwise.
            complex_dict (dict[int, list[Block]] | None, optional): Mapping from
                dimension `d` to the list of d-dimensional blocks. If provided, it
                initializes the complex structure directly. Defaults to None.
            height (float | None, optional): Height value associated with the complex,
                used when part of a filtration. Defaults to None.
        """
        self._alphabet = alphabet
        self._prod_symbol = prod_symbol
        self.dim = -1
        self.complex_dict = {}
        self.height = None
        self._positions_dict = None

        if not complex_dict is None:
            self.dim = max(complex_dict.keys(), default=-1)
            self.complex_dict = complex_dict
            self.height = height

    def compute_d_skeleton(
        self,
        W,
        height=None,
        max_dim=10,
        alphabet=None,
        prod_symbol=None,
        check_duplicates=True,
        already_blocks=False,
        verbose=False,
    ):
        """
        Compute the d-skeleton of the Insertion Chain Complex generated by a set of words, C[W].
        This method replaces any existing data in the Complex with a new complex supported on `W`.

        This method constructs all valid blocks up to the specified maximum dimension (`max_dim`)
        for a given set of words `W`. It begins by computing the 0- and 1-skeletons (vertices and edges),
        then iteratively extends to higher dimensions.

        It updates the internal `complex_dict` to the blocks supported on `W`.

        Args:
            W (list of str or Block): List of words (or blocks, if `already_blocks=True`)
                forming the base of the complex.
            height (float, optional): Height value associated with the complex,
                used when part of a filtration. Defaults to None.
            max_dim (int, optional): Maximum dimension of the skeleton to compute. Defaults to 10.
            alphabet (Alphabet, optional): Alphabet object used together with the internal
                `self._alphabet` and any letters inferred from `W`. If provided, its symbols
                are merged with `self._alphabet`; otherwise, the new symbols are inferred entirely
                from the given words.
            prod_symbol (str, optional): Product symbol for block construction ('*', '.', or '').
                If None, inferred automatically.
            check_duplicates (bool, optional): Whether to verify that input words are unique.
                Defaults to True.
            already_blocks (bool, optional): If True, assumes the input `W` is already a list of
                `Block` objects instead of strings. Defaults to False.
            verbose (bool, optional): If True, prints progress information during computation.

        Example:
        ```python
        >>> W = ['a*b', 'a*b*b', 'a*a*b','']
        >>> K = Filtration()
        >>> K.compute_d_skeleton(W, heights=[0.1, 0.3, 0.2,0.4], max_dim=2)
        >>> K[1]
        {a*b*(1,b): 0.3, a*(1,a)*b: 0.2}
        ```
        """
        K = Filtration()
        K.compute_d_skeleton(
            W,
            heights=None,
            max_dim=max_dim,
            alphabet=alphabet,
            prod_symbol=prod_symbol,
            check_duplicates=check_duplicates,
            already_blocks=already_blocks,
            verbose=verbose,
        )
        complex_dict = {dim: list(K[dim].keys()) for dim in K.filtration_dict}

        self.complex_dict = complex_dict
        self.height = height
        self.dim = K.dim
        self._alphabet = K.get_alphabet()

        del K

    def get_maximal_blocks(self):
        """
        Return the maximal blocks of the complex, ordered by subfaces.

        A block is maximal if it is not a subface of any higher-dimensional block.
        The method identifies all such maximal blocks in each dimension and returns
        them as a dictionary.

        Returns:
            dict[int, list[Block]]: A dictionary mapping each dimension to a list
            of maximal blocks (i.e., blocks not covered by higher-dimensional ones).

        Notes:
            - Empty dimensions are removed from the output dictionary.
        """
        covered_dict = {}
        max_dim = self.dim

        if max_dim >= 0:
            for dim in range(max_dim, -1, -1):
                covered_dict[dim - 1] = set([])
                for block in self.complex_dict[dim]:
                    covered_dict[dim - 1].update(block.get_all_facets())

            # Remove faces covered by a higher one
            maximal_dict = {max_dim: self.complex_dict[max_dim].copy()}
            for dim in range(max_dim):
                maximal_dict[dim] = list(
                    set(self.complex_dict[dim]).difference(covered_dict[dim])
                )

            # Remove empty dimensions
            for dim in range(max_dim):
                if len(maximal_dict[dim]) == 0:
                    del maximal_dict[dim]
            return maximal_dict
        return {}

    def get_complex(self, max_dim=None):
        """
        Return a subcomplex of the current complex up to the specified dimension.

        If `max_dim` is not provided or is greater than or equal to the current
        dimension, the method returns the complex itself. Otherwise, it returns
        a new `Complex` object containing only the blocks up to dimension `max_dim`.

        Args:
            max_dim (int or None, optional): Maximum dimension of blocks to include
                in the returned complex. If None or greater than the complex
                dimension, the full complex is returned.

        Returns:
            Complex: A subcomplex containing blocks up to dimension `max_dim`.
        """
        if max_dim is None or max_dim >= self.dim:
            return self

        complex_dict = {dim: list(self.complex_dict[dim]) for dim in range(max_dim + 1)}
        for dim in range(max_dim, -1, -1):
            if len(complex_dict[dim]) == 0:
                del complex_dict[dim]
            else:
                break
        return Complex(
            alphabet=self._alphabet,
            prod_symbol=self._prod_symbol,
            complex_dict=complex_dict,
            height=self.height,
        )

    def add_blocks(self, list_blocks, prod_symbol=None, already_blocks=False):
        """
        Add new blocks to the Complex.

        Extends the current Complex by inserting additional blocks and their faces.
        Intended for expert use only, since the resulting structure may not be a full
        Insertion Chain Complex C[W], but rather a subcomplex if some supported blocks
        are missing.

        Args:
            list_blocks (list[Block] or list[str]):
                List of blocks to be added to the Filtration. If ``already_blocks`` is
                False (default), the elements are assumed to be strings representing
                blocks and will be converted. If True, they are assumed to be existing
                ``Block`` objects.


            prod_symbol (str or None, optional):
                Product symbol used in block representation ('*', '.', or '').
                If not specified, it is inferred from the input blocks.

            already_blocks (bool, optional):
                If True, elements of ``list_blocks`` are assumed to be valid ``Block``
                objects. If False (default), the method attempts to convert the input
                into blocks.

        Notes:
            The internal alphabet and product symbol are updated to ensure consistency.
        """
        # Convert into blocks if needed
        if already_blocks:
            alphabet = _combine_blocks_alphabet(list_blocks, self._alphabet)
        else:
            list_blocks, alphabet, prod_symbol = _convert_words_to_blocks(
                list_blocks, prod_symbol=prod_symbol, alphabet=self._alphabet
            )
        self._alphabet = alphabet

        # Uniformalize prod_symbols pylint: disable=protected-access
        new_prods = [blk._prod_symbol for blk in list_blocks] + [self._prod_symbol]
        if "*" in new_prods:
            prod_symbol = "*"
        elif "." in new_prods:
            prod_symbol = "."
        else:
            prod_symbol = ""
        self._prod_symbol = prod_symbol
        for blk in list_blocks:
            blk._prod_symbol = self._prod_symbol

        # Add blocks and faces to complex_dict
        for block in list_blocks:
            for face in block.get_all_faces(True):
                dimension = face.dim
                if dimension in self.complex_dict:
                    if not face in self.complex_dict[dimension]:
                        try:
                            self.complex_dict[dimension].append(face)
                        except KeyError:
                            self.complex_dict[dimension] = [face]
                else:
                    self.complex_dict[dimension] = [face]

        # Recompute dimension
        self.dim = max(self.complex_dict, default=-1)

    def remove_blocks(
        self, list_blocks, prod_symbol=None, include_upfaces=True, already_blocks=False
    ):
        """
        Remove blocks from the Complex.

        Deletes specified blocks and optionally their super-faces from the Complex.
        Intended for expert use only, since the resulting structure may not be a full
        Insertion Chain Complex C[W], but rather a subcomplex if some supported blocks
        are missing.

        Args:
            list_blocks (list of Block or string): A list of blocks to remove. If
                `already_blocks` is False (default), the elements are assumed to be strings
                representing blocks and will be converted. If True, they are assumed to be
                existing Block objects.
            prod_symbol (str or None, optional): Product symbol used in block
                representation ( '*', '.', or ''). If not specified, it is inferred
                from the input blocks.
            include_upfaces (bool, optional): If True, all super faces of the specified blocks
                are also removed, so the result is a subcomplex. Default is True.
            already_blocks (bool, optional): If True, the elements of `list_blocks`
                are assumed to be valid Block objects. If False (default), the method
                attempts to convert the input into blocks.
        """
        # Make sure it is a list
        if not isinstance(list_blocks, list):
            raise TypeError("list_blocks must be a list")

        # Convert into blocks if needed
        if not already_blocks:
            list_blocks, _alphabet, prod_symbol = _convert_words_to_blocks(
                list_blocks, prod_symbol=prod_symbol, alphabet=self._alphabet
            )

        # Dictionary of blocks to remove
        blocks_to_remove = {i: [] for i in range(self.dim + 1)}
        for block in list_blocks:
            if block.dim in self.complex_dict and block in self.complex_dict[block.dim]:
                blocks_to_remove[block.dim].append(block)

        # Find super-faces if needed
        if include_upfaces:
            for dimension in range(1, self.dim + 1):
                for block in self.complex_dict[dimension]:
                    facets = block.get_all_facets()
                    if any(
                        facet in blocks_to_remove[dimension - 1] for facet in facets
                    ):
                        blocks_to_remove[dimension].append(block)

        # Remove blocks
        for dimension in blocks_to_remove:
            for block in blocks_to_remove[dimension]:
                if block in self.complex_dict[dimension]:
                    self.complex_dict[dimension].remove(block)

        # Update complex_dict
        for dimension in range(self.dim, -1, -1):
            if len(self.complex_dict[dimension]) == 0:
                del self.complex_dict[dimension]
            else:
                break

        # Recompute dimension
        self.dim = max(self.complex_dict, default=-1)

    def euler_characteristic(self):
        """
        This method computes the Euler characteristic of the complex.
        """
        faces = [(k, len(self.complex_dict[k])) for k in self.complex_dict]
        characteristic = int(np.sum([(-1) ** k * m for (k, m) in faces]))
        return characteristic

    def __getitem__(self, key):
        if 0 <= key <= self.dim:
            return self.complex_dict[key]
        return {}

    def get_graph(
        self,
        show_labels=True,
        max_dim=5,
        positions=None,
        initial_positions=None,
        fixed=None,
        recompute=False,
        colors_by_dim=None,
        ax=None,
    ):
        """
        Generate a graphical representation of the complex up to a specified dimension.

        Positions of vertices can be computed automatically or provided manually.
        Only accurate for low-dimensional complexes (typically dim <= 3).

        Args:
            show_labels (bool, optional): Whether to display labels on the vertices. Defaults to True.
            max_dim (int, optional): Maximum dimension of blocks to include in the graph. Defaults to 5.
            positions (dict, optional): Dictionary of vertex positions.
                If None, positions are computed automatically. Once computed, they are reused
                everytime this method is called, unless recompute is set to True.
            initial_positions (dict, optional): Initial positions used to seed the
                automatic layout algorithm.
            fixed (list or None, optional): List of vertex keys to fix in place when computing positions.
                Defaults to None.
            recompute (bool, optional): Whether to recompute vertex positions even
                if already stored. Defaults to False.
            colors_by_dim (list of str, optional): List of colors to use for each dimension.
                If None, defaults to ['black', 'gray', 'yellow', 'red', 'blue', 'purple'].
            ax (matplotlib.axes._subplots.Axes3DSubplot, optional): A Matplotlib Axes
                object to draw the plot on. If None, a new figure and axes are created.
                Defaults to None.

        Returns:
            matplotlib.axes.Axes: Matplotlib axes object containing the drawn graph.
        """
        if self.dim == -1:
            return None
        if (positions is None) or recompute:
            if (self._positions_dict is None) or recompute:
                self._positions_dict = graphics.compute_vertex_positions(
                    self, pos0=initial_positions, fixed=fixed
                )
            # Make sure position includes all vertices
            elif any(vertex not in self._positions_dict for vertex in self[0]):
                # Update with pos0
                if isinstance(initial_positions, dict):
                    for vertex in initial_positions:
                        self._positions_dict[vertex] = initial_positions[vertex]
                # Compute for all vertices
                self._positions_dict = graphics.compute_vertex_positions(
                    self,
                    pos0=self._positions_dict,
                    fixed=list(self._positions_dict.keys()),
                )

            # Get the positions
            positions = self._positions_dict

            ax = graphics.make_graph(
                self,
                pos=positions,
                show_labels=show_labels,
                max_dim=max_dim,
                height=None,
                already_complex=True,
                colors_by_dim=colors_by_dim,
                ax=ax,
            )

        return ax

    def get_betti_numbers_z2(self, max_dim=None):
        """
        Returns the betti numbers in Z_2 coefficients up to the specified dimension max_dim.
        """
        # Import csc_matrix
        from scipy.sparse import csc_matrix

        if max_dim is None or max_dim >= self.dim:
            max_dim = self.dim
            using_skeleton = False
        elif max_dim == self.dim - 1:
            max_dim = self.dim
            using_skeleton = True
        else:
            max_dim += 1
            using_skeleton = True

        # Order the blocks according to their dimension
        ordered_blocks = []
        for d in range(max_dim + 1):
            ordered_blocks += self[d]
        ordered_blocks.sort(key=lambda B: B.dim)

        ordered_blocks_dict = {ordered_blocks[i]: i for i in range(len(ordered_blocks))}

        # Construct the boundary matrix
        cols = []
        rows = []
        data = []
        for id_col, block in enumerate(ordered_blocks):
            facets = block.get_all_facets()
            for face in facets:
                cols.append(id_col)
                rows.append(ordered_blocks_dict[face])
                data.append(1)

        N = len(ordered_blocks)
        boundary_matrix = csc_matrix((data, (rows, cols)), shape=(N, N)).tolil()

        # Perform row reduction (we follow the algorithm from https://arxiv.org/pdf/1506.08903)
        low = []  # maps row index to pivot column index
        betti_numbers = {dim: 0 for dim in range(max_dim + 1)}

        for j in range(N):
            col = boundary_matrix.getcol(j)
            if col.nnz > 0:
                # Set initial value for low(j)
                low_j = col.nonzero()[0][-1]
                while low_j in low:
                    i = low.index(low_j)
                    # add column i to column j
                    col = col + boundary_matrix.getcol(i)
                    col.data = col.data % 2  # Make adition modulo 2
                    col.eliminate_zeros()
                    if col.nnz > 0:
                        low_j = col.nonzero()[0][-1]  # update low_j value
                    else:
                        low_j = -1
                        break
                boundary_matrix[:, j] = col  # update column j in the matrix
                low.append(low_j)  # Save value for low_j
            else:
                # Set -1, for undefined low(j)
                low.append(-1)

        # Extract surviving cycles
        for j, low_j in enumerate(low):
            if low_j < 0 and not j in low:
                dim = ordered_blocks[j].dim
                betti_numbers[dim] += 1

        if using_skeleton:
            if max_dim in betti_numbers:
                del betti_numbers[max_dim]

        return betti_numbers

    def get_chain_complex_sage(self, get_ordered_blocks=False):
        """
        Construct and return the associated chain complex as a SageMath object.

        This method is intended to be used in a SageMath kernel. It builds a chain complex
        using Sage's `ChainComplex()` constructor. Optionally, it can also return the blocks
        sorted within each dimension.

        Args:
            get_ordered_blocks (bool, optional): If True, also return the blocks
                ordered by their expressions within each dimension. Defaults to False.

        Returns:
            ChainComplex or tuple[ChainComplex, dict[int, list[Block]]]:
                - If `get_ordered_blocks` is False, returns a ChainComplex object
                  representing the boundary operators.
                - If `get_ordered_blocks` is True, returns a tuple `(ChainComplex, Blocks_ordered)`,
                  where `Blocks_ordered` is a dictionary of lists of blocks sorted by expression.
         Notes:
            - Requires SageMath to be installed and accessible in the current environment.
            - If SageMath is not found, the function prints a warning and returns None.
            - Any other errors during construction are caught and printed.
        """
        try:
            from InDelsTopo import homology_sagemath

            return homology_sagemath.create_chain_complex(
                self.complex_dict, get_ordered_blocks
            )
        except ImportError:
            print(
                "Could not find SageMath functions. "
                "This module requires SageMath to run and cannot be "
                "executed in a standard Python environment."
            )
            return None
        except Exception as e:
            print(f"An error occurred: {e}")
            return None

    def get_homology_sage(
        self, save_chain_complex=False, used_saved_chain_complex=True, **kwargs
    ):
        """
        Return the homology of the associated chain complex using SageMath.

        This method should be run in a SageMath kernel. It constructs a SageMath chain complex
        and computes its homology using Sage's built-in `homology()` method.

        Parameters:
            save_chain_complex (bool): If True, the chain complex is saved as an attribute
                                       (`self._sage_chain_complex`) to speed up future computations.
            used_saved_chain_complex (bool): If True, it will attempt to use the saved chain_complex
                                            attribute, otherwise, it will be computed from scratch.
            **kwargs: Additional keyword arguments passed directly to SageMath's `homology()` method.

        Returns:
            The homology object returned by SageMath.
        """
        chain_complex = getattr(self, "_sage_chain_complex", None)
        if not used_saved_chain_complex or chain_complex is None:
            chain_complex = self.get_chain_complex_sage()

        if save_chain_complex:
            self._sage_chain_complex = chain_complex

        return chain_complex.homology(**kwargs)

    def __str__(self):
        to_print = "Insertion Chain Complex \n"
        to_print += "alphabet: " + str(self._alphabet) + ".\n"
        if not self.height is None:
            to_print += "height: " + str(self.height) + ".\n"
        to_print += "dimension: " + str(self.dim) + ".\n"
        to_print += "vertices: " + str(len(self[0])) + ".\n"
        to_print += (
            "blocks: " + str(sum([len(self[k]) for k in range(self.dim + 1)])) + "."
        )
        return to_print

    def __repr__(self):
        return self.__str__()

    def get_alphabet(self):
        """Returns the alphabet attribute."""
        return self._alphabet
