from __future__ import annotations

import logging
import asyncio
from contextlib import suppress
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Literal

import numpy as np
import polars as pl
from mashumaro.types import SerializableType

from aligned.lazy_imports import pandas as pd
from aligned.schemas.codable import Codable
from aligned.schemas.feature import FeatureReference, FeatureType
from aligned.schemas.literal_value import LiteralValue
from aligned.schemas.text_vectoriser import EmbeddingModel

if TYPE_CHECKING:
    from aligned.sources.s3 import AwsS3Config
    from aligned.feature_store import ContractStore


logger = logging.getLogger(__name__)


@dataclass
class TransformationTestDefinition:
    transformation: Transformation
    input: dict[str, list]
    output: list

    @property
    def input_pandas(self) -> pd.DataFrame:
        return pd.DataFrame(self.input)

    @property
    def output_pandas(self) -> pd.Series:
        return pd.Series(self.output)

    @property
    def input_polars(self) -> pl.DataFrame:
        return pl.from_dict(self.input, strict=False)

    @property
    def output_polars(self) -> pl.Series:
        try:
            values = pl.Series(
                self.output, dtype=self.transformation.dtype.polars_type
            ).fill_nan(None)
            if self.transformation.dtype == FeatureType.boolean():
                return values.cast(pl.Boolean)
            else:
                return values
        except pl.exceptions.InvalidOperationError:
            return pl.Series(self.output, dtype=self.transformation.dtype.polars_type)


def gracefull_transformation(
    df: pd.DataFrame,
    is_valid_mask: pd.Series,
    transformation: Callable[[pd.Series], pd.Series],
) -> pd.Series:
    result = pd.Series(np.repeat(np.nan, repeats=is_valid_mask.shape[0]))
    return result.mask(is_valid_mask, transformation(df.loc[is_valid_mask]))


class PsqlTransformation:
    def as_psql(self) -> str:
        raise NotImplementedError()


class RedshiftTransformation:
    def as_redshift(self) -> str:
        if isinstance(self, PsqlTransformation):
            return self.as_psql()
        raise NotImplementedError()


class PolarsExprTransformation:
    def polars_expr(self) -> pl.Expr | None:
        raise NotImplementedError(type(self))


class InnerTransformation(PolarsExprTransformation):
    inner: Expression

    def polars_expr_from(self, inner: pl.Expr) -> pl.Expr:
        raise NotImplementedError(type(self))

    def polars_expr(self) -> pl.Expr | None:
        inner_exp = self.inner.to_polars()
        if inner_exp is not None:
            return self.polars_expr_from(inner_exp)
        else:
            return None

    def pandas_tran(self, column: pd.Series) -> pd.Series:
        raise NotImplementedError(type(self))


class Transformation(Codable, SerializableType):
    name: str
    dtype: FeatureType

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        if isinstance(self, InnerTransformation):
            if self.inner.column:
                return self.pandas_tran(df[self.inner.column])  # type: ignore
            if self.inner.transformation:
                inner = await self.inner.transformation.transform_pandas(df, store)
                return self.pandas_tran(inner)

            raise ValueError(
                f"Unable to transform literal value with inner transformation. {type(self)}. "
                "Consider precomputing the value."
            )

        raise NotImplementedError(type(self))

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr | pl.Expr:
        if isinstance(self, PolarsExprTransformation):
            exp = self.polars_expr()
            if exp is not None:
                return exp

        if isinstance(self, InnerTransformation):
            assert self.inner.transformation is not None
            output_key = "_aligned_out"
            inner = await self.inner.transformation.transform_polars(
                df, output_key, store
            )
            if isinstance(inner, pl.Expr):
                return self.polars_expr_from(inner)
            else:
                return df.with_columns(
                    self.polars_expr_from(pl.col(output_key))
                ).select(pl.exclude(output_key))

        raise NotImplementedError(type(self))

    def _serialize(self) -> dict:
        return self.to_dict()

    def should_skip(self, output_column: str, columns: list[str]) -> bool:
        return output_column in columns

    @classmethod
    def _deserialize(cls, value: dict) -> Transformation:
        name_type = value["name"]
        del value["name"]
        data_class = SupportedTransformations.shared().types[name_type]
        with suppress(AttributeError):
            if data_class.dtype:
                del value["dtype"]

        return data_class.from_dict(value)

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        raise NotImplementedError()

    @classmethod
    async def run_transformation_test_polars(cls) -> None:
        from polars.testing import assert_series_equal
        from aligned import ContractStore

        try:
            test = cls.test_definition()
            alias = "something"
            output_df = await test.transformation.transform_polars(
                test.input_polars.lazy(), alias=alias, store=ContractStore.empty()
            )
            if isinstance(output_df, pl.Expr):
                output_df = test.input_polars.lazy().with_columns(
                    [output_df.alias(alias)]
                )
            output = output_df.select(pl.col(alias)).collect().to_series()

            missing_columns = set(test.input_polars.columns) - set(
                output_df.collect_schema().names()
            )
            assert missing_columns == set(), f"Missing columns: {missing_columns}"

            expected = test.output_polars
            if test.transformation.dtype == FeatureType.boolean():
                is_correct = output.equals(test.output_polars.alias(alias))
                assert is_correct, (
                    f"Output for {cls.__name__} is not correct.,"
                    f"\nGot: {output},\nexpected: {test.output_polars}"
                )
            else:
                assert_series_equal(
                    expected.alias(alias), output, check_names=False, check_dtypes=False
                )
        except AttributeError:
            raise AssertionError(
                f"Error for transformation {cls.__name__}. Could be missing a return in the transformation"
            )
        except NotImplementedError:
            pass
        except TypeError as e:
            raise ValueError(f"Error for transformation {cls.__name__}: {e}")

    @classmethod
    async def run_transformation_test_pandas(cls) -> None:
        import numpy as np
        from aligned import ContractStore
        from numpy.testing import assert_almost_equal

        with suppress(NotImplementedError):
            test = cls.test_definition()
            output = await test.transformation.transform_pandas(
                test.input_pandas, ContractStore.empty()
            )
            if test.transformation.dtype == FeatureType.boolean():
                is_correct = np.all(output == test.output_pandas) | output.equals(
                    test.output_pandas
                )
                assert is_correct, (
                    f"Output for {cls.__name__} is not correct.,"
                    f"\nGot: {output},\nexpected: {test.output_pandas}"
                )
            elif test.transformation.dtype == FeatureType.string():
                expected = test.output_pandas
                assert expected.equals(output), (
                    f"Output for {cls.__name__} is not correct.,"
                    f"\nGot: {output},\nexpected: {test.output_pandas}"
                )
            else:
                expected = test.output_pandas.to_numpy()
                output_np = output.to_numpy().astype("float")
                is_null = np.isnan(expected) & np.isnan(output_np)
                assert_almost_equal(expected[~is_null], output_np[~is_null])


class SupportedTransformations:
    types: dict[str, type[Transformation]]

    _shared: SupportedTransformations | None = None

    def __init__(self) -> None:
        self.types = {}

        for tran_type in [
            NotNull,
            PandasLambdaTransformation,
            PandasFunctionTransformation,
            PolarsLambdaTransformation,
            PolarsExpression,
            StructField,
            Contains,
            DateComponent,
            TimeDifference,
            Logarithm,
            LogarithmOnePluss,
            ToNumerical,
            HashColumns,
            ReplaceStrings,
            MultiTransformation,
            IsIn,
            BinaryTransformation,
            Inverse,
            Ordinal,
            FillNaValues,
            FillNaValuesColumns,
            Absolute,
            Round,
            Ceil,
            Floor,
            CopyTransformation,
            WordVectoriser,
            MapArgMax,
            LoadImageUrl,
            LoadImageUrlBytes,
            GrayscaleImage,
            PresignedAwsUrl,
            AppendConstString,
            AppendStrings,
            PrependConstString,
            ConcatStringAggregation,
            SumAggregation,
            MeanAggregation,
            MinAggregation,
            MaxAggregation,
            MedianAggregation,
            CountAggregation,
            CountDistinctAggregation,
            StdAggregation,
            VarianceAggregation,
            PercentileAggregation,
            JsonPath,
            Clip,
            ArrayContains,
            ArrayContainsAny,
            ArrayAtIndex,
            OllamaEmbedding,
            PolarsMapRowTransformation,
            LoadFeature,
            FormatStringTransformation,
            ListDotProduct,
        ]:
            self.add(tran_type)

    def add(self, transformation: type[Transformation]) -> None:
        self.types[transformation.name] = transformation

    @classmethod
    def shared(cls) -> SupportedTransformations:
        if cls._shared:
            return cls._shared
        cls._shared = SupportedTransformations()
        return cls._shared


@dataclass
class Expression(Codable):
    column: str | None = field(default=None)
    transformation: Transformation | None = field(default=None)
    literal: LiteralValue | None = field(default=None)

    def to_polars(self) -> pl.Expr | None:
        if self.column:
            return pl.col(self.column)
        if self.literal:
            return pl.lit(self.literal.python_value)
        if self.transformation and isinstance(
            self.transformation, PolarsExprTransformation
        ):
            return self.transformation.polars_expr()
        return None

    @staticmethod
    def from_value(value: Any) -> Expression:
        from aligned.compiler.feature_factory import FeatureFactory

        if isinstance(value, FeatureFactory):
            if value._name is not None:
                return Expression(column=value.name)

            assert value.transformation is not None
            return Expression(transformation=value.transformation.compile())

        return Expression(literal=LiteralValue.from_value(value))


BinaryOperators = Literal[
    "add",
    "sub",
    "eq",
    "neq",
    "gt",
    "gte",
    "lt",
    "lte",
    "mul",
    "div",
    "or",
    "and",
    "pow",
    "mod",
]


@dataclass
class BinaryTransformation(Transformation, PolarsExprTransformation):
    left: Expression
    right: Expression

    operator: BinaryOperators
    dtype: FeatureType = FeatureType.string()
    name: str = "binary"

    def polars_expr(self) -> pl.Expr | None:
        left_exp = self.left.to_polars()
        right_exp = self.right.to_polars()

        if left_exp is not None and right_exp is not None:
            return self._polars_expr(left_exp, right_exp)

        return None

    def _polars_expr(self, left: pl.Expr, right: pl.Expr) -> pl.Expr:
        if self.operator == "add":
            return left + right
        elif self.operator == "sub":
            return left - right
        elif self.operator == "eq":
            return left == right
        elif self.operator == "neq":
            return left != right
        elif self.operator == "gt":
            return left > right
        elif self.operator == "gte":
            return left >= right
        elif self.operator == "lt":
            return left < right
        elif self.operator == "lte":
            return left <= right
        elif self.operator == "mul":
            return left * right
        elif self.operator == "div":
            return left / right
        elif self.operator == "or":
            return left | right
        elif self.operator == "and":
            return left & right
        elif self.operator == "pow":
            return left.pow(right)
        elif self.operator == "mod":
            return left.mod(right)

        raise ValueError(f"Unable to compute {self.operator}")

    def pandas_op(self, left: pd.Series, right: pd.Series) -> pd.Series:
        if self.operator == "add":
            return left + right
        elif self.operator == "sub":
            return left - right
        elif self.operator == "eq":
            return left == right
        elif self.operator == "neq":
            return left != right
        elif self.operator == "gt":
            return left > right
        elif self.operator == "gte":
            return left >= right
        elif self.operator == "lt":
            return left < right
        elif self.operator == "lte":
            return left <= right
        elif self.operator == "mul":
            return left * right
        elif self.operator == "div":
            return left / right
        elif self.operator == "or":
            return left | right
        elif self.operator == "and":
            return left & right
        elif self.operator == "pow":
            return left**right
        elif self.operator == "mod":
            return left.mod(right)

        raise ValueError(f"Unable to compute {self.operator}")

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        exp = self.polars_expr()
        if exp is not None:
            return exp

        left_exp = self.left.to_polars()
        right_exp = self.right.to_polars()

        left_col = "_aligned_left"
        right_col = "_aligned_right"

        if left_exp is None and self.left.transformation:
            out = await self.left.transformation.transform_polars(df, left_col, store)
            if isinstance(out, pl.Expr):
                left_exp = out
            else:
                df = out
                left_exp = pl.col(left_col)

        if right_exp is None and self.right.transformation:
            out = await self.right.transformation.transform_polars(df, right_col, store)
            if isinstance(out, pl.Expr):
                right_exp = out
            else:
                df = out
                right_exp = pl.col(right_col)

        assert left_exp is not None
        assert right_exp is not None

        new_exp = self._polars_expr(left_exp, right_exp)
        return df.with_columns(new_exp).select(pl.exclude([left_col, right_col]))

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        left_series = None
        right_series = None

        if self.left.column:
            left_series = df[self.left.column]
        elif self.left.literal:
            left_series = self.left.literal.python_value
        else:
            assert self.left.transformation
            left_series = await self.left.transformation.transform_pandas(df, store)

        if self.right.column:
            right_series = df[self.right.column]
        elif self.right.literal:
            right_series = self.right.literal.python_value
        else:
            assert self.right.transformation
            right_series = await self.right.transformation.transform_pandas(df, store)

        assert left_series is not None
        assert right_series is not None

        return self.pandas_op(left_series, right_series)  # type: ignore


@dataclass
class PolarsMapRowTransformation(Transformation):
    """
    This will encode a custom method, that is not a lambda function
    Threfore, we will stort the actual code, and dynamically load it on runtime.

    This is unsafe, but will remove the ModuleImportError for custom methods
    """

    code: str
    function_name: str
    dtype: FeatureType
    name: str = "pol_map_row"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return (
            await self.transform_polars(pl.from_pandas(df).lazy(), "value", store)
        ).collect()["value"]  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        if self.function_name not in locals():
            exec(self.code)

        loaded = locals()[self.function_name]

        polars_df = df.collect()
        new_rows = []

        for row in polars_df.to_dicts():
            if asyncio.iscoroutinefunction(loaded):
                row[alias] = await loaded(row, store)
            else:
                row[alias] = loaded(row, store)
            new_rows.append(row)

        return pl.DataFrame(new_rows).lazy()


@dataclass
class PandasFunctionTransformation(Transformation):
    """
    This will encode a custom method, that is not a lambda function
    Threfore, we will stort the actual code, and dynamically load it on runtime.

    This is unsafe, but will remove the ModuleImportError for custom methods
    """

    code: str
    function_name: str
    dtype: FeatureType
    name: str = "pandas_code_tran"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        if self.function_name not in locals():
            exec(self.code)

        loaded = locals()[self.function_name]
        if asyncio.iscoroutinefunction(loaded):
            return await loaded(df, store)
        else:
            return loaded(df, store)

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        pandas_df = df.collect().to_pandas()
        if self.function_name not in locals():
            exec(self.code)

        loaded = locals()[self.function_name]
        if asyncio.iscoroutinefunction(loaded):
            pandas_df[alias] = await loaded(pandas_df, store)
        else:
            pandas_df[alias] = loaded(pandas_df, store)

        return pl.from_pandas(pandas_df).lazy()

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            transformation=PandasFunctionTransformation(
                code='async def test(df, store):\n    return df["a"] + df["b"]',
                function_name="test",
                dtype=FeatureType.int32(),
            ),
            input={
                "a": [1, 2, 3, 4, 5],
                "b": [1, 2, 3, 4, 5],
            },
            output=[2, 4, 6, 8, 10],
        )


@dataclass
class PandasLambdaTransformation(Transformation):
    method: bytes
    code: str
    dtype: FeatureType
    name: str = "pandas_lambda_tran"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        import asyncio

        import dill

        loaded = dill.loads(self.method)
        if asyncio.iscoroutinefunction(loaded):
            return await loaded(df, store)
        else:
            return loaded(df, store)

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        import dill

        pandas_df = df.collect().to_pandas()
        loaded = dill.loads(self.method)
        if asyncio.iscoroutinefunction(loaded):
            pandas_df[alias] = await loaded(pandas_df, store)
        else:
            pandas_df[alias] = loaded(pandas_df, store)

        return pl.from_pandas(pandas_df).lazy()


@dataclass
class PolarsFunctionTransformation(Transformation):
    """
    This will encode a custom method, that is not a lambda function
    Threfore, we will stort the actual code, and dynamically load it on runtime.

    This is unsafe, but will remove the ModuleImportError for custom methods
    """

    code: str
    function_name: str
    dtype: FeatureType
    name: str = "pandas_code_tran"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        polars_df = await self.transform_polars(
            pl.from_pandas(df).lazy(), self.function_name, store
        )
        assert isinstance(polars_df, pl.LazyFrame)
        return polars_df.collect().to_pandas()[self.function_name]  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        if self.function_name not in locals():
            exec(self.code)

        loaded = locals()[self.function_name]
        if asyncio.iscoroutinefunction(loaded):
            return await loaded(df, alias, store)
        else:
            return loaded(df, alias, store)


@dataclass
class PolarsExpression(Transformation, PolarsExprTransformation):
    polars_expression: str
    dtype: FeatureType
    name: str = "polars_expression"

    def polars_expr(self) -> pl.Expr:
        return pl.Expr.deserialize(self.polars_expression.encode(), format="json")

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        pl_df = pl.from_pandas(df)
        pl_df = pl_df.with_columns(
            pl.Expr.deserialize(self.polars_expression.encode(), format="json").alias(
                "polars_tran_column"
            )
        )
        return pl_df["polars_tran_column"].to_pandas()


@dataclass
class PolarsLambdaTransformation(Transformation):
    method: bytes
    code: str
    dtype: FeatureType
    name: str = "polars_lambda_tran"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        import dill

        loaded: pl.Expr = dill.loads(self.method)
        pl_df = pl.from_pandas(df)
        pl_df = pl_df.with_columns((loaded).alias("polars_tran_column"))
        return pl_df["polars_tran_column"].to_pandas()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        import dill

        tran = dill.loads(self.method)
        if isinstance(tran, pl.Expr):
            return tran
        else:
            return tran(df, alias, store)


@dataclass
class IsNull(Transformation, InnerTransformation):
    inner: Expression
    name: str = "is_null"
    dtype: FeatureType = FeatureType.boolean()

    def polars_expr_from(self, inner: pl.Expr) -> pl.Expr:
        return inner.is_null()

    def pandas_tran(self, column: pd.Series) -> pd.Series:
        return column.isnull()

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            IsNull(Expression(column="x")),
            input={"x": ["Hello", None, None, "test", None]},
            output=[False, True, True, False, True],
        )


@dataclass
class NotNull(Transformation, InnerTransformation):
    inner: Expression

    name: str = "not_null"
    dtype: FeatureType = FeatureType.boolean()

    def polars_expr_from(self, inner: pl.Expr) -> pl.Expr:
        return inner.is_not_null()

    def pandas_tran(self, column: pd.Series) -> pd.Series:
        return column.notnull()  # type: ignore

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            NotNull(Expression(column="x")),
            input={"x": ["Hello", None, None, "test", None]},
            output=[True, False, False, True, False],
        )


@dataclass
class Inverse(Transformation, PolarsExprTransformation):
    key: str

    name: str = "inverse"
    dtype: FeatureType = FeatureType.boolean()

    def __init__(self, key: str) -> None:
        self.key = key

    def polars_expr(self) -> pl.Expr:
        return ~pl.col(self.key)

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return gracefull_transformation(
            df,
            is_valid_mask=~(df[self.key].isnull()),  # type: ignore
            transformation=lambda dfv: ~dfv[self.key].astype("bool"),  # type: ignore
        )

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            Inverse("x"),
            input={"x": [False, True, True, False, None]},
            output=[True, False, False, True, None],
        )


@dataclass
class TimeDifference(Transformation, PsqlTransformation, RedshiftTransformation):
    front: str
    behind: str
    unit: str

    name: str = "time-diff"
    dtype: FeatureType = FeatureType.floating_point()

    def __init__(self, front: str, behind: str, unit: str = "s") -> None:
        self.front = front
        self.behind = behind
        self.unit = unit

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return gracefull_transformation(
            df,
            is_valid_mask=~(df[self.front].isna() | df[self.behind].isna()),
            transformation=lambda dfv: (dfv[self.front] - dfv[self.behind])
            / np.timedelta64(1, self.unit),  # type: ignore
        )

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return df.with_columns(
            (pl.col(self.front) - pl.col(self.behind)).dt.total_seconds().alias(alias)
        )

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        from numpy import nan

        return TransformationTestDefinition(
            TimeDifference(front="x", behind="y"),
            input={
                "x": [
                    datetime.fromtimestamp(1),
                    datetime.fromtimestamp(2),
                    datetime.fromtimestamp(0),
                    None,
                    datetime.fromtimestamp(1),
                ],
                "y": [
                    datetime.fromtimestamp(1),
                    datetime.fromtimestamp(0),
                    datetime.fromtimestamp(2),
                    datetime.fromtimestamp(1),
                    None,
                ],
            },
            output=[0, 2, -2, nan, nan],
        )

    def as_psql(self) -> str:
        return f"DATEDIFF('sec', {self.behind}, {self.front})"


@dataclass
class Logarithm(Transformation, PolarsExprTransformation):
    key: str

    name: str = "log"
    dtype: FeatureType = FeatureType.floating_point()

    def __init__(self, key: str) -> None:
        self.key = key

    def polars_expr(self) -> pl.Expr:
        return (
            pl.when(pl.col(self.key) > 0)
            .then(pl.col(self.key).log())
            .otherwise(pl.lit(None))
        )

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return gracefull_transformation(
            df,
            is_valid_mask=~(df[self.key].isna() | (df[self.key] <= 0)),
            transformation=lambda dfv: np.log(dfv[self.key]),  # type: ignore
        )

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        from numpy import nan

        return TransformationTestDefinition(
            Logarithm("x"),
            input={"x": [1, 0, np.e, None, -1]},
            output=[0, nan, 1, nan, nan],
        )


@dataclass
class LogarithmOnePluss(Transformation, PolarsExprTransformation):
    key: str

    name: str = "log1p"
    dtype: FeatureType = FeatureType.floating_point()

    def __init__(self, key: str) -> None:
        self.key = key

    def polars_expr(self) -> pl.Expr:
        return (
            pl.when(pl.col(self.key) > -1)
            .then((pl.col(self.key) + 1).log())
            .otherwise(pl.lit(None))
        )

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return gracefull_transformation(
            df,
            is_valid_mask=~(df[self.key].isna() | (df[self.key] <= -1)),
            transformation=lambda dfv: np.log1p(dfv[self.key]),  # type: ignore
        )

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        from numpy import nan

        return TransformationTestDefinition(
            LogarithmOnePluss("x"),
            input={"x": [1, 0, np.e - 1, None, -1]},
            output=[0.6931471806, 0, 1, nan, nan],
        )


@dataclass
class ToNumerical(Transformation, PolarsExprTransformation):
    key: str

    name: str = "to-num"
    dtype: FeatureType = FeatureType.floating_point()

    def __init__(self, key: str) -> None:
        self.key = key

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        from pandas import to_numeric

        return to_numeric(df[self.key], errors="coerce")  # type: ignore

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).cast(pl.Float64)

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            ToNumerical("x"),
            input={"x": ["1", "0", "10.5", None, "-20"]},
            output=[1, 0, 10.5, None, -20],
        )


@dataclass
class DateComponent(Transformation, PolarsExprTransformation):
    key: str
    component: str

    name: str = "date-component"
    dtype: FeatureType = FeatureType.int32()

    def __init__(self, key: str, component: str) -> None:
        self.key = key
        self.component = component

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return gracefull_transformation(
            df,
            is_valid_mask=~(df[self.key].isna()),  # type: ignore
            transformation=lambda dfv: getattr(dfv[self.key].dt, self.component),  # type: ignore
        )

    def polars_expr(self) -> pl.Expr:
        col = pl.col(self.key).cast(pl.Datetime).dt
        match self.component:
            case "day":
                expr = col.day()
            case "days":
                expr = col.ordinal_day()
            case "epoch":
                expr = col.epoch()
            case "hour":
                expr = col.hour()
            case "hours":
                expr = col.total_hours()
            case "iso_year":
                expr = col.iso_year()
            case "microsecond":
                expr = col.microsecond()
            case "microseconds":
                expr = col.total_microseconds()
            case "millisecond":
                expr = col.millisecond()
            case "milliseconds":
                expr = col.total_milliseconds()
            case "minute":
                expr = col.minute()
            case "minutes":
                expr = col.total_minutes()
            case "month":
                expr = col.month()
            case "nanosecond":
                expr = col.nanosecond()
            case "nanoseconds":
                expr = col.total_nanoseconds()
            case "ordinal_day":
                expr = col.ordinal_day()
            case "quarter":
                expr = col.quarter()
            case "second":
                expr = col.second()
            case "seconds":
                expr = col.total_seconds()
            case "week":
                expr = col.week()
            case "weekday":
                expr = col.weekday()
            case "year":
                expr = col.year()
            case "dayofweek":
                expr = col.weekday()
            case _:
                raise NotImplementedError(
                    f"Date component {self.component} is not implemented. Maybe setup a PR and contribute?"
                )
        return expr

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            DateComponent(key="x", component="hour"),
            input={
                "x": [
                    datetime.fromisoformat(value) if value else None
                    for value in [
                        "2022-04-02T20:20:50",
                        None,
                        "2022-02-20T23:20:50",
                        "1993-04-02T01:20:50",
                    ]
                ]
            },
            output=[20, None, 23, 1],
        )


@dataclass
class ArrayAtIndex(Transformation, PolarsExprTransformation):
    """Checks if an array contains a value

    some_array = List(String())
    contains_a_char = some_array.contains("a")
    """

    key: str
    index: int

    name: str = "array_at_index"
    dtype: FeatureType = FeatureType.boolean()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return pl.Series(df[self.key]).list.get(self.index).to_pandas()

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).list.get(self.index)

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            ArrayContains("x", LiteralValue.from_value("test")),
            input={"x": [["Hello", "test"], ["nah"], ["test", "espania", None]]},
            output=[True, False, True],
        )


@dataclass
class ArrayContainsAny(Transformation, PolarsExprTransformation):
    """Checks if an array contains a value

    some_array = List(String())
    contains_char = some_array.contains_any(["a", "b"])
    """

    key: str
    values: LiteralValue

    name: str = "array_contains_any"
    dtype: FeatureType = FeatureType.boolean()

    def __init__(self, key: str, value: Any | LiteralValue) -> None:
        self.key = key
        if isinstance(value, LiteralValue):
            self.value = value
        else:
            self.value = LiteralValue.from_value(value)

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        vals = self.value.python_value
        return (
            pl.Series(df[self.key])
            .list.eval(pl.element().is_in(vals))
            .list.any()
            .to_pandas()
        )

    def polars_expr(self) -> pl.Expr:
        vals = self.value.python_value
        return pl.col(self.key).list.eval(pl.element().is_in(vals)).list.any()

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            ArrayContainsAny("x", LiteralValue.from_value(["test", "nah"])),
            input={"x": [["Hello", "test"], ["nah"], ["espania", None]]},
            output=[True, True, False],
        )


@dataclass
class ArrayContains(Transformation, PolarsExprTransformation):
    """Checks if an array contains a value

    some_array = List(String())
    contains_a_char = some_array.contains("a")
    """

    key: str
    value: LiteralValue

    name: str = "array_contains"
    dtype: FeatureType = FeatureType.boolean()

    def __init__(self, key: str, value: Any | LiteralValue) -> None:
        self.key = key
        if isinstance(value, LiteralValue):
            self.value = value
        else:
            self.value = LiteralValue.from_value(value)

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return (
            pl.Series(df[self.key]).list.contains(self.value.python_value).to_pandas()
        )

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).list.contains(self.value.python_value)

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            ArrayContains("x", LiteralValue.from_value("test")),
            input={"x": [["Hello", "test"], ["nah"], ["test", "espania", None]]},
            output=[True, False, True],
        )


@dataclass
class Contains(Transformation, PolarsExprTransformation):
    """Checks if a string value contains another string

    some_string = String()
    contains_a_char = some_string.contains("a")
    """

    key: str
    value: str

    name: str = "contains"
    dtype: FeatureType = FeatureType.boolean()

    def __init__(self, key: str, value: str) -> None:
        self.key = key
        self.value = value

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return gracefull_transformation(
            df,
            is_valid_mask=~(df[self.key].isna()),  # type: ignore
            transformation=lambda dfv: dfv[self.key]
            .astype("str")
            .str.contains(self.value),  # type: ignore
        )

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).str.contains(self.value)

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            Contains("x", "es"),
            input={"x": ["Hello", "Test", "nah", "test", "espania", None]},
            output=[False, True, False, True, True, None],
        )


@dataclass
class Ordinal(Transformation):
    key: str
    orders: list[str]

    @property
    def orders_dict(self) -> dict[str, int]:
        return {key: index for index, key in enumerate(self.orders)}

    name: str = "ordinal"
    dtype: FeatureType = FeatureType.int32()

    def __init__(self, key: str, orders: list[str]) -> None:
        self.key = key
        self.orders = orders

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return df[self.key].map(self.orders_dict)  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        mapper = pl.DataFrame(
            {self.key: list(self.orders), alias: list(range(0, len(self.orders)))}
        )
        return df.join(mapper.lazy(), on=self.key, how="left")

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            Ordinal("x", ["a", "b", "c", "d"]),
            input={"x": ["a", "b", "a", None, "d", "p"]},
            output=[0, 1, 0, None, 3, None],
        )


@dataclass
class ReplaceStrings(Transformation):
    key: str
    values: list[tuple[str, str]]

    name: str = "replace"
    dtype: FeatureType = FeatureType.string()

    def __init__(self, key: str, values: list[tuple[str, str]]) -> None:
        self.key = key
        self.values = values

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        temp_df = df[self.key].copy()
        mask = ~(df[self.key].isna() | df[self.key].isnull())
        temp_df.loc[~mask] = np.nan
        for k, v in self.values:
            temp_df.loc[mask] = temp_df.loc[mask].str.replace(k, v, regex=True)

        return temp_df  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        collected = df.collect()
        pandas_column = collected.select(self.key).to_pandas()
        transformed = await self.transform_pandas(pandas_column, store)
        return collected.with_columns(pl.Series(transformed).alias(alias)).lazy()


@dataclass
class IsIn(Transformation, PolarsExprTransformation):
    values: list
    key: str

    name = "isin"
    dtype = FeatureType.boolean()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return df[self.key].isin(self.values)  # type: ignore

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).is_in(self.values)

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            IsIn(values=["hello", "test"], key="x"),
            input={"x": ["No", "Hello", "hello", "test", "nah", "nehtest"]},
            output=[False, False, True, True, False, False],
        )


@dataclass
class FillNaValuesColumns(Transformation):
    key: str
    fill_key: str
    dtype: FeatureType

    name: str = "fill_missing_key"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return df[self.key].fillna(df[self.fill_key])  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        if self.dtype == FeatureType.floating_point():
            return (
                pl.col(self.key)
                .fill_nan(pl.col(self.fill_key))
                .fill_null(pl.col(self.fill_key))
            )

        else:
            return pl.col(self.key).fill_null(pl.col(self.fill_key))

    def should_skip(self, output_column: str, columns: list[str]) -> bool:
        return False

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            FillNaValuesColumns("x", "y", dtype=FeatureType.int32()),
            input={
                "x": [1, 1, None, None, 3, 3, None, 4, 5, None],
                "y": [1, 2, 1, 2, 7, 2, 4, 1, 1, 9],
            },
            output=[1, 1, 1, 2, 3, 3, 4, 4, 5, 9],
        )


@dataclass
class FillNaValues(Transformation, PolarsExprTransformation):
    key: str
    value: LiteralValue
    dtype: FeatureType

    name: str = "fill_missing"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return df[self.key].fillna(self.value.python_value)  # type: ignore

    def polars_expr(self) -> pl.Expr:
        if self.dtype == FeatureType.floating_point():
            return (
                pl.col(self.key)
                .fill_nan(self.value.python_value)
                .fill_null(self.value.python_value)
            )
        else:
            return pl.col(self.key).fill_null(self.value.python_value)

    def should_skip(self, output_column: str, columns: list[str]) -> bool:
        return False

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            FillNaValues("x", LiteralValue.from_value(3), dtype=FeatureType.int32()),
            input={"x": [1, 1, None, None, 3, 3, None, 4, 5, None]},
            output=[1, 1, 3, 3, 3, 3, 3, 4, 5, 3],
        )


@dataclass
class CopyTransformation(Transformation, PolarsExprTransformation):
    key: str
    dtype: FeatureType

    name: str = "nothing"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return df[self.key]  # type: ignore

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key)


@dataclass
class Floor(Transformation, PolarsExprTransformation):
    key: str
    dtype: FeatureType = FeatureType.int64()

    name: str = "floor"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        from numpy import floor

        return floor(df[self.key])  # type: ignore

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).floor()

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            Floor("x"),
            input={"x": [1.3, 1.9, None]},
            output=[1, 1, None],
        )


@dataclass
class Ceil(Transformation, PolarsExprTransformation):
    key: str
    dtype: FeatureType = FeatureType.int64()

    name: str = "ceil"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        from numpy import ceil

        return ceil(df[self.key])  # type: ignore

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).ceil()

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            Ceil("x"),
            input={"x": [1.3, 1.9, None]},
            output=[2, 2, None],
        )


@dataclass
class Round(Transformation, InnerTransformation):
    inner: Expression
    dtype: FeatureType = FeatureType.int64()

    name: str = "round"

    def pandas_tran(self, column: pd.Series) -> pd.Series:
        from numpy import round

        return round(column)  # type: ignore

    def polars_expr_from(self, inner: pl.Expr) -> pl.Expr:
        return inner.round(0)

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            Round(Expression(column="x")),
            input={"x": [1.3, 1.9, None]},
            output=[1, 2, None],
        )


@dataclass
class Absolute(Transformation, InnerTransformation):
    inner: Expression
    dtype: FeatureType = FeatureType.floating_point()

    name: str = "abs"

    def pandas_tran(self, column: pd.Series) -> pd.Series:
        from numpy import abs

        return abs(column)  # type: ignore

    def polars_expr_from(self, inner: pl.Expr) -> pl.Expr:
        return inner.abs()

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            Absolute(Expression(column="x")),
            input={"x": [-13, 19, None]},
            output=[13, 19, None],
        )


@dataclass
class MapArgMax(Transformation):
    column_mappings: dict[str, LiteralValue]
    name = "map_arg_max"

    @property
    def dtype(self) -> FeatureType:  # type: ignore
        return list(self.column_mappings.values())[0].dtype

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        pl_df = await self.transform_polars(pl.from_pandas(df).lazy(), "feature", store)
        return pl_df.collect().to_pandas()["feature"]  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        expr: pl.Expr = pl.lit(None)

        if len(self.column_mappings) == 1:
            key, value = list(self.column_mappings.items())[0]
            if self.dtype == FeatureType.boolean():
                expr = (
                    pl.when(pl.col(key) > 0.5)
                    .then(value.python_value)
                    .otherwise(not value.python_value)
                )
            elif self.dtype == FeatureType.string():
                expr = (
                    pl.when(pl.col(key) > 0.5)
                    .then(value.python_value)
                    .otherwise(f"not {value.python_value}")
                )
            else:
                expr = (
                    pl.when(pl.col(key) > 0.5)
                    .then(value.python_value)
                    .otherwise(pl.lit(None))
                )
            return expr.alias(alias)
        else:
            features = list(self.column_mappings.keys())
            arg_max_alias = f"{alias}_arg_max"
            array_row_alias = f"{alias}_row"
            mapper = pl.DataFrame(
                {
                    alias: [
                        self.column_mappings[feature].python_value
                        for feature in features
                    ],
                    arg_max_alias: list(range(0, len(features))),
                }
            ).with_columns(pl.col(arg_max_alias).cast(pl.UInt32))
            sub = df.with_columns(
                pl.concat_list(pl.col(features)).alias(array_row_alias)
            ).with_columns(pl.col(array_row_alias).list.arg_max().alias(arg_max_alias))
            return sub.join(mapper.lazy(), on=arg_max_alias, how="left").select(
                pl.exclude([arg_max_alias, array_row_alias])
            )

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            MapArgMax(
                {
                    "a_prob": LiteralValue.from_value("a"),
                    "b_prob": LiteralValue.from_value("b"),
                    "c_prob": LiteralValue.from_value("c"),
                }
            ),
            input={
                "a_prob": [0.01, 0.9, 0.25],
                "b_prob": [0.9, 0.05, 0.15],
                "c_prob": [0.09, 0.05, 0.6],
            },
            output=["b", "a", "c"],
        )


@dataclass
class WordVectoriser(Transformation):
    key: str
    model: EmbeddingModel

    name = "word_vectoriser"
    dtype = FeatureType.embedding(768)

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return await self.model.vectorise_pandas(df[self.key])  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return await self.model.vectorise_polars(df, self.key, alias)


@dataclass
class LoadImageUrlBytes(Transformation):
    image_url_key: str

    name = "load_image"
    dtype = FeatureType.binary()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        import asyncio
        from aligned.sources.local import StorageFileSource

        urls = df.select(self.image_url_key).collect()[self.image_url_key]
        logger.info("Fetching image bytes")
        images = await asyncio.gather(
            *[StorageFileSource(url).read() for url in urls.to_list()]
        )
        logger.info("Loaded all images")
        image_dfs = pl.DataFrame({alias: images})

        return df.with_context(image_dfs.lazy()).select(pl.all())


@dataclass
class LoadImageUrl(Transformation):
    image_url_key: str

    name = "load_image"
    dtype = FeatureType.array()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        import asyncio
        from io import BytesIO

        import numpy as np
        from PIL import Image

        from aligned.sources.local import StorageFileSource

        urls = df.select(self.image_url_key).collect()[self.image_url_key]

        images = await asyncio.gather(
            *[StorageFileSource(url).read() for url in urls.to_list()]
        )
        data = [np.asarray(Image.open(BytesIO(buffer))) for buffer in images]
        image_dfs = pl.DataFrame({alias: data})
        return df.with_context(image_dfs.lazy()).select(pl.all())


@dataclass
class GrayscaleImage(Transformation):
    image_key: str

    name = "grayscale_image"
    dtype = FeatureType.array()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        import numpy as np

        def grayscale(images) -> pl.Series:
            return pl.Series(
                [
                    np.mean(image, axis=2) if len(image.shape) == 3 else image
                    for image in images.to_list()
                ]
            )

        return pl.col(self.image_key).map_batches(grayscale).alias(alias)


@dataclass
class AppendConstString(Transformation, PolarsExprTransformation):
    key: str
    string: str

    name = "append_const_string"
    dtype = FeatureType.string()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return df[self.key] + self.string

    def polars_expr(self) -> pl.Expr:
        return pl.concat_str(
            [pl.col(self.key).fill_null(""), pl.lit(self.string)], separator=""
        )


@dataclass
class AppendStrings(Transformation, PolarsExprTransformation):
    first_key: str
    second_key: str
    sep: str

    name = "append_strings"
    dtype = FeatureType.string()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return df[self.first_key] + self.sep + df[self.second_key]

    def polars_expr(self) -> pl.Expr:
        return pl.concat_str(
            [
                pl.col(self.first_key).fill_null(""),
                pl.col(self.second_key).fill_null(""),
            ],
            separator=self.sep,
        )


@dataclass
class PrependConstString(Transformation, PolarsExprTransformation):
    string: str
    key: str

    name = "prepend_const_string"
    dtype = FeatureType.string()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return self.string + df[self.key]

    def polars_expr(self) -> pl.Expr:
        return pl.concat_str(
            [pl.lit(self.string), pl.col(self.key).fill_null("")], separator=""
        )


@dataclass
class ConcatStringAggregation(
    Transformation, PsqlTransformation, RedshiftTransformation
):
    key: str
    separator: str = field(default=" ")

    name = "concat_string_agg"
    dtype = FeatureType.string()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        pdf = await self.transform_polars(pl.from_pandas(df).lazy(), self.name, store)
        assert isinstance(pdf, pl.LazyFrame)
        return pdf.collect().to_pandas()[self.name]  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return df.with_columns(
            pl.concat_str(pl.col(self.key), separator=self.separator).alias(alias)
        )

    def as_psql(self) -> str:
        return f"array_to_string(array_agg({self.key}), '{self.separator}')"

    def as_redshift(self) -> str:
        return f'listagg("{self.key}", \'{self.separator}\') within group (order by "{self.key}")'


@dataclass
class SumAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "sum_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.sum(self.key)

    def as_psql(self) -> str:
        return f"SUM({self.key})"


@dataclass
class MeanAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "mean_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).mean()

    def as_psql(self) -> str:
        return f"AVG({self.key})"


@dataclass
class MinAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "min_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).min()

    def as_psql(self) -> str:
        return f"MIN({self.key})"


@dataclass
class MaxAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "max_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).max()

    def as_psql(self) -> str:
        return f"MAX({self.key})"


@dataclass
class CountAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "count_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).count()

    def as_psql(self) -> str:
        return f"COUNT({self.key})"


@dataclass
class CountDistinctAggregation(
    Transformation, PsqlTransformation, RedshiftTransformation
):
    key: str

    name = "count_distinct_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).unique_counts()

    def as_psql(self) -> str:
        return f"COUNT(DISTINCT {self.key})"


@dataclass
class StdAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "std_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).std()

    def as_psql(self) -> str:
        return f"STDDEV({self.key})"


@dataclass
class VarianceAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "var_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).var()

    def as_psql(self) -> str:
        return f"variance({self.key})"


@dataclass
class MedianAggregation(Transformation, PsqlTransformation, RedshiftTransformation):
    key: str

    name = "median_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        return pl.col(self.key).median()

    def as_psql(self) -> str:
        return f"percentile_cont(0.5) WITHIN GROUP(ORDER BY {self.key})"


@dataclass
class PercentileAggregation(
    Transformation, PsqlTransformation, RedshiftTransformation, PolarsExprTransformation
):
    key: str
    percentile: float

    name = "percentile_agg"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        raise NotImplementedError()

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).quantile(self.percentile)

    def as_psql(self) -> str:
        return f"percentile_cont({self.percentile}) WITHIN GROUP(ORDER BY {self.key})"


@dataclass
class Clip(Transformation, InnerTransformation):
    inner: Expression
    lower: LiteralValue
    upper: LiteralValue

    name = "clip"
    dtype = FeatureType.floating_point()

    def pandas_tran(self, column: pd.Series) -> pd.Series:
        return column.clip(lower=self.lower.python_value, upper=self.upper.python_value)  # type: ignore

    def polars_expr_from(self, inner: pl.Expr) -> pl.Expr:
        return inner.clip(
            lower_bound=self.lower.python_value, upper_bound=self.upper.python_value
        )

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            transformation=Clip(
                inner=Expression(column="a"),
                lower=LiteralValue.from_value(0),
                upper=LiteralValue.from_value(1),
            ),
            input={"a": [-1, 0.1, 0.9, 2]},
            output=[0, 0.1, 0.9, 1],
        )


@dataclass
class PresignedAwsUrl(Transformation):
    config: AwsS3Config
    key: str

    max_age_seconds: int = field(default=30)

    name = "presigned_aws_url"
    dtype = FeatureType.string()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        from aioaws.s3 import S3Client
        from httpx import AsyncClient

        s3 = S3Client(AsyncClient(), config=self.config.s3_config)
        return df[self.key].apply(
            lambda x: s3.signed_download_url(x, max_age=self.max_age_seconds)
        )  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        from aioaws.s3 import S3Client
        from httpx import AsyncClient

        s3 = S3Client(AsyncClient(), config=self.config.s3_config)

        return df.with_columns(
            pl.col(self.key)
            .map_elements(
                lambda x: s3.signed_download_url(x, max_age=self.max_age_seconds)
            )
            .alias(alias)
        )


@dataclass
class StructField(Transformation):
    key: str
    field: str

    name = "struct_field"
    dtype = FeatureType.string()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        data = pl.from_pandas(df).lazy()
        tran = await self.transform_polars(data, "feature", store)

        if isinstance(tran, pl.LazyFrame):
            return tran.collect().to_pandas()["feature"]  # type: ignore

        return data.select(tran).collect().to_pandas()["feature"]  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        if df.schema[self.key].is_(pl.Utf8):
            return await JsonPath(self.key, f"$.{self.field}").transform_polars(
                df, alias, store
            )
        else:
            return pl.col(self.key).struct.field(self.field).alias(alias)


@dataclass
class OllamaGenerate(Transformation):
    key: str
    model: str
    system: str

    host_env: str | None = None
    name = "ollama_embedding"
    dtype = FeatureType.json()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        from ollama import AsyncClient
        import os

        host = None
        if self.host_env:
            host = os.getenv(self.host_env)

        client = AsyncClient(host=host)

        response = pd.Series([[]] * df.shape[0])

        for index, row in df.iterrows():
            response.iloc[index] = await client.generate(
                model=self.model,
                prompt=row[self.key],  # type: ignore
                system=self.system,
            )

        return response

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        def generate_embedding(values: pl.Series) -> pl.Series:
            from ollama import Client
            import os

            host = None
            if self.host_env:
                host = os.getenv(self.host_env)

            client = Client(host=host)

            return pl.Series(
                [
                    str(
                        client.generate(
                            model=self.model,
                            prompt=value,
                            system=self.system,
                        )
                    )
                    for value in values
                ]
            )

        return pl.col(self.key).map_batches(
            generate_embedding, return_dtype=pl.String()
        )


@dataclass
class OllamaEmbedding(Transformation):
    key: str
    model: str

    host_env: str | None = None
    name = "ollama_embedding"
    dtype = FeatureType.embedding(768)

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        from ollama import AsyncClient
        import os

        host = None
        if self.host_env:
            host = os.getenv(self.host_env)

        client = AsyncClient(host=host)

        response = pd.Series([[]] * df.shape[0])

        for index, row in df.iterrows():
            embedded: dict[str, list] = await client.embeddings(
                self.model,
                row[self.key],  # type: ignore
            )
            response.iloc[index] = embedded["embedding"]

        return response

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        def generate_embedding(values: pl.Series) -> pl.Series:
            from ollama import Client
            import os

            host = None
            if self.host_env:
                host = os.getenv(self.host_env)

            client = Client(host=host)

            values = [
                client.embeddings(self.model, value)["embedding"]
                for value in values  # type: ignore
            ]
            return pl.Series(values)

        return pl.col(self.key).map_batches(
            generate_embedding, return_dtype=pl.List(pl.Float64())
        )


@dataclass
class JsonPath(Transformation, PolarsExprTransformation):
    key: str
    path: str

    name = "json_path"
    dtype = FeatureType.string()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        return pl.Series(df[self.key]).str.json_path_match(self.path).to_pandas()

    def polars_expr(self) -> pl.Expr:
        return pl.col(self.key).str.json_path_match(self.path)


@dataclass
class Split(Transformation, InnerTransformation):
    inner: Expression
    separator: str
    name = "split"
    dtype: FeatureType = FeatureType.array(FeatureType.string())

    def pandas_tran(self, column: pd.Series) -> pd.Series:
        return column.str.split(self.separator)

    def polars_expr_from(self, inner: pl.Expr) -> pl.Expr:
        return inner.str.split(self.separator)


@dataclass
class LoadFeature(Transformation):
    entities: dict[str, str]
    feature: FeatureReference
    explode_key: str | None
    dtype: FeatureType
    name = "load_feature"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        entities = {}
        for key, df_key in self.entities.items():
            entities[key] = df[df_key]

        values = await store.features_for(
            entities, features=[self.feature.identifier]
        ).to_pandas()
        return values[self.feature.name]  # type: ignore

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        group_keys = []

        if self.explode_key:
            group_keys = ["row_nr"]
            entity_df = df.with_row_index("row_nr").explode(self.explode_key)
        else:
            entity_df = df

        entities = entity_df.rename(
            {df_key: key for key, df_key in self.entities.items()}
        )

        values = (
            await store.features_for(
                entities.collect(), features=[self.feature.identifier]
            )
            .with_subfeatures()
            .to_polars()
        )

        if group_keys:
            values = values.group_by(group_keys).agg(
                [pl.col(col) for col in values.columns if col not in group_keys]
            )

        values = values.select(pl.col(self.feature.name).alias(alias))

        return pl.concat([df, values.lazy()], how="horizontal")


@dataclass
class FormatStringTransformation(Transformation):
    format: str
    keys: list[str]
    name = "format_string"

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        values = []
        for row in df[self.keys].to_dict(orient="records"):  # type: ignore
            values.append(self.format.format(**row))

        return pd.Series(values)

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        polars_df = df.collect()
        new_rows = []

        for row in polars_df.to_dicts():
            row[alias] = self.format.format(**row)
            new_rows.append(row)

        return pl.DataFrame(new_rows).lazy()


@dataclass
class ListDotProduct(Transformation):
    left: str
    right: str

    name = "list_dot_product"
    dtype = FeatureType.floating_point()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        pl_df = pl.from_pandas(df)
        res = await self.transform_polars(pl_df.lazy(), "output", store)
        if isinstance(res, pl.Expr):
            return pl_df.with_columns(res.alias("output"))["output"].to_pandas()
        else:
            return res.collect()["output"].to_pandas()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        polars_version = pl.__version__.split(".")
        if len(polars_version) != 3:
            polars_version = [1, 8, 0]
        else:
            polars_version = [int(num) for num in polars_version]

        if polars_version[0] >= 1 and polars_version[1] >= 8:
            return (pl.col(self.left) * pl.col(self.right)).list.sum()

        dot_product = (
            df.select(self.left, self.right)
            .with_row_index(name="index")
            .explode(self.left, self.right)
            .group_by("index", maintain_order=True)
            .agg(pl.col(self.left).dot(self.right).alias(alias))
            .drop("index")
        )
        return pl.concat([df, dot_product], how="horizontal")

    @staticmethod
    def test_definition() -> TransformationTestDefinition:
        return TransformationTestDefinition(
            transformation=ListDotProduct("left", "right"),
            input={
                "left": [[1, 2, 3], [2, 3]],
                "right": [[1, 1, 1], [2, 2]],
            },
            output=[6, 10],
        )


@dataclass
class HashColumns(Transformation, PolarsExprTransformation):
    columns: list[str]

    name = "hash_columns"
    dtype = FeatureType.uint64()

    def polars_expr(self) -> pl.Expr:
        return pl.concat_str(self.columns).hash()

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        pl_df = pl.from_pandas(df)
        res = await self.transform_polars(pl_df.lazy(), "output", store)
        if isinstance(res, pl.Expr):
            return pl_df.with_columns(res.alias("output"))["output"].to_pandas()
        else:
            return res.collect()["output"].to_pandas()


@dataclass
class MultiTransformation(Transformation):
    transformations: list[tuple[Transformation, str | None]]
    name = "multi"
    dtype = FeatureType.string()

    async def transform_polars(
        self, df: pl.LazyFrame, alias: str, store: ContractStore
    ) -> pl.LazyFrame | pl.Expr:
        exclude_cols = []

        for tran, sub_alias in self.transformations:
            output = await tran.transform_polars(df, sub_alias or alias, store)

            if sub_alias:
                exclude_cols.append(sub_alias)

            if isinstance(output, pl.Expr):
                df = df.with_columns(output.alias(sub_alias or alias))
            else:
                df = output

        if alias in exclude_cols:
            exclude_cols.remove(alias)

        return df.select(pl.exclude(exclude_cols))

    async def transform_pandas(
        self, df: pd.DataFrame, store: ContractStore
    ) -> pd.Series:
        pl_df = pl.from_pandas(df)
        res = await self.transform_polars(pl_df.lazy(), "output", store)
        if isinstance(res, pl.Expr):
            return pl_df.with_columns(res.alias("output"))["output"].to_pandas()
        else:
            return res.collect()["output"].to_pandas()
