######################################################################################################################
# Copyright (C) 2017-2022 Spine project consortium
# Copyright Spine Database API contributors
# This file is part of Spine Database API.
# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
""" Contains export mappings for database items such as entities, entity classes and parameter values."""
from __future__ import annotations
from collections.abc import Callable, Iterator
from contextlib import suppress
from dataclasses import dataclass
from itertools import cycle, dropwhile, islice
from typing import Any, ClassVar, Optional
from sqlalchemy import and_
from sqlalchemy.engine import Row
from sqlalchemy.orm import Query
from sqlalchemy.sql.expression import CacheKey
from .. import DatabaseMapping
from ..mapping import Mapping, Position, is_pivoted, is_regular, unflatten
from ..parameter_value import (
    IndexedValue,
    convert_containers_to_maps,
    from_database,
    from_database_to_dimension_count,
    from_database_to_single_value,
    type_for_scalar,
)
from .group_functions import NoGroup


class _MappingWithLeafMixin:
    """Provides current_leaf field."""

    current_leaf = None


class ExportMapping(Mapping):
    _TITLE_SEP = ","
    name_field: ClassVar[Optional[str]] = None
    id_field: ClassVar[Optional[str]] = None

    def __init__(self, position, value=None, header="", filter_re=""):
        """
        Args:
            position (int or Position): column index or Position
            value (Any, optional): A fixed value
            header (str); A string column header that's yielded as 'first row', if not empty.
                The default is an empty string (so it's not yielded).
            filter_re (str): A regular expression to filter the mapped values by
        """
        super().__init__(position, value, filter_re)
        self._ignorable = False
        self.header = header
        self._convert_data = None

    def __eq__(self, other):
        if not isinstance(other, ExportMapping):
            return NotImplemented
        if not super().__eq__(other):
            return False
        return self._ignorable == other._ignorable and self.header == other.header

    def check_validity(self):
        """Checks if mapping is valid.

        Returns:
            list: a list of issues
        """
        issues = []
        if self.is_effective_leaf() and is_pivoted(self.position):
            issues.append("Cannot be pivoted.")
        return issues

    def replace_data(self, data):
        """
        Replaces the data generated by this item by user given data.

        If data is exhausted, it gets cycled again from the beginning.

        Args:
            data (Iterable): user data
        """
        data_iterator = cycle(data)
        self._convert_data = lambda _: next(data_iterator)

    @staticmethod
    def is_buddy(parent):
        """Checks if mapping uses a parent's state for its data.

        Args:
            parent (ExportMapping): a parent mapping

        Returns:
            bool: True if parent's state affects what a mapping yields
        """
        return False

    def is_ignorable(self):
        """Returns True if the mapping is ignorable, False otherwise.

        Returns:
            bool: True if mapping is ignorable, False otherwise
        """
        return self._ignorable

    def set_ignorable(self, ignorable):
        """
        Sets mapping as ignorable.

        Mappings that are ignorable map to None if there is no other data to yield.
        This allows 'incomplete' rows if child mappings do not depend on the ignored mapping.

        Args:
            ignorable (bool): True to set mapping ignorable, False to unset
        """
        self._ignorable = ignorable

    def to_dict(self):
        """
        Serializes mapping into dict.

        Returns:
            dict: serialized mapping
        """
        mapping_dict = super().to_dict()
        if self._ignorable:
            mapping_dict["ignorable"] = True
        if self.header:
            mapping_dict["header"] = self.header
        return mapping_dict

    @classmethod
    def reconstruct(cls, position, value, header, filter_re, ignorable, mapping_dict):
        """
        Reconstructs mapping.

        Args:
            position (int or Position, optional): mapping's position
            value (Any): fixed value
            header (str, optional): column header
            filter_re (str): filter regular expression
            ignorable (bool): ignorable flag
            mapping_dict (dict): mapping dict

        Returns:
            Mapping: reconstructed mapping
        """
        mapping = cls(position, value, header, filter_re)
        mapping.set_ignorable(ignorable)
        return mapping

    def build_query_columns(self, db_map, columns):
        """Appends columns needed to query the mapping's data into a list.

        Args:
            db_map (DatabaseMapping): database mapping
            columns (list of Column): list of columns to append to
        """

    def filter_query(self, db_map, query):
        """Filters the mapping query if needed, and returns the new query.

        The base class implementation just returns the same query without applying any new filters.

        Args:
            db_map (DatabaseMapping)
            query (Subquery or dict)

        Returns:
            Subquery: filtered query, or the same if nothing to add.
        """
        return query

    def filter_query_by_title(self, query, title_state):
        """Filters the query according to the given title state.
        Note that ``_build_query()`` does some default filtering on the title state after calling this method.
        Therefore, if a subclass reimplements this method, it needs to delete the consumed keys from ``title_state``
        so they aren't consumed again by ``_build_query()``.

        The base class implementations just returns the unaltered query.

        Args:
            title_state (dict)

        Returns:
            Query or _FilteredQuery
        """
        return query

    def _build_query(
        self, db_map: DatabaseMapping, title_state: dict[str, Any], row_cache: dict[CacheKey, list[Row]]
    ) -> Optional[Query | _FilteredQuery]:
        """Builds and returns the query to run for this mapping hierarchy."""
        mappings = self.flatten()
        columns = []
        for m in mappings:
            m.build_query_columns(db_map, columns)
        if not columns:
            return None
        qry = db_map.query(*columns)
        for m in mappings:
            qry = m.filter_query(db_map, qry)
        # Apply special title filters (first, so we clean up the state)
        for m in mappings:
            qry = m.filter_query_by_title(qry, title_state)
        if not title_state:
            return qry
        # Use a _FilteredQuery, since building a subquery to query it again leads to parser stack overflow
        return _FilteredQuery(
            db_map,
            qry,
            (lambda db_row: all(getattr(db_row, key) == value for key, value in title_state.items())),
            row_cache,
        )

    def _build_title_query(self, db_map: DatabaseMapping) -> Optional[Query]:
        """Builds and returns the query to get titles for this mapping hierarchy.

        Args:
            db_map: database mapping

        Returns:
            title query
        """
        mappings = self.flatten()
        for _ in range(len(mappings)):
            if mappings[-1].position == Position.table_name:
                break
            mappings.pop(-1)
        columns = []
        for m in mappings:
            m.build_query_columns(db_map, columns)
        if not columns:
            return None
        qry = db_map.query(*columns)
        # Apply filters
        for m in mappings:
            qry = m.filter_query(db_map, qry)
        return qry

    def _build_header_query(
        self,
        db_map: DatabaseMapping,
        title_state: dict[str, Any],
        buddies: list[tuple[ExportMapping, ExportMapping]],
        row_cache: dict[CacheKey, list[Row]],
    ) -> Optional[Query | _FilteredQuery]:
        """Builds the header query for this mapping hierarchy.

        Args:
            db_map: database mapping
            title_state: title state
            buddies: pairs of buddy mappings
            row_cache: cache for queried database rows

        Returns:
            header query
        """
        mappings = self.flatten()
        flat_buddies = [b for pair in buddies for b in pair]
        for _ in range(len(mappings)):
            m = mappings[-1]
            if m.position in (Position.header, Position.table_name) or m in flat_buddies:
                break
            mappings.pop(-1)
        columns = []
        for m in mappings:
            m.build_query_columns(db_map, columns)
        if not columns:
            return None
        qry = db_map.query(*columns)
        for m in mappings:
            qry = m.filter_query(db_map, qry)
        # Apply special title filters (first, so we clean up the state)
        for m in mappings:
            qry = m.filter_query_by_title(qry, title_state)
        if not title_state:
            return qry
        # Use a _FilteredQuery, since building a subquery to query it again leads to parser stack overflow
        return _FilteredQuery(
            db_map,
            qry,
            (lambda db_row: all(getattr(db_row, key) == value for key, value in title_state.items())),
            row_cache,
        )

    def _data(self, row):
        """Returns the data relevant to this mapping from given database row.

        The base class implementation returns the field given by ``name_field``.

        Args:
            row (Row)

        Returns:
            any
        """
        return getattr(row, self.name_field, None)

    def _expand_data(self, data):
        """Takes data from an individual field in the db and yields all data generated by this mapping.

        The base class implementation simply yields the given data.
        Reimplement in subclasses that need to expand the data into multiple elements (e.g., indexed value mappings).

        Args:
            data (any)

        Returns:
            generator(any)
        """
        yield data

    def _get_data_iterator(self, data):
        """Applies regexp filtering and data conversion on the output of ``_expand_data()`` to produce the final data
        iterator for this mapping.

        Args:
            data (any)

        Returns:
            generator(any)
        """
        data_iterator = self._expand_data(data)
        if self._filter_re is not None:
            data_iterator = (x for x in data_iterator if self._filter_re.search(str(x)))
        if self._convert_data is not None:
            data_iterator = (self._convert_data(x) for x in data_iterator)
        return data_iterator

    def _get_rows(self, db_row):
        """Yields rows issued by this mapping for given database row.

        Args:
            db_row (Row)

        Returns:
            generator(dict)
        """
        if self.position == Position.table_name:
            yield {}
            return
        data = self._data(db_row)
        if data is None and not self._ignorable:
            return
        data_iterator = self._get_data_iterator(data)
        for data in data_iterator:
            yield {self.position: data}

    def get_rows_recursive(self, db_row: Row) -> Iterator[dict[int | Position, Any]]:
        """Takes a database row and yields rows issued by this mapping and its children combined."""
        if self.child is None:
            yield from self._get_rows(db_row)
            return
        for row in self._get_rows(db_row):
            for child_row in self.child.get_rows_recursive(db_row):
                yield {**row, **child_row}

    def rows(
        self, db_map: DatabaseMapping, title_state: dict[str, Any], row_cache: dict[CacheKey, list[Row]]
    ) -> Iterator[dict[int | Position, Any]]:
        """Yields rows issued by this mapping and its children combined."""
        qry = self._build_query(db_map, title_state, row_cache)
        if qry is None:
            yield {}
            return
        if not isinstance(qry, _FilteredQuery):
            cache_key = qry.statement._generate_cache_key().key
            if cache_key in row_cache:
                cache = row_cache[cache_key]
                for db_row in cache:
                    yield from self.get_rows_recursive(db_row)
            else:
                row_cache[cache_key] = cache = []
                for db_row in qry:
                    cache.append(db_row)
                    yield from self.get_rows_recursive(db_row)

        else:
            for db_row in qry:
                yield from self.get_rows_recursive(db_row)

    def has_titles(self):
        """Returns True if this mapping or one of its children generates titles.

        Returns:
            bool: True if mappings generate titles, False otherwise
        """
        if self.position == Position.table_name:
            return True
        if self.child is not None:
            return self.child.has_titles()
        return False

    def _title_state(self, db_row):
        """Returns the title state associated to this mapping from given database row.

        The base class implementation returns a dict mapping the output of ``id_field()``
        to the corresponding field from the row.

        Args:
            db_row (Row)

        Returns:
            dict
        """
        id_field = self.id_field
        if id_field is None:
            return {}
        return {id_field: getattr(db_row, id_field)}

    def _get_titles(self, db_row, limit=None):
        """Yields pairs (title, title state) issued by this mapping for given database row.

        Args:
            db_row (Row)
            limit (int, optional): yield only this many items

        Returns:
            generator(str,dict)
        """
        if self.position != Position.table_name:
            yield "", {}
            return
        data = self._data(db_row)
        title_state = self._title_state(db_row)
        data_iterator = self._get_data_iterator(data)
        if limit is not None:
            data_iterator = islice(data_iterator, limit)
        for data in data_iterator:
            if data is None:
                data = ""
            yield data, title_state

    def get_titles_recursive(self, db_row, limit=None):
        """Takes a database row and yields pairs (title, title state) issued by this mapping and its children combined.

        Args:
            db_row (Row)
            limit (int, optional): yield only this many items

        Returns:
            generator(str,dict)
        """
        if self.child is None:
            yield from self._get_titles(db_row, limit=limit)
            return
        for title, title_state in self._get_titles(db_row, limit=limit):
            for child_title, child_title_state in self.child.get_titles_recursive(db_row, limit=limit):
                title_sep = self._TITLE_SEP if title and child_title else ""
                final_title = title + title_sep + child_title
                yield final_title, {**title_state, **child_title_state}

    def _non_unique_titles(self, db_map, limit=None):
        """Yields all titles, not necessarily unique, and associated state dictionaries.

        Args:
            db_map (DatabaseMapping): a database map
            limit (int, optional): yield only this many items

        Yields:
            tuple(str,dict): title, and associated title state dictionary
        """
        qry = self._build_title_query(db_map)
        if qry is None:
            yield from self.get_titles_recursive((), limit=limit)
            return
        for db_row in qry:
            yield from self.get_titles_recursive(db_row, limit=limit)

    def titles(self, db_map, limit=None):
        """Yields unique titles and associated state dictionaries.

        Args:
            db_map (DatabaseMapping): a database map
            limit (int, optional): yield only this many items

        Yields:
            tuple(str,dict): unique title, and associated title state dictionary
        """
        titles = {}
        for title, title_state in self._non_unique_titles(db_map, limit=limit):
            titles.setdefault(title, {}).update(title_state)
        yield from titles.items()

    def has_header(self):
        """Recursively checks if mapping would create a header row.

        Returns:
            bool: True if make_header() would return something useful
        """
        if self.header or self.position == Position.header:
            return True
        if self.child is None:
            return False
        return self.child.has_header()

    def make_header_recursive(self, query, buddies):
        """Builds the header recursively.

        Args:
            query (Query, optional): export query
            buddies (list of tuple): buddy mappings

        Returns
            dict: a mapping from column index to string header
        """
        if self.child is None:
            if not is_regular(self.position):
                return {}
            return {self.position: self.header}
        header = self.child.make_header_recursive(query, buddies)
        if self.position == Position.header and query is not None:
            buddy = find_my_buddy(self, buddies)
            if buddy is not None:
                query.rewind()
                header[buddy.position] = next(
                    (x for db_row in query for x in self._get_data_iterator(self._data(db_row)) if x), ""
                )
        else:
            header[self.position] = self.header
        return header

    def make_header(
        self,
        db_map: DatabaseMapping,
        title_state: dict[str, Any],
        buddies: list[tuple[ExportMapping, ExportMapping]],
        row_cache: dict[CacheKey, list[Row]],
    ) -> dict[int, str]:
        """Returns the header for this mapping.

        Args:
            db_map: database map
            title_state: title state
            buddies: buddy mappings
            row_cache: cache for fetched database rows

        Returns
            a mapping from column index to string header
        """
        query = self._build_header_query(db_map, title_state, buddies, row_cache)
        if query is not None:
            query = _Rewindable(query)
        return self.make_header_recursive(query, buddies)


def drop_non_positioned_tail(root_mapping):
    """Makes a modified mapping hierarchy without hidden tail mappings.

    This enables pivot tables to work correctly in certain situations.

    Args:
        root_mapping (Mapping): root mapping

    Returns:
        Mapping: modified mapping hierarchy
    """
    mappings = root_mapping.flatten()
    return unflatten(
        reversed(list(dropwhile(lambda m: m.position == Position.hidden and not m.filter_re, reversed(mappings))))
    )


class FixedValueMapping(ExportMapping):
    """Always yields a fixed value.

    Can be used as the topmost mapping.

    """

    MAP_TYPE = "FixedValue"

    def __init__(self, position, value, header="", filter_re=""):
        """
        Args:
            position (int or Position, optional): mapping's position
            value (Any): value to yield
            header (str, optional); A string column header that's yielded as 'first row', if not empty.
                The default is an empty string (so it's not yielded).
            filter_re (str, optional): A regular expression to filter the mapped values by
        """
        super().__init__(position, value, header, filter_re)


class EntityClassMapping(ExportMapping):
    """Maps entity classes.

    Can be used as the topmost mapping.
    """

    MAP_TYPE = "EntityClass"
    name_field = "entity_class_name"
    id_field = "entity_class_name"  # Use the class name here, for the sake of the standard excel export.

    def __init__(self, position, value=None, header="", filter_re="", highlight_position=None):
        super().__init__(position, value, header, filter_re)
        self.highlight_position = highlight_position

    def build_query_columns(self, db_map, columns):
        columns += [
            db_map.wide_entity_class_sq.c.id.label("entity_class_id"),
            db_map.wide_entity_class_sq.c.name.label("entity_class_name"),
            db_map.wide_entity_class_sq.c.dimension_id_list.label("dimension_id_list"),
            db_map.wide_entity_class_sq.c.dimension_name_list.label("dimension_name_list"),
        ]
        if self.highlight_position is not None:
            columns.append(db_map.entity_class_dimension_sq.c.dimension_id.label("highlighted_dimension_id"))

    def filter_query(self, db_map, query):
        if any(isinstance(m, (DimensionMapping, ElementMapping)) for m in self.flatten()):
            query = query.filter(db_map.wide_entity_class_sq.c.dimension_id_list != None)
        else:
            query = query.filter(db_map.wide_entity_class_sq.c.dimension_id_list == None)
        if self.highlight_position is not None:
            query = query.outerjoin(
                db_map.entity_class_dimension_sq,
                db_map.entity_class_dimension_sq.c.entity_class_id == db_map.wide_entity_class_sq.c.id,
            ).filter(db_map.entity_class_dimension_sq.c.position == self.highlight_position)
        return query

    def query_parents(self, what):
        if what == "dimension":
            return -1
        if what == "highlight_position":
            return self.highlight_position
        return super().query_parents(what)

    def _title_state(self, db_row):
        state = super()._title_state(db_row)
        state["dimension_id_list"] = getattr(db_row, "dimension_id_list")
        return state

    def to_dict(self):
        mapping_dict = super().to_dict()
        if self.highlight_position is not None:
            mapping_dict["highlight_position"] = self.highlight_position
        return mapping_dict

    @classmethod
    def reconstruct(cls, position, value, header, filter_re, ignorable, mapping_dict):
        highlight_position = mapping_dict.get("highlight_position")
        mapping = cls(position, value, header, filter_re, highlight_position)
        mapping.set_ignorable(ignorable)
        return mapping


class EntityMapping(ExportMapping):
    """Maps entities.

    Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`.
    """

    MAP_TYPE = "Entity"
    name_field = "entity_name"
    id_field = "entity_id"

    def build_query_columns(self, db_map, columns):
        columns += [
            db_map.wide_entity_sq.c.id.label("entity_id"),
            db_map.wide_entity_sq.c.name.label("entity_name"),
            db_map.wide_entity_sq.c.element_id_list,
            db_map.wide_entity_sq.c.element_name_list,
        ]
        if self.query_parents("highlight_position") is not None:
            columns.append(db_map.entity_element_sq.c.element_id.label("highlighted_element_id"))

    def filter_query(self, db_map, query):
        query = query.outerjoin(
            db_map.wide_entity_sq, db_map.wide_entity_sq.c.class_id == db_map.wide_entity_class_sq.c.id
        )
        if (highlight_position := self.query_parents("highlight_position")) is not None:
            query = query.outerjoin(
                db_map.entity_element_sq, db_map.entity_element_sq.c.entity_id == db_map.wide_entity_sq.c.id
            ).filter(db_map.entity_element_sq.c.position == highlight_position)
        return query

    def query_parents(self, what):
        if what == "dimension":
            return -1
        return super().query_parents(what)

    def _title_state(self, db_row):
        state = super()._title_state(db_row)
        state["element_id_list"] = getattr(db_row, "element_id_list")
        return state

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, EntityClassMapping)


class EntityGroupMapping(ExportMapping):
    """Maps entity groups.

    Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`.
    """

    MAP_TYPE = "EntityGroup"
    name_field = "group_name"
    id_field = "group_id"

    def build_query_columns(self, db_map, columns):
        columns += [db_map.ext_entity_group_sq.c.group_id, db_map.ext_entity_group_sq.c.group_name]

    def filter_query(self, db_map, query):
        return query.outerjoin(
            db_map.ext_entity_group_sq, db_map.ext_entity_group_sq.c.class_id == db_map.wide_entity_class_sq.c.id
        ).distinct()

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, EntityClassMapping)


class EntityGroupEntityMapping(ExportMapping):
    """Maps entities in object entity groups.

    Cannot be used as the topmost mapping; one of the parents must be :class:`EntityGroupMapping`.
    """

    MAP_TYPE = "EntityGroupEntity"
    name_field = "entity_name"
    id_field = "entity_id"

    def build_query_columns(self, db_map, columns):
        columns += [db_map.wide_entity_sq.c.id.label("entity_id"), db_map.wide_entity_sq.c.name.label("entity_name")]

    def filter_query(self, db_map, query):
        return query.filter(db_map.ext_entity_group_sq.c.member_id == db_map.wide_entity_sq.c.id)

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, EntityGroupMapping)


class DimensionMapping(ExportMapping):
    """Maps dimensions.

    Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`.
    """

    MAP_TYPE = "Dimension"
    name_field = "dimension_name_list"
    id_field = "dimension_id_list"
    _cached_dimension = None

    def _data(self, row):
        dimension_name_list = super()._data(row)
        if dimension_name_list is None:
            return None
        data = dimension_name_list.split(",")
        if self._cached_dimension is None:
            self._cached_dimension = self.query_parents("dimension")
        try:
            return data[self._cached_dimension]
        except IndexError:
            return ""

    def query_parents(self, what):
        if what != "dimension":
            return super().query_parents(what)
        return self.parent.query_parents(what) + 1

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, EntityClassMapping)


class ElementMapping(ExportMapping):
    """Maps elements.

    Cannot be used as the topmost mapping; must have :class:`EntityClassMapping` and :class:`EntityMapping`
    as parents.
    """

    MAP_TYPE = "Element"
    name_field = "element_name_list"
    id_field = "element_id_list"
    _cached_dimension = None

    def _data(self, row):
        element_name_list = super()._data(row)
        if element_name_list is None:
            return None
        data = element_name_list.split(",")
        if self._cached_dimension is None:
            self._cached_dimension = self.query_parents("dimension")
        try:
            return data[self._cached_dimension]
        except IndexError:
            return ""

    def query_parents(self, what):
        if what != "dimension":
            return super().query_parents(what)
        return self.parent.query_parents(what) + 1

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, DimensionMapping)


class ParameterDefinitionMapping(ExportMapping):
    """Maps parameter definitions.

    Cannot be used as the topmost mapping; must have an entity class mapping as one of parents.
    """

    MAP_TYPE = "ParameterDefinition"
    name_field = "parameter_definition_name"
    id_field = "parameter_definition_id"

    def build_query_columns(self, db_map, columns):
        columns += [
            db_map.parameter_definition_sq.c.id.label("parameter_definition_id"),
            db_map.parameter_definition_sq.c.name.label("parameter_definition_name"),
        ]

    def filter_query(self, db_map, query):
        if self.query_parents("highlight_position") is not None:
            return query.outerjoin(
                db_map.parameter_definition_sq,
                db_map.parameter_definition_sq.c.entity_class_id == db_map.entity_class_dimension_sq.c.dimension_id,
            )
        return query.outerjoin(
            db_map.parameter_definition_sq,
            db_map.parameter_definition_sq.c.entity_class_id == db_map.wide_entity_class_sq.c.id,
        )


class ParameterDefaultValueMapping(ExportMapping):
    """Maps scalar (non-indexed) default values

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping` as parent.
    """

    MAP_TYPE = "ParameterDefaultValue"

    def build_query_columns(self, db_map, columns):
        columns += [db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type]

    def _data(self, row):
        return from_database_to_single_value(row.default_value, row.default_type)

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, ParameterDefinitionMapping)


class ParameterDefaultValueTypeMapping(ParameterDefaultValueMapping):
    """Maps parameter value types.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping`, an entity mapping and
    an :class:`AlternativeMapping` as parents.
    """

    MAP_TYPE = "ParameterDefaultValueType"

    def __init__(self, position, value=None, header="", filter_re=""):
        filter_re = filter_re.replace("single_value", "float|str|bool")
        super().__init__(position, value, header, filter_re)

    def _data(self, row):
        type_ = row.default_type
        if type_ == "map":
            return f"{from_database_to_dimension_count(row.default_value, type_)}d_map"
        if type_ in ("time_series", "time_pattern", "array"):
            return type_
        return type_ if type_ else type_for_scalar(from_database(row.default_value, row.default_type))

    def _title_state(self, db_row):
        return {
            "type_and_dimensions": (
                db_row.default_type,
                from_database_to_dimension_count(db_row.default_value, db_row.default_type),
            )
        }

    def filter_query_by_title(self, query, title_state):
        with suppress(KeyError):
            del title_state["type_and_dimensions"]
        return query


class DefaultValueIndexNameMapping(_MappingWithLeafMixin, ParameterDefaultValueMapping):
    """Maps parameter default value index names.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping` as a parent.
    """

    MAP_TYPE = "DefaultValueIndexName"

    def _data(self, row):
        return row.default_value, row.default_type

    def _expand_data(self, data):
        yield from _expand_index_names(data, self)


class ParameterDefaultValueIndexMapping(_MappingWithLeafMixin, ExportMapping):
    """Maps default value indexes.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping` as parent.
    """

    MAP_TYPE = "ParameterDefaultValueIndex"

    def build_query_columns(self, db_map, columns):
        if any(c.name == "default_value" for c in columns):
            return
        columns += [db_map.parameter_definition_sq.c.default_value, db_map.parameter_definition_sq.c.default_type]

    def _expand_data(self, data):
        yield from _expand_indexed_data(data, self)

    def _data(self, row):
        return row.default_value, row.default_type

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, DefaultValueIndexNameMapping)


class ExpandedParameterDefaultValueMapping(ExportMapping):
    """Maps indexed default values.

    Whenever this mapping is a child of :class:`ParameterDefaultValueIndexMapping`, it maps individual values of
    indexed parameters.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping` as parent.
    """

    MAP_TYPE = "ExpandedDefaultValue"
    name_field = "default_value"
    id_field = "default_value"

    def _data(self, row):
        value = self.parent.current_leaf
        return value if not isinstance(value, IndexedValue) else value.VALUE_TYPE


class ParameterValueMapping(ExportMapping):
    """Maps scalar (non-indexed) parameter values.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping`, an entity mapping and
    an :class:`AlternativeMapping` as parents.
    """

    MAP_TYPE = "ParameterValue"
    _selects_value = False

    def build_query_columns(self, db_map, columns):
        if any(c.name == "value" for c in columns):
            return
        self._selects_value = True
        columns += [db_map.parameter_value_sq.c.value, db_map.parameter_value_sq.c.type]

    def filter_query(self, db_map, query):
        if not self._selects_value:
            return query
        if self.query_parents("highlight_position") is not None:
            return query.filter(
                and_(
                    db_map.parameter_value_sq.c.entity_id == db_map.entity_element_sq.c.element_id,
                    db_map.parameter_value_sq.c.parameter_definition_id == db_map.parameter_definition_sq.c.id,
                    db_map.parameter_value_sq.c.alternative_id == db_map.alternative_sq.c.id,
                )
            )
        return query.filter(
            and_(
                db_map.parameter_value_sq.c.entity_id == db_map.wide_entity_sq.c.id,
                db_map.parameter_value_sq.c.parameter_definition_id == db_map.parameter_definition_sq.c.id,
                db_map.parameter_value_sq.c.alternative_id == db_map.alternative_sq.c.id,
            )
        )

    def _data(self, row):
        return from_database_to_single_value(row.value, row.type)

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, (ParameterDefinitionMapping, EntityMapping, AlternativeMapping))


class ParameterValueTypeMapping(ParameterValueMapping):
    """Maps parameter value types.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping`, an entity mapping and
    an :class:`AlternativeMapping` as parents.
    """

    MAP_TYPE = "ParameterValueType"

    def __init__(self, position, value=None, header="", filter_re=""):
        filter_re = filter_re.replace("single_value", "float|str|bool")
        super().__init__(position, value, header, filter_re)

    def _data(self, row):
        type_ = row.type
        if type_ == "map":
            return f"{from_database_to_dimension_count(row.value, type_)}d_map"
        if type_ in ("time_series", "time_pattern", "array"):
            return type_
        return type_ if type_ else type_for_scalar(from_database(row.value, row.type))

    def _title_state(self, db_row):
        return {"type_and_dimensions": (db_row.type, from_database_to_dimension_count(db_row.value, db_row.type))}

    def filter_query_by_title(self, query, title_state):
        with suppress(KeyError):
            del title_state["type_and_dimensions"]
        return query


class IndexNameMapping(_MappingWithLeafMixin, ParameterValueMapping):
    """Maps parameter value index names.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping`, an entity mapping and
    an :class:`AlternativeMapping` as parents.
    """

    MAP_TYPE = "IndexName"

    def _data(self, row):
        return row.value, row.type

    def _expand_data(self, data):
        yield from _expand_index_names(data, self)


class ParameterValueIndexMapping(_MappingWithLeafMixin, ParameterValueMapping):
    """Maps parameter value indexes.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping`, an entity mapping and
    an :class:`AlternativeMapping` as parents.
    """

    MAP_TYPE = "ParameterValueIndex"

    def _data(self, row):
        return row.value, row.type

    def _expand_data(self, data):
        yield from _expand_indexed_data(data, self)

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, IndexNameMapping)


class ExpandedParameterValueMapping(ExportMapping):
    """Maps parameter values.

    Whenever this mapping is a child of :class:`ParameterValueIndexMapping`, it maps individual values of indexed
    parameters.

    Cannot be used as the topmost mapping; must have a :class:`ParameterDefinitionMapping`, an entity mapping and
    an :class:`AlternativeMapping` as parents.
    """

    MAP_TYPE = "ExpandedValue"
    name_field = "value"
    id_field = "value"

    def _data(self, row):
        value = self.parent.current_leaf
        return value if not isinstance(value, IndexedValue) else value.VALUE_TYPE


class ParameterValueListMapping(ExportMapping):
    """Maps parameter value list names.

    Can be used as the topmost mapping; in case the mapping has a :class:`ParameterDefinitionMapping` as parent,
    yields value list name for that parameter definition.
    """

    MAP_TYPE = "ParameterValueList"
    name_field = "parameter_value_list_name"
    id_field = "parameter_value_list_id"

    def build_query_columns(self, db_map, columns):
        columns += [
            db_map.parameter_value_list_sq.c.id.label("parameter_value_list_id"),
            db_map.parameter_value_list_sq.c.name.label("parameter_value_list_name"),
        ]

    def filter_query(self, db_map, query):
        if self.parent is None:
            return query
        return query.outerjoin(
            db_map.parameter_value_list_sq,
            db_map.parameter_value_list_sq.c.id == db_map.parameter_definition_sq.c.parameter_value_list_id,
        )


class ParameterValueListValueMapping(ExportMapping):
    """Maps parameter value list values.

    Cannot be used as the topmost mapping; must have a :class:`ParameterValueListMapping` as parent.

    """

    MAP_TYPE = "ParameterValueListValue"

    def build_query_columns(self, db_map, columns):
        columns += [db_map.ord_list_value_sq.c.value, db_map.ord_list_value_sq.c.type]

    def filter_query(self, db_map, query):
        return query.filter(db_map.ord_list_value_sq.c.parameter_value_list_id == db_map.parameter_value_list_sq.c.id)

    def _data(self, row):
        return from_database_to_single_value(row.value, row.type)

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, ParameterValueListMapping)


class AlternativeMapping(ExportMapping):
    """Maps alternatives.

    Can be used as the topmost mapping.
    """

    MAP_TYPE = "Alternative"
    name_field = "alternative_name"
    id_field = "alternative_id"

    def build_query_columns(self, db_map, columns):
        columns += [
            db_map.alternative_sq.c.id.label("alternative_id"),
            db_map.alternative_sq.c.name.label("alternative_name"),
            db_map.alternative_sq.c.description.label("description"),
        ]

    def filter_query(self, db_map, query):
        parent = self.parent
        while parent is not None:
            if isinstance(parent, ParameterDefinitionMapping):
                return query.filter(db_map.alternative_sq.c.id == db_map.parameter_value_sq.c.alternative_id)
            parent = parent.parent
        return query


class ScenarioMapping(ExportMapping):
    """Maps scenarios.

    Can be used as the topmost mapping.
    """

    MAP_TYPE = "Scenario"
    name_field = "scenario_name"
    id_field = "scenario_id"

    def build_query_columns(self, db_map, columns):
        columns += [
            db_map.scenario_sq.c.id.label("scenario_id"),
            db_map.scenario_sq.c.name.label("scenario_name"),
            db_map.scenario_sq.c.description.label("description"),
        ]


class ScenarioAlternativeMapping(ExportMapping):
    """Maps scenario alternatives.

    Cannot be used as the topmost mapping; must have a :class:`ScenarioMapping` as parent.
    """

    MAP_TYPE = "ScenarioAlternative"
    name_field = "alternative_name"
    id_field = "alternative_id"

    def build_query_columns(self, db_map, columns):
        if self._child is None:
            columns += [
                db_map.ext_scenario_sq.c.alternative_id,
                db_map.ext_scenario_sq.c.alternative_name,
                db_map.ext_scenario_sq.c.rank,
            ]
        else:
            # Legacy: expecting child to be ScenarioBeforeAlternativeMapping
            columns += [
                db_map.ext_linked_scenario_alternative_sq.c.alternative_id,
                db_map.ext_linked_scenario_alternative_sq.c.alternative_name,
            ]

    def filter_query(self, db_map, query):
        if self._child is None:
            return query.outerjoin(
                db_map.ext_scenario_sq,
                db_map.ext_scenario_sq.c.id == db_map.scenario_sq.c.id,
            ).order_by(db_map.ext_scenario_sq.c.name, db_map.ext_scenario_sq.c.rank)
        # Legacy: expecting child to be ScenarioBeforeAlternativeMapping
        return query.outerjoin(
            db_map.ext_linked_scenario_alternative_sq,
            db_map.ext_linked_scenario_alternative_sq.c.scenario_id == db_map.scenario_sq.c.id,
        )

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, ScenarioMapping)


class ScenarioBeforeAlternativeMapping(ExportMapping):
    """Maps scenario 'before' alternatives.

    Cannot be used as the topmost mapping; must have a :class:`ScenarioAlternativeMapping` as parent.
    """

    MAP_TYPE = "ScenarioBeforeAlternative"
    name_field = "before_alternative_name"
    id_field = "before_alternative_id"

    def build_query_columns(self, db_map, columns):
        columns += [
            db_map.ext_linked_scenario_alternative_sq.c.before_alternative_id,
            db_map.ext_linked_scenario_alternative_sq.c.before_alternative_name,
        ]

    @staticmethod
    def is_buddy(parent):
        return isinstance(parent, ScenarioAlternativeMapping)


class _DescriptionMappingBase(ExportMapping):
    """Maps descriptions."""

    MAP_TYPE = "Description"
    name_field = "description"
    id_field = "description"


class AlternativeDescriptionMapping(_DescriptionMappingBase):
    """Maps alternative descriptions.

    Cannot be used as the topmost mapping; must have :class:`AlternativeMapping` as parent.
    """

    MAP_TYPE = "AlternativeDescription"


class ScenarioDescriptionMapping(_DescriptionMappingBase):
    """Maps scenario descriptions.

    Cannot be used as the topmost mapping; must have :class:`ScenarioMapping` as parent.
    """

    MAP_TYPE = "ScenarioDescription"


class _FilteredQuery:
    """Helper class to define non-standard query filters."""

    def __init__(
        self,
        db_map: DatabaseMapping,
        query: Query,
        condition: Callable[[Any], bool],
        row_cache: dict[CacheKey, list[Row]],
    ):
        """
        Args:
            db_map: database mapping instance
            query: a query to filter
            condition: the filter condition
            row_cache: cache for fetched database rows
        """
        self._db_map = db_map
        self._query = query
        self._condition = condition
        self._row_cache = row_cache

    @property
    def statement(self):
        return self._query.statement

    def filter(self, *args):
        return _FilteredQuery(self._db_map, self._query.filter(*args), self._condition, self._row_cache)

    def __iter__(self):
        cache_key = self._query.statement._generate_cache_key().key
        if cache_key in self._row_cache:
            cache = self._row_cache[cache_key]
            yield from (db_row for db_row in cache if self._condition(db_row))
        else:
            self._row_cache[cache_key] = cache = []
            for db_row in self._query:
                cache.append(db_row)
                if self._condition(db_row):
                    yield db_row


class _Rewindable:
    def __init__(self, it):
        self._it = iter(it)
        self._seen = []
        self._seen_it = iter(self._seen)

    def rewind(self):
        self._seen_it = iter(self._seen)

    def __next__(self):
        try:
            return next(self._seen_it)
        except StopIteration:
            pass
        item = next(self._it)
        self._seen.append(item)
        return item

    def __iter__(self):
        return self


def pair_header_buddies(root_mapping):
    """Pairs mappings that have Position.header to their 'buddy' child mappings.

    Args:
        root_mapping (ExportMapping): root mapping

    Returns:
        list of tuple: pairs of parent mapping - buddy child mapping
    """

    @dataclass
    class Pairable:
        mapping: ExportMapping
        paired: bool

    pairables = [Pairable(m, False) for m in root_mapping.flatten()]
    buddies = []
    for i, parent in enumerate(pairables):
        if parent.mapping.position != Position.header:
            continue
        for child in pairables[i + 1 :]:
            if child.mapping.is_buddy(parent.mapping) and not child.paired:
                buddies.append((parent.mapping, child.mapping))
                child.paired = True
                break
    return buddies


def find_my_buddy(mapping, buddies):
    """Finds mapping's buddy.

    Args:
        mapping (ExportMapping): a mapping
        buddies (list of tuple): list of mapping - buddy mapping pairs

    Returns:
        ExportMapping: buddy mapping or None if not found
    """
    for parent, buddy in buddies:
        if mapping is parent:
            return buddy
    return None


def from_dict(serialized):
    """
    Deserializes mappings.

    Args:
        serialized (list): serialized mappings

    Returns:
        ExportMapping: root mapping
    """
    mappings = {
        klass.MAP_TYPE: klass
        for klass in (
            AlternativeDescriptionMapping,
            AlternativeMapping,
            DefaultValueIndexNameMapping,
            DimensionMapping,
            ElementMapping,
            ExpandedParameterDefaultValueMapping,
            ExpandedParameterValueMapping,
            FixedValueMapping,
            IndexNameMapping,
            EntityClassMapping,
            EntityGroupMapping,
            EntityGroupEntityMapping,
            EntityMapping,
            ParameterDefaultValueIndexMapping,
            ParameterDefaultValueMapping,
            ParameterDefaultValueTypeMapping,
            ParameterDefinitionMapping,
            ParameterValueIndexMapping,
            ParameterValueListMapping,
            ParameterValueListValueMapping,
            ParameterValueMapping,
            ParameterValueTypeMapping,
            ScenarioBeforeAlternativeMapping,
            ScenarioDescriptionMapping,
            ScenarioMapping,
            # FIXME
            # FeatureEntityClassMapping,
            # FeatureParameterDefinitionMapping,
            # ToolMapping,
            # ToolFeatureEntityClassMapping,
            # ToolFeatureParameterDefinitionMapping,
            # ToolFeatureRequiredFlagMapping,
            # ToolFeatureMethodEntityClassMapping,
            # ToolFeatureMethodParameterDefinitionMapping,
        )
    }
    legacy_mappings = {
        "ParameterIndex": ParameterValueIndexMapping,
        "ObjectClass": EntityClassMapping,
        "ObjectGroup": EntityGroupMapping,
        "ObjectGroupObject": EntityGroupEntityMapping,
        "Object": EntityMapping,
        "RelationshipClass": EntityClassMapping,
        "RelationshipClassObjectClass": DimensionMapping,
        "Relationship": EntityMapping,
        "RelationshipObject": ElementMapping,
        "RelationshipClassObjectHighlightingMapping": EntityClassMapping,
        "RelationshipObjectHighlightingMapping": ElementMapping,
    }
    mappings.update(legacy_mappings)
    flattened = []
    for mapping_dict in serialized:
        if mapping_dict["map_type"] == "ScenarioActiveFlag":
            # We don't support active flag exporting anymore.
            continue
        if (highlight_position := mapping_dict.get("highlight_dimension")) is not None:
            # legacy
            mapping_dict["highlight_position"] = highlight_position
        position = mapping_dict["position"]
        if isinstance(position, str):
            position = Position(position)
        ignorable = mapping_dict.get("ignorable", False)
        value = mapping_dict.get("value")
        header = mapping_dict.get("header", "")
        filter_re = mapping_dict.get("filter_re", "")
        flattened.append(
            mappings[mapping_dict["map_type"]].reconstruct(position, value, header, filter_re, ignorable, mapping_dict)
        )
    return unflatten(flattened)


def legacy_group_fn_from_dict(serialized):
    """Restores legacy group_fn attribute from serialized mappings.

    group_fn has been removed from export mappings but this serves for backwards compatibility.

    Args:
        serialized (list): serialized mappings

    Returns:
        str: name of the first group_fn attribute that was found in the serialized mappings or NoGroup if not found
    """
    for mapping_dict in serialized:
        group_fn = mapping_dict.get("group_fn")
        if group_fn is not None:
            return group_fn
    return NoGroup.NAME


def _expand_indexed_data(data, mapping):
    """Expands indexed data and updates the current_leaf attribute.

    Args:
        data (Any): data to expand
        mapping (ExportMapping): mapping whose data is being expanded

    Yields:
        Any: parameter value index
    """
    if not isinstance(mapping.parent, _MappingWithLeafMixin):
        # Get dict
        current_leaf = from_database(data[0], data[1])
        if data[1] == "map":
            current_leaf = convert_containers_to_maps(current_leaf)
    else:
        # Get leaf from parent
        current_leaf = mapping.parent.current_leaf
    if not isinstance(current_leaf, IndexedValue):
        # Nothing to expand. Set the current leaf so the child can find it
        mapping.current_leaf = current_leaf
        yield None
        return
    # Expand and set the current leaf so the child can find it
    for index, value in zip(current_leaf.indexes, current_leaf.values):
        mapping.current_leaf = value
        yield index


def _expand_index_names(data, mapping):
    """Expands index names and updates the current_leaf attribute.

    Args:
        data (Any): data to expand
        mapping (ExportMapping): mapping whose data is being expanded

    Yields:
        str: index name
    """
    if not isinstance(mapping.parent, _MappingWithLeafMixin):
        current_leaf = from_database(data[0], data[1])
        if data[1] == "map":
            current_leaf = convert_containers_to_maps(current_leaf)
    else:
        current_leaf = mapping.parent.current_leaf
    mapping.current_leaf = current_leaf
    yield current_leaf.index_name if isinstance(current_leaf, IndexedValue) else None
