import base64
from dataclasses import dataclass, field
from enum import Enum
import numbers
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, Any, cast
import json
import textwrap
from datetime import datetime, date

#--------------------------------------------------
# Base
#--------------------------------------------------

id = 0
def next_id():
    global id
    id += 1
    return id

@dataclass
class Base():
    id: int = field(default_factory=next_id, init=False)

    def __str__(self):
        return Printer().print(self)

#--------------------------------------------------
# Data
#--------------------------------------------------

@dataclass
class Type(Base):
    name: str = ""
    properties: List['Property'] = field(default_factory=list)
    parents: List['Type'] = field(default_factory=list)
    agent: Optional['Agent'] = None

    def isa(self, other: 'Type') -> bool:
        if self == other:
            return True
        for p in self.parents:
            if p.isa(other):
                return True
        return False

    def __hash__(self) -> int:
        return self.id

UnknownType = Type("Unknown")

@dataclass
class Property(Base):
    name: str
    type: Type
    is_input: bool = False

    def __hash__(self) -> int:
        return self.id

    @staticmethod
    def find(name:str, types:List[Type]) -> Optional['Property']:
        found = None
        for t in types:
            for p in t.properties:
                if p.name == name:
                    found = p
                    break
        return found

@dataclass
class Agent(Base):
    name: str
    platform: str  # SQL, Rel, JS, OpenAI, etc
    info: Any

    def __hash__(self) -> int:
        return self.id

#--------------------------------------------------
# Task
#--------------------------------------------------

class Behavior(Enum):
    Query = "query"
    Union = "union"
    OrderedChoice = "ordered_choice"
    Sequence = "sequence"
    Catch = "catch"

@dataclass
class Task(Type):
    behavior: Behavior = Behavior.Query
    items: List['Action'] = field(default_factory=list)
    bindings: Dict[Property, 'Var'] = field(default_factory=dict)
    inline: bool = False

    _task_builtin: Any = None
    def __post_init__(self):
        self.parents.append(Task._task_builtin)

    def __hash__(self) -> int:
        return self.id

    def return_cols(self) -> List[str]:
        namer = Namer()
        for i in self.items:
            ent = i.entity.value
            if i.action == ActionType.Call and isinstance(ent, Task):
                sub = ent.return_cols()
                if len(sub):
                    return sub
            if i.action == ActionType.Bind and i.entity.value == Builtins.Return:
                cols = []
                for (prop, var) in i.bindings.items():
                    cols.append(namer.get(var) or namer.get(prop))
                return cols
        return []


#--------------------------------------------------
# Var
#--------------------------------------------------

Value = Union[str, numbers.Number, bool, Task, Property, Type, bytes,
              datetime, date, List['Value'], List['Var']]

@dataclass
class Var(Base):
    id: int = field(default_factory=next_id, init=False)
    type: Type = UnknownType
    name: Optional[str] = None
    value: Optional[Value] = None

    def __hash__(self) -> int:
        return self.id

    def isa(self, other: Type) -> bool:
        return bool(self.type.isa(other) or (self.value and isinstance(self.value, Type) and self.value.isa(other)))

#--------------------------------------------------
# Action
#--------------------------------------------------

class ActionType(Enum):
    Get = "get"
    Call = "call"
    Persist = "persist"
    Unpersist = "unpersist"
    Bind = "bind"
    Construct = "construct"

    def is_effect(self):
        return self in [ActionType.Bind, ActionType.Persist, ActionType.Unpersist, ActionType.Construct]

@dataclass
class Action(Base):
    action: ActionType
    entity: Var
    types: List[Type] = field(default_factory=list)
    bindings: Dict[Property, Var] = field(default_factory=dict)

    def requires_provides(self, seen = None) -> Tuple[Set[Var], Set[Var]]:
        requires = set()
        provides = set()
        if not seen:
            seen = set()
        if self in seen:
            return requires, provides

        seen.add(self)
        if isinstance(self.entity.value, Task) and len(self.entity.value.items):
            for i in self.entity.value.items:
                r, p = i.requires_provides(seen)
                requires.update(r)
                provides.update(p)

        for k,var in self.bindings.items():
            into = provides
            if self.action == ActionType.Bind:
                into = requires
            elif k.is_input:
                into = requires

            if var.value is not None and isinstance(var.value, list):
                for v in var.value:
                    if isinstance(v, Var) and v.value is None:
                        into.add(v)
            elif var.value is None:
                into.add(var)

        return requires, provides

    def vars(self) -> Set[Var]:
        vars = set(self.bindings.values())
        vars.add(self.entity)
        return vars

    def append(self, item: Union[Type, Property, Task, Agent, Var], var: Optional[Var] = None):
        if isinstance(item, Type):
            self.types.append(item)
        elif isinstance(item, Property) and var:
            self.bindings[item] = var
        elif isinstance(item, Var):
            self.entity = item

    def __hash__(self) -> int:
        return self.id

#--------------------------------------------------
# All
#--------------------------------------------------

AllItems = Union[Type, Property, Task, Agent, Var, Action]

#--------------------------------------------------
# Builtins
#--------------------------------------------------

class BuiltinsClass:
    def __init__(self) -> None:
        self.Primitive = Type("Primitive")

        self.Unknown = UnknownType
        self.Any = Type("Any")
        self.String = Type("String", parents=[self.Primitive])
        self.Number = Type("Number", parents=[self.Primitive])
        self.Int = Type("Int", parents=[self.Number])
        self.Decimal = Type("Decimal", parents=[self.Number])
        self.Bool = Type("Bool", parents=[self.Primitive])
        self.Type = Type("Type", parents=[self.Primitive])
        self.Relation = Type("Relation", parents=[self.Primitive], properties=[Property(f"v{id}", self.Any) for i in range(20)])
        self.Anonymous = Type("Anonymous") # A thing we assume to exist in the host DB for which we don't have information.
        self.Task = Type("Task", parents=[self.Type, self.Relation])
        Task._task_builtin = self.Task
        self.Return = Type("Return", parents=[self.Relation], properties=[Property(f"v{id}", self.Any) for i in range(20)])
        self.RawCode = Type("RawCode", properties=[Property("code", self.String)])
        self.RawData = Type("RawData", parents=[self.Relation])
        self.Aggregate = Type("Aggregate", parents=[self.Task])
        self.Extender = Type("Extender", parents=[self.Aggregate])
        self.Quantifier = Type("Quantifier", parents=[self.Task])
        self.Not = Task("Not", parents=[self.Quantifier], properties=[Property("group", self.Any, True), Property("task", self.Task, True)])
        self.Exists = Task("Exists", parents=[self.Quantifier], properties=[Property("group", self.Any, True), Property("task", self.Task, True)])
        self.Every = Task("Every", parents=[self.Quantifier], properties=[Property("group", self.Any, True), Property("task", self.Task, True)])

        self.Identity = Type("Identity", parents=[self.String])
        self.make_identity = Task("make_identity", properties=[
            Property("params", self.Any, True),
            Property("identity", self.Any)
        ])

        self.Infix = Type("Infix")

        def binary_op(op, with_result=False):
            t = Task(name=op, parents=[self.Infix], properties=[
                Property("a", self.Number, True),
                Property("b", self.Number, True),
            ])
            if with_result:
                t.properties.append(Property("result", self.Number))
            return t

        def aggregate(op):
            return Task(name=op, parents=[self.Aggregate], properties=[
                Property("projection", self.Any, True),
                Property("group", self.Any, True),
                Property("result", self.Number),
            ])

        self.gt = binary_op(">")
        self.gte = binary_op(">=")
        self.lt = binary_op("<")
        self.lte = binary_op("<=")
        self.eq = binary_op("=")
        self.neq = binary_op("!=")

        self.plus = binary_op("+", True)
        self.minus = binary_op("-", True)
        self.mult = binary_op("*", True)
        self.div = binary_op("/", True)

        self.pow = binary_op("^", True)

        self.count = aggregate("count")
        self.sum = aggregate("sum")
        self.avg = aggregate("average")
        # self.min = aggregate("min")
        # self.max = aggregate("max")

Builtins = BuiltinsClass()

#--------------------------------------------------
# Builder
#--------------------------------------------------

class Builder():
    def __init__(self, to_var):
        self.to_var = to_var

    def call(self, op, params:List[Any]):
        a = Action(ActionType.Call, Var(Builtins.Task, value=op))
        for ix, p in enumerate(params):
            a.append(op.properties[ix], self.to_var(p))
        return a

    def return_(self, params:List[Any]):
        return self.relation_action(ActionType.Bind, Builtins.Return, params)

    def relation(self, name:str, field_count:int):
        return Type(name,
                    parents=[Builtins.Relation, Builtins.Anonymous],
                    properties=[Builtins.Relation.properties[i] for i in range(field_count)])

    def relation_action(self, action_type:ActionType, op:Type|Property, params:Iterable[Any]):
        a = Action(action_type, Var(Builtins.Type, value=op))
        for ix, p in enumerate(params):
            a.append(Builtins.Relation.properties[ix], self.to_var(p))
        return a

    def ident(self, action:Action):
        params = [Var(value=t.name) for t in action.types]
        params.extend(action.bindings.values())
        return self.call(Builtins.make_identity, [Var(value=params), action.entity])

    def property_named(self, name:str, types:List[Type]):
        found = Property.find(name, types)
        if not found:
            found = Property(name, Builtins.Any)
            if len(types):
                types[0].properties.append(found)
        return found

    def raw(self, code:str):
        code = textwrap.dedent(code)
        return self.relation_action(ActionType.Call, Builtins.RawCode, [code])

    def raw_task(self, code:str):
        return Task(behavior=Behavior.Sequence, items=[self.raw(textwrap.dedent(code))])

    def aggregate_def(self, op:str):
        return Task(name=op, parents=[Builtins.Aggregate], properties=[
            Property("group", Builtins.Any, True),
            Property("projection", Builtins.Any, True),
            Property("result", Builtins.Number),
        ])

    def eq(self, a:Any, b:Any):
        return self.call(Builtins.eq, [a, b])

#--------------------------------------------------
# Printer
#--------------------------------------------------

class Namer():
    def __init__(self, unnamed_vars=False):
        self.name_mapping = {}
        self.names = set()
        self.unnamed_vars = unnamed_vars

    def get_safe_name(self, name:str):
        if name in self.names:
            ix = 2
            while f"{name}{ix}" in self.names:
                ix += 1
            name = f"{name}{ix}"
        self.names.add(name)
        return name

    def get(self, item:Var|Task|Type|Property):
        if item.id in self.name_mapping:
            return self.name_mapping[item.id]

        name = item.name if not self.unnamed_vars or not isinstance(item, Var) else None
        name = name or ("t" if isinstance(item, Task) else "v")
        raw_name = name
        ix = 2
        while name in self.names:
            name = f"{raw_name}{ix}"
            ix += 1
        self.names.add(name)
        self.name_mapping[item.id] = name
        return name

    def reset(self):
        self.name_mapping.clear()
        self.names.clear()

class Printer():
    def __init__(self, unnamed_vars=False):
        self.indent = 0
        self.namer = Namer(unnamed_vars=unnamed_vars)

    def indent_str(self):
        return " " * 4 * self.indent

    def print(self, item:AllItems|Base|Value, is_sub=False):
        if isinstance(item, Task):
            return self.task(item, is_sub)
        elif isinstance(item, Type):
            return self.type(item, is_sub)
        elif isinstance(item, Property):
            return self.property(item, is_sub)
        elif isinstance(item, Agent):
            return self.agent(item, is_sub)
        elif isinstance(item, Var):
            return self.var(item, is_sub)
        elif isinstance(item, Action):
            return self.action(item, is_sub)
        elif isinstance(item, bytes):
            return base64.b64encode(item).decode()[:-2]
        elif isinstance(item, list):
            vs = [self.print(i, is_sub) for i in item]
            if len(item) > 20:
                return f"[{', '.join(vs[0:5])}, ... {', '.join(vs[-2:])}]"
            return f"[{', '.join(vs)}]"
        elif isinstance(item, str) or isinstance(item, bool) or isinstance(item, numbers.Number):
            return json.dumps(item)
        elif isinstance(item, datetime) or isinstance(item, date):
            return item.isoformat()
        raise Exception(f"Unknown item type: {type(item)}")

    def type(self, type:Type, is_sub=False):
        return type.name

    def property(self, property:Property, is_sub=False):
        return property.name

    def task(self, task:Task, is_sub=False):
        self.indent += 1
        items = '\n'.join([self.print(i, is_sub) for i in task.items])
        self.indent -= 1
        behavior = task.behavior.value
        if is_sub:
            final = f"""{behavior}\n{items}\n"""
        else:
            final = f"""{self.indent_str()}{behavior}\n{items}\n"""
        return final

    def agent(self, agent:Agent, is_sub=False):
        return agent.name

    def var(self, var:Var, is_sub=False):
        if var.value is not None:
            if isinstance(var.value, Task):
                return "SUBTASK"
            elif isinstance(var.value, Type) or isinstance(var.value, Property):
                return str(var.value)
            return self.print(var.value, is_sub)
        return self.namer.get(var)

    def action(self, action:Action, is_sub=False):
        op = action.action.value
        entity_value = action.entity.value
        as_relation = False
        as_quantifier = False
        subs = []
        body = ""

        if entity_value == Builtins.Return:
            op = "return"
            as_relation = True
        elif isinstance(entity_value, Task):
            if op == "call" and len(entity_value.items):
                subs.append(entity_value)
            elif op == "call" and entity_value.isa(Builtins.Quantifier):
                as_quantifier = True
                subs.append(action.bindings[entity_value.properties[1]].value)
            else:
                as_relation = True
        elif isinstance(entity_value, Property) or isinstance(entity_value, Type):
            as_relation = True

        if as_relation:
            args = " ".join([self.print(v, is_sub) for k,v in action.bindings.items()])
            if op == "return":
                body = args
            else:
                entity_value = cast(Type|Property|Task, entity_value)
                rel_name = entity_value.name or self.namer.get(entity_value)
                body = f"{rel_name}({args})"
        elif as_quantifier:
            entity_value = cast(Task, entity_value)
            group_vars:Any = action.bindings[entity_value.properties[0]].value
            group = ", ".join([self.print(v) for v in group_vars])
            if group:
                body = f"{entity_value.name.lower()}({group}) "
            else:
                body = f"{entity_value.name.lower()} "
        elif not len(subs):
            types = [t.name for t in action.types]
            args = [f"{k.name}({self.print(v, is_sub)})" for k,v in action.bindings.items()]
            body = f"{self.print(action.entity)} | " + " ".join(types + args)

        final = f"{self.indent_str()}{op :>9} | {body}"
        if len(subs):
            self.indent += 1
            final += "\n".join([self.print(s, True) for s in subs])
            self.indent -= 1
        return final



