from __future__ import annotations
from typing import Union, Optional, Type, TYPE_CHECKING
from uuid import uuid4

from whyqd.parsers import CoreParser, ScriptParser

from whyqd.models import ConstraintsModel, FieldModel
from whyqd.crosswalk.base import BaseSchemaAction

if TYPE_CHECKING:
    from whyqd.models import SchemaActionModel, ModifierModel
    from whyqd.core import SchemaDefinition
    import modin.pandas as pd


class ActionParser:
    """Parsing functions for action scripts.

    Can process and validate any action script. Scripts may be recursive and are of the form:

        "ACTION > 'destination field' < [modifier 'source field', {action script}]"

    Modifiers are dependent on the specific Action. Fields must be found in provided SchemaDefinitions.
    """

    def __init__(self, *, schema_source: SchemaDefinition = None, schema_destination: SchemaDefinition = None) -> None:
        """
        Parameters
        ----------
        schema_source: SchemaDefinition, optional
        schema_destination: SchemaDefinition, optional
        """
        self.core = CoreParser()
        self.parser = ScriptParser()
        self.schema = []
        if schema_source and schema_destination:
            self.set_schema(schema_source=schema_source, schema_destination=schema_destination)
        self.source_modifiers = {}
        self.modifier_names = set()

    ###################################################################################################
    ### PARSE TRANSFORM SCRIPT
    ###################################################################################################

    def parse(
        self, *, script: str, action: BaseSchemaAction
    ) -> dict[str, BaseSchemaAction | FieldModel | list[ModifierModel | FieldModel]]:
        """Parse a script for a base action and return the corresponding destination field, and source structure for
        transformation. Can, optionally, also update the existing schema list.

        Parameters
        ----------
        script: str
        action: BaseSchemaAction

        Raises
        ------
        ValueError if the script term is not recognised.

        Returns
        -------
        dict of the form:
            {
                'action': BaseSchemaAction,
                'destination': FieldModel | list[FieldModel],
                'source': list[ModifierModel | FieldModel]
            }
        """
        if not self.schema:
            raise ValueError("Schema has not been provided.")
        hexed_script = self.get_hexed_script(script=script)
        root = self.parser.get_split_terms(script=hexed_script, by="<")
        split_terms = self.parser.get_split_terms(script=root[0], by=">")
        # Get destination term
        destination = None
        if len(split_terms) == 2:
            destination = self.parser.get_literal(text=split_terms[1])
            if isinstance(destination, list):
                destination = [self.get_schema_field(term=d) for d in destination]
                assert all(isinstance(d, FieldModel) for d in destination)
            else:
                destination = self.get_schema_field(term=destination)
                assert isinstance(destination, FieldModel)
        # Get source term
        source = None
        if len(root) > 1:
            source = "<".join(root[1:])
        # Process initial response
        if action.settings.name == "NEW":
            # Special case where value is assigned as default to 'destination' field
            value = self.parser.get_literal(text=root[1])
            if len(value) > 1:
                raise ValueError(f"'New' actions must only contain a single value term. ({root[1:]})")
            destination.constraints = ConstraintsModel(**{"default": {"name": value[0]}})
            return {"action": action, "destination": destination}
        if not source:
            if action.settings.structure:
                # The structure for this action requires a source term
                raise ValueError(
                    f"{action.name} action requires a source term ({action.settings.structure}) but none found."
                )
            return {"action": action, "destination": destination}
        # If action does not include a structure, then no source term should be included
        if not action.settings.structure:
            # The structure for this action requires a source term
            raise ValueError(f"{action.settings.name} action does not include a source term but one found ({source}).")
        # Source exists *and* and is required, process the second part
        # Nested sources must not have destinations as these will be autogenerated
        last_i = None
        parsed = []
        for i, stack in list(self.parser.generate_contents(text=source)):
            # Generate contents yields the deepest, right-most nested bracket first, and goes from there.
            if not last_i:
                last_i = i
            if last_i == i:
                # Process nested sources on same level, e.g. [[nested_source1], [nested_source2]]
                parsed.append(
                    (stack, {"action": None, "source": self.parser.get_listed_literal(text=stack)}, uuid4().hex)
                )
            else:
                # Structure: ACTION < SOURCE
                # Process: as pop 'up' through nested sources, replace lower levels with hash key to simplify
                # extraction of the action, then replace the hash with a nested dictionary {action: , source: }
                # where each 'source' contains the input for that action, and is the input for the higher level.
                parsed_stack = stack
                for txt, prsed, hx in parsed:
                    # replace the txt with hx
                    parsed_stack = parsed_stack.replace(f"[{txt}]", hx)
                i_prsed = []
                for s in self.parser.get_split_terms(script=parsed_stack, by=",", maxsplit=-1):
                    splt = self.parser.get_split_terms(script=s, by="<")
                    if len(splt) == 1:
                        i_prsed.extend(self.parser.get_listed_literal(text=s))
                    else:
                        for txt, prsed, hx in parsed:
                            if hx in s:
                                prsed["action"] = self.parser.get_action_model(action=splt[0])
                                if not isinstance(prsed["action"], SchemaActionModel):
                                    raise ValueError(f"Only ACTIONS of Type `SchemaAction` can be nested ({stack}).")
                                i_prsed.append(prsed)
                parsed = [(stack, {"action": None, "source": i_prsed}, uuid4().hex)]
        # Once the stack is empty, need only the prsed 'source' section of the list
        source = [p[1]["source"] for p in parsed]
        source = self.recover_fields_from_hexed_script(parsed=source, action=action)[0]
        action.validate(destination=destination, source=source)
        return {"action": action, "destination": destination, "source": source}

    ###################################################################################################
    ### IMPLEMENT VALIDATED SCRIPT
    ###################################################################################################

    def transform(
        self,
        *,
        df: pd.DataFrame,
        action: Type[BaseSchemaAction],
        destination: FieldModel,
        source: Optional[list[Union[ModifierModel, dict]]] = None,
        **kwargs,
    ) -> pd.DataFrame:
        """
        A recursive transformation. A method should be a list fields upon which actions are applied, but
        each field may have nested sub-fields requiring their own actions. Before the action on the
        current field can be completed, it is necessary to perform the actions on each sub-field.

        Parameters
        ----------
        df: DataFrame
            Working data to be transformed
        action: SchemaActionModel
        destination: FieldModel
        source: list of ModifierModel, and dicts of nested transforms, default None
        assigned: list of dict
            Specific to CATEGORISE actions. Each dict has values for: Assignment ACTION, destination schema field,
            schema category, source data column, and a list of source data column category terms assigned to that
            schema category.

        Returns
        -------
        Dataframe
            Containing the implementation of all nested transformations
        """
        if not source:
            return action.transform(df=df, destination=destination)
        flattened_source = []
        for term in source:
            if isinstance(term, dict):
                # Nested transform
                nested_destination = term.get("destination")
                if not nested_destination:
                    # Need to create a temporary column ... the action will be performed here
                    # then this nested structure will be replaced by the output of this new column
                    # Temporary column is based on initial root FieldModel destination to ensure
                    # all schema parameters carried over.
                    nested_destination = destination.copy()
                    nested_destination.name = f"nested_{uuid4().hex}"
                df = self.transform(df, term["action"], nested_destination, term.get("source"))
                flattened_source.append(nested_destination)
            else:
                flattened_source.append(term)
        # Action transform
        return action.transform(df=df, destination=destination, source=flattened_source)

    ###################################################################################################
    ### SUPPORT UTILITIES
    ###################################################################################################

    def set_schema(
        self, *, schema_source: SchemaDefinition = None, schema_destination: SchemaDefinition = None
    ) -> None:
        """Set SchemaDefinitions for the parser.

        Parameters
        ----------
        schema_source: SchemaDefinition
        schema_destination: SchemaDefinition
        """
        if not schema_source or not schema_destination:
            raise ValueError("Schema for both source and destination has not been provided.")
        self.schema = [schema_source, schema_destination]

    def get_schema_field(self, *, term: str) -> FieldModel:
        """Recover a field model from a string.

        Raises
        ------
        ValueError if the field term is not recognised.

        Returns
        -------
        FieldModel
        """
        field = None
        for s in self.schema:
            field = s.fields.get(name=term)
            if field:
                break
        if not field:
            e = f"Field name is not recognised from either of the source or destination schema fields ({term})."
            if len(self.schema) == 1:
                e = f"Field name is not recognised from the schema fields ({term})."
            raise ValueError(e)
        return field

    def get_hexed_script(self, *, script: str) -> str:
        from whyqd.crosswalk.actions import default_actions

        # Changes fields to uuid hexes
        all_fields = [field for s in self.schema for field in s.get.fields]
        script = self.parser.get_hexed_script(script=script, fields=all_fields)
        self.modifier_names = set()
        self.source_modifiers = {}
        for action in default_actions:
            if action.name in script:
                action_modifiers = {}
                if action.modifiers:
                    # TODO: this isn't right ... modifiers may share a name but have different functions/descriptions
                    action_modifiers = {m.name: m for m in action.modifiers}
                for m in set(action_modifiers.keys()).difference(self.modifier_names):
                    # Preserve original Modifiers
                    self.source_modifiers[m] = action_modifiers[m]
                    script = script.replace(m, f",{m},")
                self.modifier_names.update(set(action_modifiers.keys()))
        return ",".join([s.strip() for s in script.split(",") if s.strip()])

    def recover_fields_from_hexed_script(
        self, *, parsed: list | dict, action: BaseSchemaAction
    ) -> list[dict[str, BaseSchemaAction | FieldModel | list[ModifierModel | FieldModel]]]:
        """
        Recovered fields and modifiers from a hexed script. Doubles up as a validator since any non-recovered
        term causes a ValueError. Recursive for deeply-nested scripts.

        Parameters
        ----------
        parsed: list, dict
        action: BaseSchemaAction

        Raises
        ------
        ValueError if the script term is not recognised.

        Returns
        -------
        list of dicts of the form:
            {
                'action': BaseSchemaAction,
                'destination': FieldModel,
                'source': list[ModifierModel | FieldModel]
            }
        """
        recovered_fields = []
        modifier_names = action.modifier_terms
        if not isinstance(parsed, list):
            parsed = [parsed]
        for term in parsed:
            if not term:
                # Blank string artifacts can be introduced
                continue
            recovered = None
            if isinstance(term, str) and term in modifier_names:
                recovered = action.get_modifier(term=term)
            elif isinstance(term, str):
                recovered = self.get_schema_field(term=term)
            elif isinstance(term, list):
                recovered = self.recover_fields_from_hexed_script(parsed=term, action=action)
            elif isinstance(term, dict):
                # Check if 'destination' is a column or schema field
                destination = term.get("destination")
                if destination:
                    destination = self.get_schema_field(term=destination)
                recovered = {
                    "action": term.get("action"),
                    "destination": destination,
                    "source": self.recover_fields_from_hexed_script(source=term.get("source"), action=action),
                }
            if not recovered:
                raise ValueError(f"Term ({term}) cannot be parsed.")
            recovered_fields.append(recovered)
        return recovered_fields
