import textwrap
from abc import abstractmethod
from collections import OrderedDict
from collections.abc import Iterable
from enum import Enum
from typing import Optional, Union

import relationalai.early_access.builder as qb
from relationalai.early_access.builder import define, where, annotations
from relationalai.early_access.dsl.bindings.common import Binding, BindableAttribute, IdentifierConceptBinding, \
    PartitioningConceptBinding, BindableTable, PrimitiveFilterBy, FilterBy, ReferentConceptBinding, Symbol
from relationalai.early_access.dsl.codegen.helpers import reference_entity, construct_entity
from relationalai.early_access.dsl.orm.constraints import Unique
from relationalai.early_access.dsl.orm.relationships import Role
from relationalai.early_access.rel.rel_utils import DECIMAL64_SCALE, DECIMAL128_SCALE


class GeneratedRelation(qb.Relationship):
    """
    A class representing a relation generated from a model.
    """
    def __init__(self, madlib, model, name):
        super().__init__(madlib, model=model.qb_model(), short_name=name)

    def __repr__(self):
        return f'GeneratedRelation({self._name})'


class InternallyGeneratedRelation(GeneratedRelation):
    """
    A class representing a relation generated by analyzing the model.
    """
    def __init__(self, madlib, dsl_model, name):
        super().__init__(madlib, dsl_model, name)


def filtering_view(row, filter_by: Optional[FilterBy], column_ref=None):
    """
    Generates filtering atoms for the map if the binding has any filters, which must be simple
    filters on the attribute view of the same table.
    """
    if not filter_by:
        return where()
    if isinstance(filter_by, PrimitiveFilterBy):
        assert column_ref is not None, 'Attribute must be provided if filter_by is a PrimitiveFilterBy'
        atoms = [_primitive_filtering_atom(column_ref, filter_by)]
    elif isinstance(filter_by, qb.Expression):
        atoms = [_expr_filtering_atom(row, filter_by)]
    elif isinstance(filter_by, Iterable):
        atoms = [_expr_filtering_atom(row, filter_expr) for filter_expr in filter_by]
    else:
        raise TypeError(f'Expected a PrimitiveFilterBy, Expression, or Iterable of Expressions, got {type(filter_by)}')
    return where(*atoms)

def _primitive_filtering_atom(column_ref, condition: PrimitiveFilterBy):
    """
    Generates a filtering atom for the map based on a primitive filter.
    """
    if isinstance(condition, Enum):
        condition = condition.name
    return where(
        column_ref == condition
    )

def _expr_filtering_atom(row, condition: qb.Expression):
    """
    Generates a filtering atom for the map based on a condition.
    """
    if not isinstance(condition, qb.Expression):
        raise TypeError(f'Expected an Expression, got {type(condition)}')
    params = condition._params
    if len(params) != 2:
        raise ValueError(f'Expected a condition with two parameters, got {len(params)}: {condition}')
    column, value = params
    if isinstance(value, Enum):
        value = value.name
    orig = column.type().ref()
    column_relation = Symbol(column.physical_name())
    return where(
        column(column_relation, row, orig),
        qb.Expression(condition._op, orig, value)
    )


class AbstractMap(GeneratedRelation):
    """
    A class representing an abstract map.
    """
    def __init__(self, madlib, dsl_model, name):
        super().__init__(madlib, dsl_model, name)

    @staticmethod
    def _filtering_atoms(conditions, row, column_ref=None):
        return filtering_view(row, conditions, column_ref)


class RoleMap(AbstractMap):
    """
    A class representing a value map relation.
    """

    def __init__(self, madlib, model, name, role, functional: bool = False):
        super().__init__(madlib, model, name)
        self._role = role

    def role(self):
        """
        Returns the role of the value map.
        """
        return self._role

    def value_player(self):
        """
        Returns the value player of the value map.
        """
        return self._role.player()

    @abstractmethod
    def column(self) -> BindableAttribute:
        """
        Returns the bindable column associated with this role map.
        """
        raise NotImplementedError("Subclasses must implement this method")

    def table(self) -> BindableTable:
        """
        Returns the table associated with the role map.
        """
        return self.column().table

    def __repr__(self):
        return f'RoleMap({self._name})'


class ValueMap(RoleMap):
    """
    A class representing a value map relation.
    """

    def __init__(self, model, binding: Binding, role, value_converter: Optional['ValueConverter']=None):
        madlib, name = self._handle_params(binding, role)
        super().__init__(madlib, model, name, role)
        self._binding = binding
        self._value_converter = value_converter
        self._generate_body()

    @staticmethod
    def _handle_params(binding, role):
        if not role.player()._is_primitive():
            raise TypeError(f'Cannot construct a value map for {role}, concept is not a value type')

        table_name = binding.column.table.physical_name()
        concept = role.player()
        name = f'{table_name}_row_to_{concept}'
        madlib = f'{name} {{row:Integer}} {{val:{concept}}}'
        return madlib, name

    def _generate_body(self):
        column = self._binding.column
        converter = self._value_converter
        row = qb.Integer.ref()
        result_type = self.value_player()
        if converter:
            val, converted = column.type().ref(), result_type.ref()
            filter_atoms = self._filtering_atoms(self._binding.filter_by, row, val)
            define(self(row, result_type(converted))).where(
                where(
                    column(row, val),
                    converter(val, converted),
                    filter_atoms
                )
            )
        else:
            val = column.type().ref()
            filter_atoms = self._filtering_atoms(self._binding.filter_by, row, val)
            define(
                self(row, result_type(val))
            ).where(
                column(row, val),
                filter_atoms
            )

    def binding(self) -> Binding:
        """
        Returns the binding of the value map.
        """
        return self._binding

    def column(self) -> BindableAttribute:
        return self._binding.column

    def __repr__(self):
        return f'ValueMap({self._name})'


class SimpleConstructorEntityMap(RoleMap):
    """
    A class representing an entity map relation.
    """

    def __init__(self, model, binding: Binding, role: Role, identifier_to_role_map: OrderedDict[Unique, 'RoleMap']):
        madlib, name = self._handle_params(binding, role)
        super().__init__(madlib, model, name, role)
        self._binding = binding
        self._identifier_to_role_map = identifier_to_role_map
        self._reference_role_map = list(identifier_to_role_map.items())[-1][1]  # last role map is the reference one
        self._generate_body()

    @staticmethod
    def _handle_params(binding, role):
        if role.player()._is_primitive():
            raise TypeError(f'Cannot construct an entity map for {role}, concept is not an entity type')

        table_name = binding.column.table.physical_name()
        concept = role.player()
        name = f'{table_name}_row_to_{concept}'
        madlib = f'{name} {{row:Integer}} {{val:{concept}}}'
        return madlib, name

    def _generate_body(self):
        concept = self.role().player()

        row = qb.Integer.ref()
        if not self._should_reference():
            # then create entities
            role_maps = self._identifier_to_role_map.values()
            values = [role_map.value_player().ref() for role_map in role_maps]
            where(
                *[role_map(row, value) for role_map, value in zip(role_maps, values)]
            ).define(construct_entity(concept, *values))

        # and populate the entity map, role_map is always the last one in the identifier_to_role_map
        value = self._reference_role_map.value_player().ref()
        where(
            self._reference_role_map(row, value), entity := reference_entity(concept, value),
            self._concept_population_atom(entity)
        ).define(self(row, entity))

    def _concept_population_atom(self, entity):
        """
        Generates the *optional* population atom for the entity map.
        Only used for bindings that have a FK.
        """
        if self._should_reference():
            return where(
                self.role().player()(entity)
            )
        else:
            return where()

    def _should_reference(self) -> bool:
        return self._reference_role_map.column().references_column is not None

    def binding(self) -> Binding:
        """
        Returns the binding of the entity map.
        """
        return self._binding

    def column(self) -> BindableAttribute:
        """
        Returns the bindable column associated with this entity map.
        """
        return self._binding.column

    def __repr__(self):
        return f'CtorEntityMap({self._name})'


class ReferentEntityMap(RoleMap):
    """
    A class representing a referent entity map relation.
    """

    def __init__(self, model, binding: Binding, role: Role, constructing_role_map: 'RoleMap'):
        madlib, name = self._handle_params(binding, role)
        super().__init__(madlib, model, name, role)
        self._binding = binding
        self._constructing_role_map = constructing_role_map
        self._generate_body()

    @staticmethod
    def _handle_params(binding, role):
        if role.player()._is_primitive():
            raise TypeError(f'Cannot construct an entity map for {role}, concept is not an entity type')

        table_name = binding.column.table.physical_name()
        concept = role.player()
        name = f'{table_name}_row_ref_to_{concept}'
        madlib = f'{name} {{row:Integer}} {{val:{concept}}}'
        return madlib, name

    def _generate_body(self):
        concept = self.role().player()

        row, value = qb.Integer.ref(), self._constructing_role_map.value_player().ref()
        where(
            self._constructing_role_map(row, value),
            entity := reference_entity(concept, value),
            concept(entity)
        ).define(self(row, entity))

    def binding(self) -> Binding:
        """
        Returns the binding of the entity map.
        """
        return self._binding

    def column(self) -> BindableAttribute:
        """
        Returns the bindable column associated with this entity map.
        """
        return self._binding.column

    def __repr__(self):
        return f'ReferentEntityMap({self._name})'

class CompositeEntityMap(AbstractMap):
    """
    A class representing a composite entity map relation.

    Takes a sequence of EntityMaps and constructs a composite entity type.
    """

    def __init__(self, model, entity_concept: qb.Concept, *role_maps: 'RoleMap'):
        madlib, name = self._handle_params(entity_concept, *role_maps)
        super().__init__(madlib, model, name)
        self._entity_concept = entity_concept
        self._role_maps = role_maps
        self._generate_body()

    def _handle_params(self, entity_concept: qb.Concept, *role_maps: 'RoleMap'):
        if entity_concept._is_primitive():
            raise TypeError(f'Cannot construct a composite entity map for {entity_concept},'
                            f' concept is not an entity type')
        if len(role_maps) < 2:
            raise ValueError('CompositeEntityMap requires at least two EntityMaps')

        role_map = role_maps[0]
        table = role_map.binding().column.table
        self._table = table
        name = f'{table.physical_name()}_row_to_{entity_concept}'
        madlib = f'{name} {{row:Integer}} {{val:{entity_concept}}}'
        return madlib, name

    def _generate_body(self):
        row = qb.Integer.ref()
        vars = [role_map.value_player().ref() for role_map in self._role_maps]

        # construct entities
        where(
            self._body_formula(row, *vars),
            entity := construct_entity(self._entity_concept, *vars),
        ).define(entity)

        # populate the entity map
        where(
            self._body_formula(row, *vars),
            entity := reference_entity(self._entity_concept, *vars),
            #
        ).define(self(row, entity))

    def _body_formula(self, row, *vars):
        return where(
            *[role_map(row, var) for role_map, var in zip(self._role_maps, vars)],
        )

    def value_player(self):
        """
        Returns the value player of the composite entity map.
        """
        return self._entity_concept

    def table(self):
        """
        Returns the table associated with the composite entity map.
        """
        return self._table

    def __repr__(self):
        return f'CompositeEntityMap({self._name})'


AbstractEntityMap = Union['SimpleConstructorEntityMap', 'ReferentEntityMap', 'CompositeEntityMap', 'UnionEntityMap']

class EntitySubtypeMap(AbstractMap):
    """
    A class representing an entity subtype map relation.
    """

    def __init__(self, model, binding: Union[IdentifierConceptBinding, PartitioningConceptBinding, ReferentConceptBinding],
                 ctor_entity_map: 'AbstractEntityMap'):
        madlib, name = self._handle_params(binding)
        super().__init__(madlib, model, name)
        self._binding = binding
        self._ctor_entity_map = ctor_entity_map
        self._generate_body()

    @staticmethod
    def _handle_params(binding: Union[IdentifierConceptBinding, PartitioningConceptBinding, ReferentConceptBinding]):
        table = binding.column.table
        entity_concept = binding.entity_type
        name = f'{table.physical_name()}_row_to_{entity_concept}'
        madlib = f'{name} {{row:Integer}} {{val:{entity_concept}}}'
        return madlib, name

    def _generate_body(self):
        subtype = self._binding.entity_type
        row, parent_entity = qb.Integer.ref(), self._ctor_entity_map.value_player().ref()

        # derive subtype population
        self._generate_body_formula(row, parent_entity).define(subtype(parent_entity))

        # populate the entity subtype map
        self._generate_body_formula(row, parent_entity).define(self(row, parent_entity))

    def _generate_body_formula(self, row, parent_entity):
        filtering_atom = self._get_filtering_atom(row)
        return where(
            self._ctor_entity_map(row, parent_entity),
            filtering_atom
        )

    def _get_filtering_atom(self, row):
        if isinstance(self._binding, PartitioningConceptBinding):
            filter_expr = self._binding.column == self._binding.has_value
            return self._filtering_atoms(filter_expr, row)
        else:
            return where()

    def _should_reference(self) -> bool:
        return self._role_map.column().references_column is not None

    def binding(self):
        """
        Returns the binding of the entity subtype map.
        """
        return self._binding

    def value_player(self):
        """
        Returns the subtype of the entity subtype map.
        """
        return self._binding.entity_type

    def column(self) -> BindableAttribute:
        """
        Returns the bindable column associated with this entity subtype map.
        """
        return self._binding.column

    def table(self) -> BindableTable:
        """
        Returns the table associated with the entity subtype map.
        """
        return self._binding.column.table

    def __repr__(self):
        return f'EntitySubtypeMap({self._name})'


class UnionEntityMap(AbstractMap):
    """
    A class representing a union entity map relation.
    """

    def __init__(self, model, entity_concept: qb.Concept, *entity_maps: AbstractEntityMap, generate_population: bool=False):
        madlib, name = self._handle_params(entity_concept, *entity_maps)
        super().__init__(madlib, model, name)
        self._entity_type = entity_concept
        self._entity_maps = list(entity_maps)
        self._generate_population = generate_population
        self._generate_body()

    def _handle_params(self, entity_concept: qb.Concept, *entity_maps: AbstractEntityMap):
        if len(entity_maps) == 0:
            raise ValueError('UnionEntityMap requires at least one EntityMap')
        # pick an arbitrary entity map to get the table, as they must all have the same
        table = entity_maps[0].table()
        self._table = table
        name = f'{table.physical_name()}_row_to_{entity_concept}'
        madlib = f'{name} {{row:Integer}} {{val:{entity_concept}}}'
        return madlib, name

    def _generate_body(self):
        for entity_map in self._entity_maps:
            self._generate_body_rule(entity_map)

    def _generate_body_rule(self, entity_map: AbstractEntityMap):
        row, entity = qb.Integer.ref(), entity_map.value_player().ref()

        if self._generate_population:
            # derive type population
            where(
                entity_map(row, entity)
            ).define(self._entity_type(entity))

        # derive union entity map
        where(
            entity_map(row, entity)
        ).define(self(row, entity))

    def value_player(self):
        """
        Returns the type of the entity map.
        """
        return self._entity_type

    def table(self):
        """
        Returns the table associated with the entity map.
        """
        return self._table

    def update(self, entity_map: AbstractEntityMap):
        """
        Updates the union entity map with a new entity map.
        """
        if entity_map._id in [existing._id for existing in self._entity_maps]:
            return
        self._entity_maps.append(entity_map)
        self._generate_body_rule(entity_map)

    def __repr__(self):
        return f'UnionEntityMap({self._name})'


class ValueConverter(InternallyGeneratedRelation):
    """
    Base class for value converter relations.
    """
    def __init__(self, madlib, dsl_model, name):
        super().__init__(madlib, dsl_model, name)
        self.annotate(annotations.external)

    @abstractmethod
    def result_type(self) -> qb.Concept:
        pass

# Note: this conversion is only needed until CDC standardizes on the scale and size of the Decimal type.

class DecimalValueConverter(ValueConverter):
    """
    A class representing a decimal value converter relation.
    """

    def __init__(self, model, size_from: int, scale_from: int, type_to: qb.Concept):
        madlib, name = self._generate_madlib_and_name(size_from, scale_from, type_to)
        super().__init__(madlib, model, name)
        self._validate(type_to)
        self._size_from = size_from
        self._scale_from = scale_from
        self._type_to = type_to
        self._size_to = 64 if type_to is qb.Decimal64 else 128
        self._scale_to = DECIMAL64_SCALE if type_to is qb.Decimal64 else DECIMAL128_SCALE
        self._generate_body()

    @staticmethod
    def _validate(type_to):
        if type_to is not qb.Decimal64 and type_to is not qb.Decimal128:
            raise TypeError(f'Expected Decimal64 or Decimal128, got {type_to}')

    def result_type(self) -> qb.Concept:
        return self._type_to

    @staticmethod
    def _generate_madlib_and_name(size_from: int, scale_from: int, type_to: qb.Concept):
        name = f'value_converter_{type_to}_from_{size_from}_{scale_from}'
        madlib = f'{name} {{Decimal}} {{{type_to}}}'
        return madlib, name

    def _generate_body(self):
        src = textwrap.dedent(f"""
        @inline
        def {self._name}(orig, rez):
            exists((t_size, t_scale, size, scale, int_val) |
                ^FixedDecimal(t_size, t_scale, int_val, orig) and
                ::std::mirror::lower(t_size, size) and
                ::std::mirror::lower(t_scale, scale) and
                decimal({self._size_to}, {self._scale_to}, int_val / (10 ^ scale), rez)
            )""")
        assert self._model is not None, "Model must be defined before defining a relation"
        self._model.define(qb.RawSource('rel', src))

    def __repr__(self):
        return f'DecimalValueConverter({self._name})'
