from __future__ import annotations

import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, TypeVar

import pandas as pd
import polars as pl

from aligned.exceptions import UnableToFindFileException
from aligned.request.retrival_request import RequestResult, RetrivalRequest
from aligned.schemas.feature import Feature, FeatureType
from aligned.split_strategy import (
    SplitDataSet,
    SplitStrategy,
    SupervisedDataSet,
    TrainTestSet,
    TrainTestValidateSet,
)
from aligned.validation.interface import Validator

if TYPE_CHECKING:
    from typing import AsyncIterator

    from aligned.local.source import DataFileReference
    from aligned.schemas.derivied_feature import AggregatedFeature, AggregateOver
    from aligned.schemas.model import EventTrigger


logger = logging.getLogger(__name__)


def split(
    data: pd.DataFrame, start_ratio: float, end_ratio: float, event_timestamp_column: str | None = None
) -> pd.Index:
    index = pd.Index([], dtype=data.index.dtype)
    if event_timestamp_column:
        column = data[event_timestamp_column]
        if column.dtype != 'datetime64[ns]':
            column = pd.to_datetime(data[event_timestamp_column])
        values = column.quantile([start_ratio, end_ratio])
        return data.loc[(column >= values.iloc[0]) & (column <= values.iloc[1])].index

    group_size = data.shape[0]
    start_index = round(group_size * start_ratio)
    end_index = round(group_size * end_ratio)

    if end_index >= group_size:
        index = index.append(data.iloc[start_index:].index)
    else:
        index = index.append(data.iloc[start_index:end_index].index)
    return index


def split_polars(
    data: pl.LazyFrame, start_ratio: float, end_ratio: float, event_timestamp_column: str | None = None
) -> pl.DataFrame:
    if event_timestamp_column:
        values = data.select(
            [
                pl.col(event_timestamp_column).quantile(start_ratio).alias('start_value'),
                pl.col(event_timestamp_column).quantile(end_ratio).alias('end_value'),
            ]
        )
        return data.filter(
            pl.col(event_timestamp_column).is_between(values[0, 'start_value'], values[0, 'end_value'])
        ).collect()

    collected = data.collect()
    group_size = collected.shape[0]
    start_index = round(group_size * start_ratio)
    end_index = round(group_size * end_ratio)

    if end_index >= group_size:
        return collected[start_index:]
    else:
        return collected[start_index:end_index]


@dataclass
class SupervisedJob:

    job: RetrivalJob
    target_columns: set[str]

    async def to_pandas(self) -> SupervisedDataSet[pd.DataFrame]:
        data = await self.job.to_pandas()
        features = {
            feature.name
            for feature in self.job.request_result.features
            if feature.name not in self.target_columns
        }
        entities = {feature.name for feature in self.job.request_result.entities}
        return SupervisedDataSet(
            data, entities, features, self.target_columns, self.job.request_result.event_timestamp
        )

    async def to_polars(self) -> SupervisedDataSet[pl.LazyFrame]:
        data = await self.job.to_polars()
        features = [
            feature.name
            for feature in self.job.request_result.features
            if feature.name not in self.target_columns
        ]
        entities = [feature.name for feature in self.job.request_result.entities]
        return SupervisedDataSet(
            data, set(entities), set(features), self.target_columns, self.job.request_result.event_timestamp
        )

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    def train_set(self, train_size: float) -> SupervisedTrainJob:
        return SupervisedTrainJob(self, train_size)

    def with_subfeatures(self) -> SupervisedJob:
        return SupervisedJob(self.job.with_subfeatures(), self.target_columns)

    def cached_at(self, location: DataFileReference | str) -> SupervisedJob:
        return SupervisedJob(
            self.job.cached_at(location),
            self.target_columns,
        )

    def validate(self, validator: Validator) -> SupervisedJob:
        return SupervisedJob(
            self.job.validate(validator),
            self.target_columns,
        )

    def log_each_job(self) -> SupervisedJob:
        return SupervisedJob(
            self.job.log_each_job(),
            self.target_columns,
        )

    def describe(self) -> str:
        return f'{self.job.describe()} with target columns {self.target_columns}'


@dataclass
class SupervisedTrainJob:

    job: SupervisedJob
    train_size: float

    async def to_pandas(self) -> TrainTestSet[pd.DataFrame]:
        core_data = await self.job.to_polars()
        data = core_data.data.collect()
        data = data.to_pandas()

        test_ratio_start = self.train_size
        return TrainTestSet(
            data=data,
            entity_columns=core_data.entity_columns,
            features=core_data.features,
            target_columns=core_data.target_columns,
            train_index=split(data, 0, test_ratio_start, core_data.event_timestamp_column),
            test_index=split(data, test_ratio_start, 1, core_data.event_timestamp_column),
            event_timestamp_column=core_data.event_timestamp_column,
        )

    async def to_polars(self) -> TrainTestSet[pl.DataFrame]:
        # Use the pandas method, as the split is not created for polars yet
        # A but unsure if I should use the same index concept for polars
        pandas_data = await self.to_pandas()
        return TrainTestSet(
            data=pl.from_pandas(pandas_data.data),
            entity_columns=pandas_data.entity_columns,
            features=pandas_data.features,
            target_columns=pandas_data.target_columns,
            train_index=pandas_data.train_index,
            test_index=pandas_data.test_index,
            event_timestamp_column=pandas_data.event_timestamp_column,
        )

    def validation_set(self, validation_size: float) -> SupervisedValidationJob:
        return SupervisedValidationJob(self, validation_size)


@dataclass
class SupervisedValidationJob:

    job: SupervisedTrainJob
    validation_size: float

    async def to_pandas(self) -> TrainTestValidateSet[pd.DataFrame]:
        data = await self.job.to_pandas()

        test_start = self.job.train_size
        validate_start = test_start + self.validation_size

        return TrainTestValidateSet(
            data=data.data,
            entity_columns=set(data.entity_columns),
            features=data.features,
            target=data.target_columns,
            train_index=split(data.data, 0, test_start, data.event_timestamp_column),
            test_index=split(data.data, test_start, validate_start, data.event_timestamp_column),
            validate_index=split(data.data, validate_start, 1, data.event_timestamp_column),
            event_timestamp_column=data.event_timestamp_column,
        )

    async def to_polars(self) -> TrainTestValidateSet[pl.DataFrame]:
        data = await self.to_pandas()

        return TrainTestValidateSet(
            data=pl.from_pandas(data.data),
            entity_columns=data.entity_columns,
            features=data.features,
            target=data.target,
            train_index=data.train_index,
            test_index=data.test_index,
            validate_index=data.validate_index,
            event_timestamp_column=data.event_timestamp_column,
        )


class RetrivalJob(ABC):
    @property
    def request_result(self) -> RequestResult:
        if isinstance(self, ModificationJob):
            return self.job.request_result
        raise NotImplementedError()

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        if isinstance(self, ModificationJob):
            return self.job.retrival_requests
        raise NotImplementedError()

    @abstractmethod
    async def to_pandas(self) -> pd.DataFrame:
        raise NotImplementedError()

    @abstractmethod
    async def to_polars(self) -> pl.LazyFrame:
        raise NotImplementedError()

    def describe(self) -> str:
        if isinstance(self, ModificationJob):
            return f'{self.job.describe()} -> {self.__class__.__name__}'
        raise NotImplementedError(f'Describe not implemented for {self.__class__.__name__}')

    def remove_derived_features(self) -> RetrivalJob:
        if isinstance(self, ModificationJob):
            return self.copy_with(self.job.remove_derived_features())
        return self

    def log_each_job(self) -> RetrivalJob:
        if isinstance(self, ModificationJob):
            return self.copy_with(self.job.log_each_job())
        return LogJob(self)

    def chuncked(self, size: int) -> DataLoaderJob:
        return DataLoaderJob(self, size)

    def with_subfeatures(self) -> RetrivalJob:
        if isinstance(self, ModificationJob):
            return self.copy_with(self.job.with_subfeatures())
        return self

    def cached_at(self, location: DataFileReference | str) -> RetrivalJob:
        if isinstance(location, str):
            from aligned.local.source import ParquetFileSource

            return FileCachedJob(ParquetFileSource(location), self)
        else:
            return FileCachedJob(location, self)

    def test_size(self, test_size: float, target_column: str) -> SupervisedTrainJob:
        return SupervisedJob(self, {target_column}).train_set(train_size=1 - test_size)

    def train_set(self, train_size: float, target_column: str) -> SupervisedTrainJob:
        return SupervisedJob(self, {target_column}).train_set(train_size=train_size)

    def validate(self, validator: Validator) -> RetrivalJob:
        return ValidationJob(self, validator)

    def derive_features(self, requests: list[RetrivalRequest]) -> RetrivalJob:
        return DerivedFeatureJob(job=self, requests=requests)

    def ensure_types(self, requests: list[RetrivalRequest]) -> RetrivalJob:
        return EnsureTypesJob(job=self, requests=requests)

    def filter(self, include_features: set[str]) -> RetrivalJob:
        return FilterJob(include_features, self)

    def listen_to_events(self, events: set[EventTrigger]) -> RetrivalJob:
        return ListenForTriggers(self, events)

    def validate_entites(self) -> RetrivalJob:
        return ValidateEntitiesJob(self)

    def fill_missing_columns(self) -> RetrivalJob:
        return FillMissingColumnsJob(self)

    @staticmethod
    def from_dict(data: dict[str, list], request: list[RetrivalRequest] | RetrivalRequest) -> RetrivalJob:
        if isinstance(request, RetrivalRequest):
            request = [request]
        return LiteralDictJob(data, request)


JobType = TypeVar('JobType')


class ModificationJob:

    job: RetrivalJob

    def copy_with(self: JobType, job: RetrivalJob) -> JobType:
        self.job = job  # type: ignore
        return self


@dataclass
class LiteralDictJob(RetrivalJob):

    data: dict[str, list]
    requests: list[RetrivalRequest]

    @property
    def request_result(self) -> RequestResult:
        return RequestResult.from_request_list(self.requests)

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.requests

    async def to_pandas(self) -> pd.DataFrame:
        return pd.DataFrame(self.data)

    async def to_polars(self) -> pl.LazyFrame:
        return pl.DataFrame(self.data).lazy()


@dataclass
class LogJob(RetrivalJob, ModificationJob):

    job: RetrivalJob

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.job.retrival_requests

    async def to_pandas(self) -> pd.DataFrame:
        df = await self.job.to_pandas()
        logger.info(f'Results from {type(self.job)}')
        logger.info(df)
        return df

    async def to_polars(self) -> pl.LazyFrame:
        df = await self.job.to_polars()
        logger.info(f'Results from {type(self.job)}')
        logger.info(df.head(10).collect())
        return df

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()

    def log_each_job(self) -> RetrivalJob:
        return self.job


@dataclass
class ValidationJob(RetrivalJob, ModificationJob):

    job: RetrivalJob
    validator: Validator

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.job.retrival_requests

    @property
    def features_to_validate(self) -> set[Feature]:
        return RequestResult.from_request_list(self.retrival_requests).features

    async def to_pandas(self) -> pd.DataFrame:
        return await self.validator.validate_pandas(
            list(self.features_to_validate), await self.job.to_pandas()
        )

    async def to_polars(self) -> pl.LazyFrame:
        return await self.validator.validate_polars(
            list(self.features_to_validate), await self.job.to_polars()
        )

    def with_subfeatures(self) -> RetrivalJob:
        return ValidationJob(self.job.with_subfeatures(), self.validator)

    def cached_at(self, location: DataFileReference | str) -> RetrivalJob:
        if isinstance(location, str):
            from aligned.local.source import ParquetFileSource

            return FileCachedJob(ParquetFileSource(location), self)
        else:
            return FileCachedJob(location, self)

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


@dataclass
class DerivedFeatureJob(RetrivalJob, ModificationJob):

    job: RetrivalJob
    requests: list[RetrivalRequest]

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.job.retrival_requests

    async def compute_derived_features_polars(self, df: pl.LazyFrame) -> pl.LazyFrame:

        for request in self.requests:
            for feature_round in request.derived_features_order():

                round_expressions: list[pl.Expr] = []

                for feature in feature_round:
                    if feature.name in df.columns:
                        logger.info(f'Skipped adding feature {feature.name} to computation plan')
                        continue
                    logger.info(f'Adding feature to computation plan in polars: {feature.name}')

                    method = await feature.transformation.transform_polars(df, feature.name)
                    if isinstance(method, pl.LazyFrame):
                        df = method
                    elif isinstance(method, pl.Expr):
                        round_expressions.append(method.alias(feature.name))
                    else:
                        raise ValueError('Invalid result from transformation')

                if round_expressions:
                    df = df.with_columns(round_expressions)
        return df

    async def compute_derived_features_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
        for request in self.requests:
            for feature_round in request.derived_features_order():
                for feature in feature_round:
                    if feature.name in df.columns:
                        logger.info(f'Skipping to compute {feature.name} as it is aleady computed')
                        continue
                    logger.info(f'Computing feature with pandas: {feature.name}')
                    df[feature.name] = await feature.transformation.transform_pandas(
                        df[feature.depending_on_names]
                    )
                    if df[feature.name].dtype != feature.dtype.pandas_type:
                        if feature.dtype.is_numeric:
                            df[feature.name] = pd.to_numeric(df[feature.name], errors='coerce').astype(
                                feature.dtype.pandas_type
                            )
                        else:
                            df[feature.name] = df[feature.name].astype(feature.dtype.pandas_type)
        return df

    async def to_pandas(self) -> pd.DataFrame:
        return await self.compute_derived_features_pandas(await self.job.to_pandas())

    async def to_polars(self) -> pl.LazyFrame:
        return await self.compute_derived_features_polars(await self.job.to_polars())

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


@dataclass
class ValidateEntitiesJob(RetrivalJob, ModificationJob):

    job: RetrivalJob

    async def to_pandas(self) -> pd.DataFrame:
        data = await self.job.to_pandas()

        for request in self.retrival_requests:
            if request.entity_names - set(data.columns):
                return pd.DataFrame({})

        return data

    async def to_polars(self) -> pl.DataFrame:
        data = await self.job.to_polars()

        for request in self.retrival_requests:
            if request.entity_names - set(data.columns):
                return pl.DataFrame({}).lazy()

        return data


@dataclass
class FillMissingColumnsJob(RetrivalJob, ModificationJob):

    job: RetrivalJob

    async def to_pandas(self) -> pd.DataFrame:
        data = await self.job.to_pandas()
        for request in self.retrival_requests:

            missing = request.all_required_feature_names - set(data.columns)
            if not missing:
                continue

            logger.info(
                f"""
Some features is missing.
Will fill values with None, but it could be a potential problem: {missing}
"""
            )
            for feature in missing:
                data[feature] = None
        return data

    async def to_polars(self) -> pl.LazyFrame:
        data = await self.job.to_polars()
        for request in self.retrival_requests:

            missing = request.all_required_feature_names - set(data.columns)
            if not missing:
                continue

            logger.info(
                f"""
Some features is missing.
Will fill values with None, but it could be a potential problem: {missing}
"""
            )
            data = data.with_columns([pl.lit(None).alias(feature) for feature in missing])
        return data


@dataclass
class StreamAggregationJob(RetrivalJob, ModificationJob):

    job: RetrivalJob
    checkpoints: dict[AggregateOver, DataFileReference]

    @property
    def time_windows(self) -> set[AggregateOver]:
        windows = set()
        for request in self.retrival_requests:
            for feature in request.aggregated_features:
                windows.add(feature.aggregate_over)
        return windows

    @property
    def aggregated_features(self) -> dict[AggregateOver, set[AggregatedFeature]]:
        features = defaultdict(set)
        for request in self.retrival_requests:
            for feature in request.aggregated_features:
                features[feature.aggregate_over].add(feature)
        return features

    async def data_windows(self, window: AggregateOver, data: pl.DataFrame, now: datetime) -> pl.DataFrame:
        checkpoint = self.checkpoints[window]
        filter_expr: pl.Expr | None = None

        if window.window:
            time_window = window.window
            filter_expr = pl.col(time_window.time_column.name) > now - time_window.time_window
        if window.condition:
            raise ValueError('Condition is not supported for stream aggregation, yet')

        try:
            window_data = (await checkpoint.to_polars()).collect()

            if filter_expr is not None:
                new_data = pl.concat([window_data.filter(filter_expr), data.filter(filter_expr)])
            else:
                new_data = pl.concat([window_data, data])

            await checkpoint.write_polars(new_data.lazy())
            return new_data
        except FileNotFoundError:

            if filter_expr is not None:
                window_data = data.filter(filter_expr)
            else:
                window_data = data

            await checkpoint.write_polars(window_data.lazy())
            return window_data

    async def to_pandas(self) -> pd.DataFrame:
        raise NotImplementedError()

    async def to_polars(self) -> pl.LazyFrame:
        data = (await self.job.to_polars()).collect()

        # This is used as a dummy frame, as the pl abstraction is not good enough
        lazy_df = pl.DataFrame({}).lazy()
        now = datetime.utcnow()

        for window in self.time_windows:

            aggregations = self.aggregated_features[window]

            required_features = set(window.group_by).union([window.window.time_column])
            for agg in aggregations:
                required_features.update(agg.derived_feature.depending_on)

            required_features_name = sorted({feature.name for feature in required_features})

            agg_transformations = await asyncio.gather(
                *[
                    agg.derived_feature.transformation.transform_polars(lazy_df, 'dummy')
                    for agg in aggregations
                ]
            )
            agg_expr = [
                agg.alias(feature.name)
                for agg, feature in zip(agg_transformations, aggregations)
                if isinstance(agg, pl.Expr)
            ]

            window_data = await self.data_windows(window, data.select(required_features_name), now)

            agg_data = window_data.lazy().groupby(window.group_by_names).agg(agg_expr).collect()
            data = data.join(agg_data, on=window.group_by_names, how='left')

        return data.lazy()

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


@dataclass
class DataLoaderJob:

    job: RetrivalJob
    chunk_size: int

    async def to_polars(self) -> AsyncIterator[pl.LazyFrame]:
        from math import ceil

        from aligned.local.job import LiteralRetrivalJob

        needed_requests = self.job.retrival_requests
        without_derived = self.job.remove_derived_features()
        raw_files = (await without_derived.to_polars()).collect()
        features_to_include = self.job.request_result.features.union(self.job.request_result.entities)
        features_to_include_names = {feature.name for feature in features_to_include}

        iterations = ceil(raw_files.shape[0] / self.chunk_size)
        for i in range(iterations):
            start = i * self.chunk_size
            end = (i + 1) * self.chunk_size
            df = raw_files[start:end, :]

            chunked_job = (
                LiteralRetrivalJob(df.lazy(), RequestResult.from_request_list(needed_requests))
                .derive_features(needed_requests)
                .filter(features_to_include_names)
            )

            chunked_df = await chunked_job.to_polars()
            yield chunked_df

    async def to_pandas(self) -> AsyncIterator[pd.DataFrame]:
        async for chunk in self.to_polars():
            yield chunk.collect().to_pandas()


@dataclass
class RawFileCachedJob(RetrivalJob, ModificationJob):

    location: DataFileReference
    job: DerivedFeatureJob

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.job.retrival_requests

    async def to_pandas(self) -> pd.DataFrame:
        from aligned.local.job import FileFullJob
        from aligned.local.source import LiteralReference

        try:
            logger.info('Trying to read cache file')
            df = await self.location.read_pandas()
        except UnableToFindFileException:
            logger.info('Unable to load file, so fetching from source')
            df = await self.job.job.to_pandas()
            logger.info('Writing result to cache')
            await self.location.write_pandas(df)
        return (
            await FileFullJob(LiteralReference(df), request=self.job.requests[0])
            .derive_features(self.job.requests)
            .to_pandas()
        )

    async def to_polars(self) -> pl.LazyFrame:
        return await self.job.to_polars()

    def cached_at(self, location: DataFileReference | str) -> RetrivalJob:
        return self

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


@dataclass
class FileCachedJob(RetrivalJob, ModificationJob):

    location: DataFileReference
    job: RetrivalJob

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.job.retrival_requests

    async def to_pandas(self) -> pd.DataFrame:
        try:
            logger.info('Trying to read cache file')
            df = await self.location.read_pandas()
        except UnableToFindFileException:
            logger.info('Unable to load file, so fetching from source')
            df = await self.job.to_pandas()
            logger.info('Writing result to cache')
            await self.location.write_pandas(df)
        return df

    async def to_polars(self) -> pl.LazyFrame:
        try:
            logger.info('Trying to read cache file')
            df = await self.location.to_polars()
        except UnableToFindFileException:
            logger.info('Unable to load file, so fetching from source')
            df = await self.job.to_polars()
            logger.info('Writing result to cache')
            await self.location.write_polars(df)
        return df

    def cached_at(self, location: DataFileReference | str) -> RetrivalJob:
        return self

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


@dataclass
class SplitJob:

    job: RetrivalJob
    target_column: str
    strategy: SplitStrategy

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.job.retrival_requests

    async def use_pandas(self) -> SplitDataSet[pd.DataFrame]:
        data = await self.job.to_pandas()
        return self.strategy.split_pandas(data, self.target_column)

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


class FullExtractJob(RetrivalJob):
    limit: int | None


class DateRangeJob(RetrivalJob):
    start_date: datetime
    end_date: datetime

    """
    ```
    psql_config = PsqlConfig(...)
    entites = psql_config.fetch("SELECT * FROM entities WHERE ...")
    ```
    """


class FactualRetrivalJob(RetrivalJob):
    facts: RetrivalJob


@dataclass
class EnsureTypesJob(RetrivalJob, ModificationJob):

    job: RetrivalJob
    requests: list[RetrivalRequest]

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.requests

    async def to_pandas(self) -> pd.DataFrame:
        df = await self.job.to_pandas()
        for request in self.requests:
            for feature in request.all_required_features:

                mask = ~df[feature.name].isnull()

                with suppress(AttributeError):
                    df[feature.name] = df[feature.name].mask(
                        ~mask, other=df.loc[mask, feature.name].str.strip('"')
                    )

                if feature.dtype == FeatureType('').datetime:
                    df[feature.name] = pd.to_datetime(df[feature.name], infer_datetime_format=True, utc=True)
                elif feature.dtype == FeatureType('').datetime or feature.dtype == FeatureType('').string:
                    continue
                else:

                    if feature.dtype.is_numeric:
                        df[feature.name] = pd.to_numeric(df[feature.name], errors='coerce').astype(
                            feature.dtype.pandas_type
                        )
                    else:
                        df[feature.name] = df[feature.name].astype(feature.dtype.pandas_type)
            if request.event_timestamp:
                feature = request.event_timestamp
                df[feature.name] = pd.to_datetime(df[feature.name], infer_datetime_format=True, utc=True)
        return df

    async def to_polars(self) -> pl.LazyFrame:
        df = await self.job.to_polars()
        for request in self.requests:
            for feature in request.all_required_features:
                if feature.dtype == FeatureType('').bool:
                    df = df.with_column(pl.col(feature.name).cast(pl.Int8).cast(pl.Boolean))
                elif feature.dtype == FeatureType('').datetime:
                    current_dtype = df.select([feature.name]).dtypes[0]
                    if isinstance(current_dtype, pl.Datetime):
                        continue
                    # Convert from ms to us
                    df = df.with_column(
                        (pl.col(feature.name).cast(pl.Int64) * 1000)
                        .cast(pl.Datetime(time_zone='UTC'))
                        .alias(feature.name)
                    )
                else:
                    df = df.with_column(pl.col(feature.name).cast(feature.dtype.polars_type, strict=False))
            if request.event_timestamp:
                feature = request.event_timestamp
                if feature.name not in df.columns:
                    continue
                current_dtype = df.select([feature.name]).dtypes[0]
                if isinstance(current_dtype, pl.Datetime):
                    continue
                df = df.with_column(
                    (pl.col(feature.name).cast(pl.Int64) * 1000)
                    .cast(pl.Datetime(time_zone='UTC'))
                    .alias(feature.name)
                )
        return df

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


@dataclass
class CombineFactualJob(RetrivalJob):
    """Computes features that depend on different retrical jobs

    The `job` therefore take in a list of jobs that output some data,
    and a `combined_requests` which defines the features depending on the data

    one example would be the following

    class SomeView(FeatureView):
        metadata = FeatureViewMetadata(
            name="some_view",
            batch_source=FileSource.csv_at("data.csv")
        )
        id = Entity(Int32())
        a = Int32()

    class OtherView(FeatureView):
        metadata = FeatureViewMetadata(
            name="other_view",
            batch_source=FileSource.parquet_at("other.parquet")
        )
        id = Entity(Int32())
        c = Int32()

    class Combined(CombinedFeatureView):
        metadata = CombinedMetadata(name="combined")

        some = SomeView()
        other = OtherView()

        added = some.a + other.c
    """

    jobs: list[RetrivalJob]
    combined_requests: list[RetrivalRequest]

    @property
    def request_result(self) -> RequestResult:
        return RequestResult.from_result_list(
            [job.request_result for job in self.jobs]
        ) + RequestResult.from_request_list(self.combined_requests)

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return [job.retrival_requests for job in self.jobs] + [self.combined_requests]

    async def combine_data(self, df: pd.DataFrame) -> pd.DataFrame:
        for request in self.combined_requests:
            for feature in request.derived_features:
                if feature.name in df.columns:
                    logger.info(f'Skipping feature {feature.name}, already computed')
                    continue
                logger.info(f'Computing feature: {feature.name}')
                df[feature.name] = await feature.transformation.transform_pandas(
                    df[feature.depending_on_names]
                )
        return df

    async def combine_polars_data(self, df: pl.LazyFrame) -> pl.LazyFrame:
        for request in self.combined_requests:
            for feature in request.derived_features:
                if feature.name in df.columns:
                    logger.info(f'Skipping feature {feature.name}, already computed')
                    continue
                logger.info(f'Computing feature: {feature.name}')
                result = await feature.transformation.transform_polars(df, feature.name)
                if isinstance(result, pl.Expr):
                    df = df.with_columns([result.alias(feature.name)])
                elif isinstance(result, pl.LazyFrame):
                    df = result
                else:
                    raise ValueError(f'Unsupported transformation result type, got {type(result)}')
        return df

    async def to_pandas(self) -> pd.DataFrame:
        job_count = len(self.jobs)
        if job_count > 1:
            dfs = await asyncio.gather(*[job.to_pandas() for job in self.jobs])
            df = pd.concat(dfs, axis=1)
            combined = await self.combine_data(df)
            return combined.loc[:, ~df.columns.duplicated()].copy()
        elif job_count == 1:
            df = await self.jobs[0].to_pandas()
            return await self.combine_data(df)
        else:
            raise ValueError(
                'Have no jobs to fetch. This is probably an internal error.\n'
                'Please submit an issue, and describe how to reproduce it.\n'
                'Or maybe even submit a PR'
            )

    async def to_polars(self) -> pl.LazyFrame:
        if not self.jobs:
            raise ValueError(
                'Have no jobs to fetch. This is probably an internal error.\n'
                'Please submit an issue, and describe how to reproduce it.\n'
                'Or maybe even submit a PR'
            )

        dfs: list[pl.LazyFrame] = await asyncio.gather(*[job.to_polars() for job in self.jobs])

        df = dfs[0]

        for other_df in dfs[1:]:
            df = df.with_context(other_df).select(pl.all())

        # df = pl.concat(dfs_to_concat, how='horizontal')
        return await self.combine_polars_data(df)

    def cached_at(self, location: DataFileReference | str) -> RetrivalJob:
        return CombineFactualJob([job.cached_at(location) for job in self.jobs], self.combined_requests)

    def remove_derived_features(self) -> RetrivalJob:
        return CombineFactualJob([job.remove_derived_features() for job in self.jobs], self.combined_requests)

    def describe(self) -> str:
        description = f'Combining {len(self.jobs)} jobs:\n'
        for index, job in enumerate(self.jobs):
            description += f'{index + 1}: {job.describe()}\n'
        return description


@dataclass
class FilterJob(RetrivalJob, ModificationJob):

    include_features: set[str]
    job: RetrivalJob

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result.filter_features(self.include_features)

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return [
            request.filter_features(request.all_feature_names - self.include_features)
            for request in self.job.retrival_requests
        ]

    async def to_pandas(self) -> pd.DataFrame:
        df = await self.job.to_pandas()
        if self.include_features:
            total_list = list({ent.name for ent in self.request_result.entities}.union(self.include_features))
            return df[total_list]
        else:
            return df

    async def to_polars(self) -> pl.LazyFrame:
        df = await self.job.to_polars()
        if self.include_features:
            total_list = list({ent.name for ent in self.request_result.entities}.union(self.include_features))
            return df.select(total_list)
        else:
            return df

    def validate(self, validator: Validator) -> RetrivalJob:
        return FilterJob(self.include_features, self.job.validate(validator))

    def cached_at(self, location: DataFileReference | str) -> RetrivalJob:

        return FilterJob(self.include_features, self.job.cached_at(location))

    def with_subfeatures(self) -> RetrivalJob:
        return self.job.with_subfeatures()

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()


@dataclass
class ListenForTriggers(RetrivalJob, ModificationJob):

    job: RetrivalJob
    triggers: set[EventTrigger]

    @property
    def request_result(self) -> RequestResult:
        return self.job.request_result

    @property
    def retrival_requests(self) -> list[RetrivalRequest]:
        return self.job.retrival_requests

    async def to_pandas(self) -> pd.DataFrame:
        import asyncio

        df = await self.job.to_pandas()
        await asyncio.gather(*[trigger.check_pandas(df, self.request_result) for trigger in self.triggers])
        return df

    async def to_polars(self) -> pl.LazyFrame:
        import asyncio

        df = await self.job.to_polars()
        await asyncio.gather(*[trigger.check_polars(df, self.request_result) for trigger in self.triggers])
        return df

    def remove_derived_features(self) -> RetrivalJob:
        return self.job.remove_derived_features()
