# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/classes/50_DomoDataset_Stream.ipynb.

# %% ../../nbs/classes/50_DomoDataset_Stream.ipynb 3
from __future__ import annotations

from dataclasses import dataclass, field, asdict
from typing import List, Optional, Union, Any
from enum import Enum

import httpx
from sqlglot import parse_one, exp

import domolibrary.client.DomoAuth as dmda
import domolibrary.client.DomoError as dmde

import domolibrary.routes.stream as stream_routes

# %% auto 0
__all__ = ['StreamConfig_Mapping_amazon_s3_assumerole', 'StreamConfig_Mapping_snowflake_unload_v2',
           'StreamConfig_Mapping_snowflake_federated', 'StreamConfig_Mapping_snowflake',
           'StreamConfig_Mapping_adobe_analytics_v2', 'StreamConfig_Mapping_amazon_athena_high_bandwidth',
           'StreamConfig_Mapping_aws_athena', 'StreamConfig_Mapping_dataset_copy', 'StreamConfig_Mapping_default',
           'StreamConfig_Mapping_domo_csv', 'StreamConfig_Mapping_google_sheets',
           'StreamConfig_Mapping_google_spreadsheets', 'StreamConfig_Mapping_postgresql',
           'StreamConfig_Mapping_qualtrics', 'StreamConfig_Mapping_sharepointonline',
           'StreamConfig_Mapping_snowflake_internal_unload', 'StreamConfig_Mapping_snowflakekeypairauthentication',
           'StreamConfig_Mapping_snowflake_writeback', 'StreamConfig_Mapping', 'StreamConfig_Mappings', 'StreamConfig',
           'Dataset_Stream_GET_Error', 'DomoStream']

# %% ../../nbs/classes/50_DomoDataset_Stream.ipynb 7
@dataclass
class StreamConfig_Mapping:
    data_provider_type: str
    sql: str = None
    warehouse: str = None
    database_name: str = None
    s3_bucket_category: str = None

    is_default: bool = False

    table_name: str = None
    src_url: str = None
    google_sheets_file_name: str = None
    adobe_report_suite_id: str = None
    qualtrics_survey_id: str = None

    def search_keys_by_value(
        self,
        value_to_search: str,
    ) -> Union[StreamConfig_Mapping, None]:

        if self.is_default:
            if value_to_search in ["enteredCustomQuery", "query", "customQuery"]:
                return "sql"

        return next(
            (key for key, value in asdict(self).items() if value == value_to_search),
            None,
        )


StreamConfig_Mapping_amazon_s3_assumerole = StreamConfig_Mapping(
    data_provider_type="amazon_s3_assumerole", s3_bucket_category="filesDiscovery"
)

StreamConfig_Mapping_snowflake_unload_v2 = StreamConfig_Mapping(
    data_provider_type="snowflake_unload_v2",
    sql="query",
    warehouse="warehouseName",
    database_name="databaseName",
)

StreamConfig_Mapping_snowflake_federated = StreamConfig_Mapping(
    data_provider_type="snowflake_federated", sql=None
)

StreamConfig_Mapping_snowflake = StreamConfig_Mapping(
    data_provider_type="snowflake",
    sql="query",
    warehouse="warehouseName",
    database_name="databaseName",
    s3_bucket_category=None,
)

StreamConfig_Mapping_adobe_analytics_v2 = StreamConfig_Mapping(
    data_provider_type="adobe-analytics-v2",
    sql="query",
    adobe_report_suite_id="report_suite_id",
)

StreamConfig_Mapping_amazon_athena_high_bandwidth = StreamConfig_Mapping(
    data_provider_type="amazon-athena-high-bandwidth",
    sql="enteredCustomQuery",
    database_name="databaseName",
)

StreamConfig_Mapping_aws_athena = StreamConfig_Mapping(
    data_provider_type="aws-athena",
    sql="query",
    database_name="databaseName",
    table_name="tableName",
)

StreamConfig_Mapping_dataset_copy = StreamConfig_Mapping(
    data_provider_type="dataset-copy", src_url="datasourceUrl"
)

StreamConfig_Mapping_default = StreamConfig_Mapping(
    data_provider_type="default", is_default=True
)

StreamConfig_Mapping_domo_csv = StreamConfig_Mapping(
    data_provider_type="domo-csv", src_url="datasourceUrl"
)

StreamConfig_Mapping_google_sheets = StreamConfig_Mapping(
    data_provider_type="google-sheets", google_sheets_file_name="spreadsheetIDFileName"
)

StreamConfig_Mapping_google_spreadsheets = StreamConfig_Mapping(
    data_provider_type="google-spreadsheets",
    google_sheets_file_name="spreadsheetIDFileName",
)

StreamConfig_Mapping_postgresql = StreamConfig_Mapping(
    data_provider_type="postgresql",
    sql="query",
)

StreamConfig_Mapping_qualtrics = StreamConfig_Mapping(
    data_provider_type="qualtrics",
    qualtrics_survey_id="survey_id",
)

StreamConfig_Mapping_sharepointonline = StreamConfig_Mapping(
    data_provider_type="sharepointonline",
    src_url="relativeURL",
)

StreamConfig_Mapping_snowflake_internal_unload = StreamConfig_Mapping(
    data_provider_type="snowflake-internal-unload",
    sql="customQuery",
    database_name="databaseName",
    warehouse="warehouseName",
)

StreamConfig_Mapping_snowflakekeypairauthentication = StreamConfig_Mapping(
    data_provider_type="snowflakekeypairauthentication",
    sql="query",
    database_name="databaseName",
    warehouse="warehouseName",
)

StreamConfig_Mapping_snowflake_writeback = StreamConfig_Mapping(
    data_provider_type="snowflake-writeback",
    table_name="enterTableName",
    database_name="databaseName",
    warehouse="warehouseName",
)


class StreamConfig_Mappings(Enum):
    amazon_s3_assumerole = StreamConfig_Mapping_amazon_s3_assumerole
    snowflake_unload_v2 = StreamConfig_Mapping_snowflake_unload_v2
    snowflake_federated = StreamConfig_Mapping_snowflake_federated
    snowflake = StreamConfig_Mapping_snowflake
    adobe_analytics_v2 = StreamConfig_Mapping_adobe_analytics_v2
    amazon_athena_high_bandwidth = StreamConfig_Mapping_amazon_athena_high_bandwidth
    aws_athena = StreamConfig_Mapping_aws_athena
    dataset_copy = StreamConfig_Mapping_dataset_copy
    domo_csv = StreamConfig_Mapping_domo_csv
    google_sheets = StreamConfig_Mapping_google_sheets
    google_spreadsheets = StreamConfig_Mapping_google_spreadsheets
    postgresql = StreamConfig_Mapping_postgresql
    qualtrics = StreamConfig_Mapping_qualtrics
    sharepointonline = StreamConfig_Mapping_sharepointonline
    snowflake_internal_unload = StreamConfig_Mapping_snowflake_internal_unload
    snowflakekeypairauthentication = StreamConfig_Mapping_snowflakekeypairauthentication
    snowflake_writeback = StreamConfig_Mapping_snowflake_writeback

    default = StreamConfig_Mapping_default

    @classmethod
    def _missing_(cls, value):
        alt_search = value.lower().replace("-", "_")

        return next(
            (member for member in cls if member.name.lower() == alt_search),
            cls.default,
        )

    @classmethod
    def search(
        cls, value, debug_api: bool = False
    ) -> Union[StreamConfig_Mappings, None]:

        alt_search = value.lower().replace("-", "_")

        try:
            return cls[alt_search]

        except KeyError as e:
            if debug_api:
                print(f"{value} has not been added to enum config, must implement")
            return cls.default

# %% ../../nbs/classes/50_DomoDataset_Stream.ipynb 9
@dataclass
class StreamConfig:
    stream_category: str
    name: str
    type: str
    value: str
    value_clean: str = None
    parent: Any = field(repr=False, default=None)

    def __post_init__(self):

        # self.value_clean = self.value.replace("\n", " ")
        # sc.value_clean = re.sub(" +", " ", sc.value_clean)

        if self.stream_category == "sql" and self.parent:
            self.process_sql()

    def process_sql(self):
        if not self.parent:
            return None

        self.parent.configuration_query = self.value

        for table in parse_one(self.value).find_all(exp.Table):

            self.parent.configuration_tables.append(table.name.lower())
            self.parent.configuration_tables = sorted(
                list(set(self.parent.configuration_tables))
            )

        return self.parent.configuration_tables

    @classmethod
    def from_json(cls, obj: dict, data_provider_type: str, parent_stream: Any = None):

        config_name = obj["name"]

        mapping_enum = StreamConfig_Mappings.search(data_provider_type)

        stream_category = "default"
        if mapping_enum:
            stream_category = mapping_enum.value.search_keys_by_value(config_name)

            if parent_stream:
                parent_stream.has_mapping = True

        return cls(
            stream_category=stream_category,
            name=config_name,
            type=obj["type"],
            value=obj["value"],
            parent=parent_stream,
        )

    def to_json(self):
        return {"field": self.stream_category, "key": self.name, "value": self.value}

# %% ../../nbs/classes/50_DomoDataset_Stream.ipynb 10
class Dataset_Stream_GET_Error(dmde.ClassError):
    def __init__(self, cls_instance, message):

        super().__init__(cls_instance=cls_instance, message=message, cls_name_attr="id")


@dataclass
class DomoStream:
    auth: dmda.DomoAuth = field(repr=False)
    id: str
    dataset_id: str

    parent: Any = field(repr=False, default=None)

    transport_description: str = None
    transport_version: int = None
    update_method: str = None
    data_provider_name: str = None
    data_provider_key: str = None
    account_id: str = None
    account_display_name: str = None
    account_userid: str = None

    has_mapping: bool = False
    configuration: List[StreamConfig] = field(default_factory=list)
    configuration_tables: List[str] = field(default_factory=list)
    configuration_query: str = None

    @classmethod
    def _from_parent(cls, parent):
        st = cls(
            auth=parent.auth, id=parent.stream_id, dataset_id=parent.id, parent=parent
        )

        return st

    @classmethod
    def _from_json(cls, auth, obj):

        data_provider = obj.get("dataProvider", {})
        transport = obj.get("transport", {})
        datasource = obj.get("dataSource", {})

        account = obj.get("account", {})

        sd = cls(
            auth=auth,
            id=obj["id"],
            transport_description=transport["description"],
            transport_version=transport["version"],
            update_method=obj.get("updateMethod"),
            data_provider_name=data_provider["name"],
            data_provider_key=data_provider["key"],
            dataset_id=datasource["id"],
        )

        if account:
            sd.account_id = account.get("id")
            sd.account_display_name = account.get("displayName")
            sd.account_userid = account.get("userId")

        sd.configuration = [
            StreamConfig.from_json(
                obj=c_obj, data_provider_type=data_provider.get("key"), parent_stream=sd
            )
            for c_obj in obj["configuration"]
        ]

        return sd

    def generate_config_rpt(self):
        res = {}

        for config in self.configuration:
            if config.stream_category != "default" and config.stream_category:
                obj = config.to_json()
                res.update({obj["field"]: obj["value"]})

        return res

    @classmethod
    async def get_stream_by_id(
        cls,
        auth: dmda.DomoAuth,
        stream_id: str,
        debug_num_stacks_to_drop=2,
        debug_api: bool = False,
        return_raw: bool = False,
        parent: Any = None,
        session: Optional[httpx.AsyncClient] = None,
    ):

        res = await stream_routes.get_stream_by_id(
            auth=auth,
            stream_id=stream_id,
            session=session,
            parent_class=cls.__name__,
            debug_num_stacks_to_drop=debug_num_stacks_to_drop,
            debug_api=debug_api,
        )

        if return_raw:
            return res

        st = cls._from_json(auth=auth, obj=res.response)
        st.parent = parent
        return st

    async def get(self):
        if not (self.parent and self.parent.stream_id):
            raise Dataset_Stream_GET_Error(
                cls_instance=self,
                message=f"dataset {self.parent} has no stream_id",
            )

        self.parent.Stream = await self.get_stream_by_id(
            auth=self.parent.auth, stream_id=self.parent.stream_id, parent=self.parent
        )

        return self.parent.Stream

    @classmethod
    async def create_stream(
        cls,
        cnfg_body,
        auth: dmda.DomoAuth = None,
        session: Optional[httpx.AsyncClient] = None,
        debug_api: bool = False,
    ):
        return await stream_routes.create_stream(
            auth=auth, body=cnfg_body, session=session, debug_api=debug_api
        )

    @classmethod
    async def update_stream(
        cls,
        cnfg_body,
        stream_id,
        auth: dmda.DomoAuth = None,
        session: Optional[httpx.AsyncClient] = None,
        debug_api: bool = False,
    ):

        return await stream_routes.update_stream(
            auth=auth,
            stream_id=stream_id,
            body=cnfg_body,
            session=session,
            debug_api=debug_api,
        )

    @classmethod
    async def upsert_connector(
        cls,
        cnfg_body,
        match_name=None,
        auth: dmda.DomoAuth = None,
        session: Optional[httpx.AsyncClient] = None,
        debug_api: bool = False,
    ):
        import domolibrary.classes.DomoDatacenter as dmdc
        import domolibrary.classes.DomoDataset as dmds

        search_body = dmdc.DomoDatacenter.generate_search_datacenter_body_by_name(
            entity_name=match_name
        )

        search_res = await dmdc.DomoDatacenter.search_datacenter(
            auth=auth, body=search_body, session=session, debug_api=debug_api
        )

        existing_ds_obj = next(
            (ds for ds in search_res if ds.get("name").lower() == match_name.lower()),
            None,
        )

        # if debug_api:
        #     print(
        #         f"existing_ds - {existing_ds.id if existing_ds else ' not found '}")

        if existing_ds_obj:
            existing_ds = await dmds.DomoDataset.get_from_id(
                dataset_id=existing_ds.get("databaseId"), auth=auth
            )
            return await cls.update_stream(
                cnfg_body,
                stream_id=existing_ds.stream_id,
                auth=auth,
                session=session,
                debug_api=False,
            )
        else:
            return await cls.create_stream(
                cnfg_body, auth=auth, session=session, debug_api=debug_api
            )
