import collections
import dataclasses
import json
from typing import Any, List, Optional, Type, TypeVar, Union, overload

import pydantic
from pydantic.main import BaseModel

from chalk.features import DataFrame, Feature, FeatureWrapper, Filter, serialize_dtype, unwrap_feature
from chalk.features.feature_field import HasOnePathObj
from chalk.features.feature_set import is_feature_set_class
from chalk.features.pseudofeatures import PSEUDONAMESPACE
from chalk.features.resolver import Cron, OfflineResolver, OnlineResolver, SinkResolver, StreamResolver
from chalk.parsed.duplicate_input_gql import (
    FeatureClassGQL,
    UpsertCDCSourceGQL,
    UpsertDataFrameGQL,
    UpsertFeatureGQL,
    UpsertFeatureIdGQL,
    UpsertFeatureReferenceGQL,
    UpsertFeatureTimeKindGQL,
    UpsertFilterGQL,
    UpsertHasManyKindGQL,
    UpsertHasOneKindGQL,
    UpsertReferencePathComponentGQL,
    UpsertResolverGQL,
    UpsertResolverInputUnionGQL,
    UpsertResolverOutputGQL,
    UpsertScalarKindGQL,
    UpsertSinkResolverGQL,
    UpsertStreamResolverGQL,
    UpsertStreamResolverParamGQL,
    UpsertStreamResolverParamKeyedStateGQL,
    UpsertStreamResolverParamMessageGQL,
    VersionInfoGQL,
)
from chalk.sql._internal.sql_source import BaseSQLSource, TableIngestMixIn
from chalk.streams import Windowed
from chalk.streams.types import (
    StreamResolverParam,
    StreamResolverParamKeyedState,
    StreamResolverParamMessage,
    StreamResolverParamMessageWindow,
)
from chalk.utils import paths
from chalk.utils.duration import timedelta_to_duration
from chalk.utils.string import to_snake_case

T = TypeVar("T")

try:
    import attrs
except ImportError:
    attrs = None


@dataclasses.dataclass
class _ConvertedType:
    name: str
    bases: List[str]


def _get_qualified_class_name(cls: Type[Any]):
    mod = cls.__module__
    return cls.__qualname__ if mod == "builtins" else f"{mod}.{cls.__qualname__}"


_primitive_types = {bool, str, int, float}


def _convert_type(underlying_type: Type) -> _ConvertedType:
    if isinstance(underlying_type, Windowed):
        if underlying_type._kind in _primitive_types:
            return _ConvertedType(name=f"Windowed[{str(underlying_type._kind.__name__)}]", bases=["CHALK_WINDOWED"])
        else:
            kind = str(underlying_type._kind).replace("typing.", "")
            return _ConvertedType(name=f"Windowed[{kind}]", bases=["CHALK_WINDOWED"])
    if isinstance(underlying_type, FeatureWrapper):
        return _convert_type(unwrap_feature(underlying_type).typ.underlying)
    assert isinstance(underlying_type, type), f"Underlying type is not a type: '{str(underlying_type)}'"
    base_classes = [x.__name__ for x in type.mro(underlying_type)]
    if "Enum" in base_classes:
        # add type(s) of Enum to base_classes
        # Enums are already confirmed to be of one type by here
        value_type = [type(t).__name__ for t in underlying_type._value2member_map_.keys()][0]
        base_classes.insert(base_classes.index("Enum") + 1, value_type)

    # Attrs and Dataclasses don't technically have base classes
    # Pydantic calls their base class BaseModel which is way too generic for string comparison
    # For simplicity on the server-side validation, we'll come up with our own "base class" names
    if dataclasses.is_dataclass(underlying_type):
        base_classes.append("__dataclass__")
    if attrs is not None and isinstance(underlying_type, type) and attrs.has(underlying_type):
        base_classes.append("__attrs__")
    if issubclass(underlying_type, pydantic.BaseModel):
        base_classes.append("__pydantic__")

    return _ConvertedType(
        name=underlying_type.__name__,
        bases=base_classes,
    )


def _get_feature_id(s: Feature, root_fqn: bool):
    assert s.features_cls is not None
    return UpsertFeatureIdGQL(
        fqn=s.root_fqn if root_fqn else s.fqn,
        name=s.name,
        namespace=s.namespace,
        isPrimary=s.primary,
        className=s.features_cls.__name__,
        attributeName=s.attribute_name,
        explicitNamespace=s.namespace != to_snake_case(s.features_cls.__name__),
    )


def _get_pseudofeature_id(s: Feature):
    return UpsertFeatureIdGQL(
        fqn=s.fqn,
        name=s.name,
        namespace=s.namespace,
        isPrimary=s.primary,
    )


def _convert_df(df: Type[DataFrame]) -> UpsertDataFrameGQL:
    return UpsertDataFrameGQL(
        columns=[
            _get_feature_id(f, root_fqn=True)
            for f in df.columns
            if not f.is_autogenerated and not f.namespace == PSEUDONAMESPACE
        ],
        # filters=[convert_type_to_gql(f) for f in df.filters],
        filters=None,
    )


def _get_path_component(pc: HasOnePathObj) -> UpsertReferencePathComponentGQL:
    assert isinstance(pc.parent, Feature), f"Parent in relationship path not a feature, but {type(pc).__name__}"
    return UpsertReferencePathComponentGQL(
        parent=_get_feature_id(pc.parent, root_fqn=False),
        child=_get_feature_id(pc.child, root_fqn=False),
        parentToChildAttributeName=pc.parent_to_child_attribute_name,
    )


@overload
def convert_type_to_gql(t: Filter, path_prefix: Optional[str] = None) -> UpsertFilterGQL:
    ...


@overload
def convert_type_to_gql(t: StreamResolver, path_prefix: Optional[str] = None) -> UpsertStreamResolverGQL:
    ...


@overload
def convert_type_to_gql(t: SinkResolver, path_prefix: Optional[str] = None) -> UpsertSinkResolverGQL:
    ...


@overload
def convert_type_to_gql(
    t: Union[OnlineResolver, OfflineResolver], path_prefix: Optional[str] = None
) -> UpsertResolverGQL:
    ...


@overload
def convert_type_to_gql(t: StreamResolverParam, path_prefix: Optional[str] = None) -> UpsertStreamResolverParamGQL:
    ...


@overload
def convert_type_to_gql(t: Feature, path_prefix: Optional[str] = None) -> UpsertFeatureGQL:
    ...


def convert_type_to_gql(
    t: Any,
    # Resolvers have their filenames attached to them. But for UI presentation, we don't
    # want to show the full path, e.g. /Users/customer-name/projects/chalk/....,
    # but we want to show just the path relative to the project root. To do this efficiently,
    # we'll pass the project root path here and delete it from those full paths
    path_prefix: Optional[str] = None,
):
    if isinstance(t, StreamResolver):
        return UpsertStreamResolverGQL(
            fqn=t.fqn,
            kind="stream",
            sourceClassName=_get_qualified_class_name(t.source.__class__),
            sourceConfig=t.source._config_to_json(),
            functionDefinition=t.function_definition,
            environment=[t.environment] if isinstance(t.environment, str) else t.environment,
            doc=t.fn.__doc__,
            machineType=t.machine_type,
            output=[
                _get_feature_id(f, root_fqn=True)
                for f in t.output_features
                if (not isinstance(f, type) or not issubclass(f, DataFrame)) and not f.is_autogenerated
            ],
            inputs=[
                convert_type_to_gql(i)
                for i in t.signature.params
                if not isinstance(i, StreamResolverParamMessageWindow)
            ],
            owner=t.owner,
            filename=t.filename if path_prefix is None else t.filename.replace(path_prefix, ""),
            sourceLine=t.source_line,
        )

    if isinstance(t, SinkResolver):
        return UpsertSinkResolverGQL(
            fqn=t.fqn,
            functionDefinition=t.function_definition,
            inputs=[
                UpsertFeatureReferenceGQL(
                    underlying=_get_feature_id(f, root_fqn=False),
                    path=[_get_path_component(p) for p in f.path or []],
                )
                for f in t.inputs
                if not f.is_autogenerated
            ],
            environment=t.environment,
            tags=t.tags,
            doc=t.doc,
            machineType=t.machine_type,
            bufferSize=t.buffer_size,
            debounce=None if t.debounce is None else timedelta_to_duration(t.debounce),
            maxDelay=None if t.max_delay is None else timedelta_to_duration(t.max_delay),
            upsert=t.upsert,
            owner=t.owner,
            filename=t.filename if path_prefix is None else t.filename.replace(path_prefix, ""),
            sourceLine=t.source_line,
        )

    if isinstance(t, (OnlineResolver, OfflineResolver)):
        cron = t.cron
        if isinstance(cron, Cron):
            assert cron.schedule is not None, "`Cron` object must be constructed with a `schedule` property."
            cron = cron.schedule
        inputs = []
        all_inputs = []
        for f in t.inputs:
            if (not isinstance(f, type) or not issubclass(f, DataFrame)) and not f.is_autogenerated:
                if f.namespace == PSEUDONAMESPACE:
                    feature_ref = UpsertFeatureReferenceGQL(
                        underlying=_get_pseudofeature_id(f),
                        path=[_get_path_component(p) for p in f.path or []],
                    )
                    all_inputs.append(UpsertResolverInputUnionGQL(pseudoFeature=feature_ref))
                else:
                    feature_ref = UpsertFeatureReferenceGQL(
                        underlying=_get_feature_id(f, root_fqn=False),
                        path=[_get_path_component(p) for p in f.path or []],
                    )
                    inputs.append(feature_ref)
                    all_inputs.append(UpsertResolverInputUnionGQL(feature=feature_ref))
            elif isinstance(f, type) and issubclass(f, DataFrame):
                all_inputs.append(UpsertResolverInputUnionGQL(dataframe=_convert_df(f)))

        assert t.output is not None
        return UpsertResolverGQL(
            fqn=t.fqn,
            kind="offline" if isinstance(t, OfflineResolver) else "online",
            functionDefinition=t.function_definition,
            inputs=inputs,
            allInputs=all_inputs,
            output=UpsertResolverOutputGQL(
                features=[
                    _get_feature_id(f, root_fqn=True)
                    for f in t.output.features
                    if (not isinstance(f, type) or not issubclass(f, DataFrame)) and not f.is_autogenerated
                ],
                dataframes=[
                    _convert_df(f) for f in t.output.features if isinstance(f, type) and issubclass(f, DataFrame)
                ],
            ),
            environment=t.environment,
            tags=t.tags,
            doc=t.doc,
            cron=cron,
            machineType=t.machine_type,
            owner=t.owner,
            timeout=None if t.timeout is None else timedelta_to_duration(t.timeout),
            filename=t.filename if path_prefix is None else t.filename.replace(path_prefix, ""),
            sourceLine=t.source_line,
        )

    if isinstance(t, StreamResolverParam):
        if isinstance(t, StreamResolverParamMessage):
            converted_type = _convert_type(t.typ)
            schema = None
            if issubclass(t.typ, BaseModel):
                schema = json.loads(t.typ.schema_json())
            return UpsertStreamResolverParamGQL(
                message=UpsertStreamResolverParamMessageGQL(
                    name=t.name,
                    typeName=converted_type.name,
                    bases=converted_type.bases,
                    schema=schema,
                ),
                state=None,
            )
        elif isinstance(t, StreamResolverParamKeyedState):
            converted_type = _convert_type(t.typ)
            default_value = None
            schema = None
            if dataclasses.is_dataclass(t.default_value):
                default_value = json.loads(json.dumps(dataclasses.asdict(t.default_value)))
                schema = pydantic.dataclasses.dataclass(t.typ).__pydantic_model__.schema()
            elif isinstance(t.default_value, BaseModel):
                assert issubclass(t.typ, BaseModel)
                default_value = json.loads(t.default_value.json())
                schema = json.loads(t.typ.schema_json())
            return UpsertStreamResolverParamGQL(
                state=UpsertStreamResolverParamKeyedStateGQL(
                    name=t.name,
                    defaultValue=default_value,
                    typeName=converted_type.name,
                    bases=converted_type.bases,
                    schema=schema,
                ),
                message=None,
            )

    if isinstance(t, Feature):
        assert t.name is not None, "Feature has no name"
        assert t.namespace is not None, "Feature has no namespace"
        assert t.features_cls is not None
        scalar_kind_gql = None
        has_one_kind_gql = None
        has_many_kind_gql = None
        feature_time_kind_gql = None
        if t.is_autogenerated:
            raise RuntimeError("Autogenerated features should not be converted")
        if t.is_has_one:
            assert t.join is not None
            has_one_kind_gql = UpsertHasOneKindGQL(join=convert_type_to_gql(t.join))
        elif t.is_has_many:
            assert t.join is not None
            has_many_kind_gql = UpsertHasManyKindGQL(
                join=convert_type_to_gql(t.join),
                columns=None,
                filters=None,
            )
        elif t.is_feature_time:
            feature_time_kind_gql = UpsertFeatureTimeKindGQL()
        else:
            assert t.typ is not None
            parsed_type = t.typ.parsed_annotation
            if hasattr(parsed_type, "__chalk_feature_set__") and parsed_type.__chalk_feature_set__:
                raise TypeError(
                    f"Feature '{t.fqn}' is annotated with feature set '{parsed_type}' but no join was found. "
                    f"Please ensure that join between keys is specified between '{t.namespace}' and '{parsed_type.namespace}'."
                )
            converted_type = _convert_type(t.typ.underlying)
            scalar_kind_gql = UpsertScalarKindGQL(
                scalarKind=converted_type.name,
                primary=t.primary,
                baseClasses=converted_type.bases,
                version=t.version and t.version.version,
                versionInfo=t.version
                and VersionInfoGQL(
                    version=t.version.version,
                    maximum=t.version.maximum,
                    default=t.version.default,
                    versions=[f.fqn for f in t.version.reference.values()],
                ),
                dtype=serialize_dtype(t.converter.pyarrow_dtype),
            )
        return UpsertFeatureGQL(
            id=UpsertFeatureIdGQL(
                fqn=t.fqn,
                name=t.name,
                namespace=t.namespace,
                isPrimary=t.primary,
                className=t.features_cls.__name__,
                attributeName=t.attribute_name,
                explicitNamespace=t.namespace != to_snake_case(t.features_cls.__name__),
            ),
            maxStaleness=timedelta_to_duration(t.max_staleness),
            description=t.description,
            owner=t.owner,
            windowBuckets=None if t.window_buckets is None else sorted(list(t.window_buckets)),
            etlOfflineToOnline=t.etl_offline_to_online,
            tags=t.tags,
            hasOneKind=has_one_kind_gql,
            hasManyKind=has_many_kind_gql,
            scalarKind=scalar_kind_gql,
            featureTimeKind=feature_time_kind_gql,
            namespacePath=str(paths.get_classpath_or_name(t.features_cls)),
            isSingleton=t.is_singleton,
        )

    if is_feature_set_class(t):
        return FeatureClassGQL(
            isSingleton=t.__chalk_is_singleton__,
            doc=t.__doc__,
            name=t.__chalk_namespace__,
            owner=t.__chalk_owner__,
            tags=t.__chalk_tags__,
        )

    if isinstance(t, Filter):
        if isinstance(t.lhs, Feature):
            assert (
                t.lhs.is_scalar or t.lhs.is_feature_time
            ), f"Filters must be on scalar features, but {t.lhs.fqn} is not a scalar."

        if isinstance(t.rhs, Feature):
            assert (
                t.rhs.is_scalar or t.rhs.is_feature_time
            ), f"Filters must be on scalar features, but {t.rhs.fqn} is not a scalar."

        return UpsertFilterGQL(
            lhs=_get_feature_id(t.lhs, root_fqn=True),
            op=t.operation,
            rhs=_get_feature_id(t.rhs, root_fqn=True),
        )

    raise ValueError(f"Unable to convert {t} to GQL")


def gather_cdc_sources() -> List[UpsertCDCSourceGQL]:
    integration_name_to_tables = collections.defaultdict(list)
    for source in BaseSQLSource.registry:
        if isinstance(source, TableIngestMixIn):
            for schema_dot_table, preferences in source.ingested_tables.items():
                if preferences.cdc is True:
                    assert isinstance(source, BaseSQLSource)
                    integration_name_to_tables[source.name].append(schema_dot_table)
    return [
        UpsertCDCSourceGQL(
            integrationName=integration_name,
            schemaDotTableList=tables,
        )
        for integration_name, tables in integration_name_to_tables.items()
    ]
