import textwrap
from typing import Any, Tuple, List
import base64

from pandas import DataFrame
import pandas as pd

from ..clients.config import Config
from ..clients.client import Client, ResourceProvider
from .. import dsl, rel, metamodel as m
from railib import api

#--------------------------------------------------
# Constants
#--------------------------------------------------

UNIXEPOCH = 62135683200000
MILLISECONDS_PER_DAY = 24 * 60 * 60 * 1000

#--------------------------------------------------
# Resources
#--------------------------------------------------

class Resources(ResourceProvider):
    def __init__(self, profile:str|None=None, config:Config|None=None):
        super().__init__(profile, config=config)
        self._ctx = None

    def _api_ctx(self):
        if not self._ctx:
            self._ctx = api.Context(**self.config.to_rai_config())
        return self._ctx

    def reset(self):
        self._ctx = None

    #--------------------------------------------------
    # Engines
    #--------------------------------------------------

    def list_engines(self):
        return api.list_engines(self._api_ctx())

    def create_engine(self, name:str, size:str, pool:str=""):
        return api.create_engine_wait(self._api_ctx(), name, size)

    def delete_engine(self, name:str):
        return api.delete_engine(self._api_ctx(), name)

    #--------------------------------------------------
    # Graphs
    #--------------------------------------------------

    def list_graphs(self) -> List[Any]:
        return api.list_databases(self._api_ctx())

    def get_graph(self, name:str):
        return api.get_database(self._api_ctx(), name)

    def create_graph(self, name: str):
        return api.create_database(self._api_ctx(), name)

    def delete_graph(self, name:str):
        return api.delete_database(self._api_ctx(), name)

    #--------------------------------------------------
    # Models
    #--------------------------------------------------

    def list_models(self, database: str, engine: str):
        return api.list_databases(self._api_ctx())

    def create_models(self, database: str, engine: str, models:List[Tuple[str, str]]) -> List[Any]:
        lines = []
        for (name, code) in models:
            code = code.replace("\"", "\\\"")
            name = name.replace("\"", "\\\"")
            lines.append(textwrap.dedent(f"""
            def delete:rel:catalog:model["{name}"] = rel:catalog:model["{name}"]
            def insert:rel:catalog:model["{name}"] = \"\"\"{code}\"\"\"
            """))
        rel_code = "\n\n".join(lines)
        results = self.exec_raw(database, engine, rel_code, readonly=False)
        if results.problems:
            return results.problems
        return []

    def delete_model(self, database:str, engine:str, name:str):
        return api.delete_model(self._api_ctx(), database, engine, name)

    #--------------------------------------------------
    # Exports
    #--------------------------------------------------

    def list_exports(self, database: str, engine: str):
        raise Exception("Azure doesn't support exports")

    def create_export(self, database: str, engine: str, name: str, inputs: List[str], out_fields: List[str], code: str):
        raise Exception("Azure doesn't support exports")

    def delete_export(self, database: str, engine: str, name: str):
        raise Exception("Azure doesn't support exports")

    #--------------------------------------------------
    # Imports
    #--------------------------------------------------

    def list_imports(self, model:str):
        raise Exception("Azure doesn't support imports")

    def create_import_stream(self, object:str, model:str, rate = 1):
        raise Exception("Azure doesn't support import streams")

    def delete_import(self, object:str, model:str):
        raise Exception("Azure doesn't support imports")

    #--------------------------------------------------
    # Exec
    #--------------------------------------------------

    def exec_raw(self, database:str, engine:str, raw_code:str, readonly=True, raw_results=True):
        return api.exec(self._api_ctx(), database, engine, raw_code, readonly=readonly)

    def _has_errors(self, results):
        if len(results.problems):
            for problem in results.problems:
                if problem['is_error'] or problem['is_exception']:
                    return True


    def format_results(self, results, task:m.Task) -> Tuple[DataFrame, List[Any]]:
        data_frame = DataFrame()
        if not self._has_errors(results) and len(results.results):
            for result in results.results:
                types = [t for t in result["relationId"].split("/") if t != "" and not t.startswith(":")]
                result_frame:DataFrame = result["table"].to_pandas()
                for i, col in enumerate(result_frame.columns):
                    if types[i] == "UInt128":
                        result_frame[col] = result_frame[col].apply(lambda x: base64.b64encode(x.tobytes()).decode()[:-2])
                    if types[i] == "Dates.DateTime":
                        result_frame[col] = pd.to_datetime(result_frame[col] - UNIXEPOCH, unit="ms")
                    if types[i] == "Dates.Date":
                        result_frame[col] = pd.to_datetime(result_frame[col] * MILLISECONDS_PER_DAY - UNIXEPOCH, unit="ms")
                ret_cols = task.return_cols()
                if len(ret_cols) and len(result_frame.columns) == len(ret_cols):
                    result_frame.columns = task.return_cols()[0:len(result_frame.columns)]
                result["table"] = result_frame
                if ":output" in result["relationId"]:
                    data_frame = pd.concat([data_frame, result_frame], ignore_index=True)
        return (data_frame, results.problems)

#--------------------------------------------------
# Graph
#--------------------------------------------------

def Graph(name, dry_run=False):
    client = Client(Resources(), rel.Compiler(), name, dry_run=dry_run)
    client.install("pyrel_base", dsl.build.raw_task("""
        @inline
        def make_identity(x..., z) =
            hash128({x...}, x..., u) and
            hash_value_uint128_convert(u, z)
            from u

        @inline
        def pyrel_default[F, c](k..., v) =
            F(k..., v) or (not F(k..., _) and v = c)
    """))
    return dsl.Graph(client, name)