import os
from pathlib import Path
from typing import Any

import pandas as pd
import sqlalchemy as sa

from pydiverse.pipedag import (
    Flow,
    Stage,
    Table,
    materialize,
    input_stage_versions,
    ConfigContext,
)
from pydiverse.pipedag.context import StageLockContext
from pydiverse.pipedag.core.config import PipedagConfig
from pydiverse.common.util.structlog import setup_logging


# these global variables are just a mock for information that can be sourced from somewhere for the full pipeline
dfA_source = pd.DataFrame(
    {
        "a": [0, 1, 2, 4],
        "b": [9, 8, 7, 6],
    }
)
dfA = dfA_source.copy()
input_hash = hash(str(dfA))


def has_new_input():
    global input_hash
    return input_hash


@materialize(nout=2, cache=has_new_input, version="1.0")
def input_task():
    global dfA
    return Table(dfA, "dfA"), Table(dfA, "dfB")


def has_copy_source_fresh_input(
    tbls: dict[str, sa.sql.expression.Alias],
    other_tbls: dict[str, sa.sql.expression.Alias],
    source_cfg: ConfigContext,
    *,
    attrs: dict[str, Any],
    stage: Stage,
):
    _ = tbls, other_tbls, attrs
    with source_cfg:
        _hash = source_cfg.store.table_store.get_stage_hash(stage)
    return _hash


@input_stage_versions(
    input_type=sa.Table,
    cache=has_copy_source_fresh_input,
    pass_args=["attrs", "stage"],
    lazy=True,
)
def copy_filtered_inputs(
    tbls: dict[str, sa.sql.expression.Alias],
    source_tbls: dict[str, sa.sql.expression.Alias],
    source_cfg: ConfigContext,
    *,
    attrs: dict[str, Any],
    stage: Stage,
):
    # we assume that tables can be copied within database engine just from different schema
    _ = source_cfg, stage
    # we expect this schema to be still empty, one could check for collisions
    _ = tbls
    filter_cnt = attrs["copy_filter_cnt"]
    ret = {
        name.lower(): Table(sa.select(tbl).limit(filter_cnt), name)
        for name, tbl in source_tbls.items()
    }
    return ret


@materialize(input_type=pd.DataFrame, version="1.0")
def double_values(df: pd.DataFrame):
    return Table(df.transform(lambda x: x * 2))


# noinspection PyTypeChecker
def get_flow(attrs: dict[str, Any], pipedag_config):
    with Flow("test_instance_selection") as flow:
        with Stage("stage_1") as stage:
            if not attrs["copy_filtered_input"]:
                a, b = input_task()
            else:
                other_cfg = pipedag_config.get(
                    attrs["copy_source"], attrs["copy_per_user"]
                )
                tbls = copy_filtered_inputs(other_cfg, stage=stage, attrs=attrs)
                a, b = tbls["dfa"], tbls["dfb"]
            a2 = double_values(a)

        with Stage("stage_2"):
            b2 = double_values(b)
            a3 = double_values(a2)

    return flow, b2, a3


def check_result(result, out1, out2, *, head=999):
    assert result.successful
    v_out1, v_out2 = result.get(out1), result.get(out2)
    pd.testing.assert_frame_equal(dfA_source.head(head) * 2, v_out1, check_dtype=False)
    pd.testing.assert_frame_equal(dfA_source.head(head) * 4, v_out2, check_dtype=False)


if __name__ == "__main__":
    password_cfg_path = str(Path(__file__).parent / "postgres_password.yaml")

    old_environ = dict(os.environ)
    os.environ["POSTGRES_PASSWORD_CFG"] = password_cfg_path

    setup_logging()  # you can setup the logging and/or structlog libraries as you wish
    pipedag_config = PipedagConfig(Path(__file__).parent / "pipedag.yaml")

    cfg = pipedag_config.get(instance="full")
    flow, out1, out2 = get_flow(cfg.attrs, pipedag_config)

    with StageLockContext():
        result = flow.run(config=cfg)
        check_result(result, out1, out2)
