# Copyright (c) QuantCo and pydiverse contributors 2025-2025
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path

import pandas as pd
import polars as pl
import pytest
import sqlalchemy as sa
from sqlalchemy.exc import ProgrammingError

import tests.util.tasks_library as m
from pydiverse.pipedag import ConfigContext, Flow, Stage, materialize
from pydiverse.pipedag.backend.table.sql.ddl import (
    CreateSchema,
    CreateTableAsSelect,
    CreateViewAsSelect,
    DropTable,
    DropView,
)
from pydiverse.pipedag.container import ExternalTableReference, Schema, Table
from pydiverse.pipedag.context.context import CacheValidationMode

# Parameterize all tests in this file with several instance_id configurations
from tests.fixtures.instances import DATABASE_INSTANCES, skip_instances, with_instances
from tests.util import swallowing_raises
from tests.util.sql import sql_table_expr

pytestmark = [with_instances(DATABASE_INSTANCES)]


@skip_instances("parquet_backend", "parquet_s3_backend", "parquet_s3_backend_db2")
def test_smoke_table_reference():
    @materialize(lazy=True)
    def in_table():
        table_store = ConfigContext.get().store.table_store
        schema = Schema("user_controlled_schema", prefix="", suffix="")
        table_name = "external_table"
        table_store.execute(CreateSchema(schema, if_not_exists=True))
        table_store.execute(DropTable(table_name, schema, if_exists=True, cascade=True))
        query = sql_table_expr({"col": [0, 1, 2, 3]})
        table_store.execute(
            CreateTableAsSelect(
                table_name,
                schema,
                query,
            )
        )
        return Table(ExternalTableReference(table_name, schema=schema.get()))

    @materialize(input_type=sa.Table)
    def duplicate_table_reference(tbl: sa.sql.expression.Alias):
        return Table(ExternalTableReference(tbl.original.name, schema=tbl.original.schema))

    with Flow() as f:
        with Stage("sql_table_reference"):
            table1 = in_table()
            table2 = duplicate_table_reference(table1)
            _ = m.assert_table_equal(table1, table2, check_dtype=False)

        with Stage("reference_constant"):
            # This tests ExternalTableReference produced at declare time. It is not
            # recommended to reference any table generated by pipedag code this way.
            # This just helps produce test input for this feature test.
            table3 = Table(ExternalTableReference("external_table", schema="user_controlled_schema"))
            table4 = duplicate_table_reference(table3)
            _ = m.assert_table_equal(table3, table4, check_dtype=False)

    lock_path = Path(__file__).parent / "lock"
    import filelock

    with filelock.FileLock(lock_path):
        assert f.run().successful


@skip_instances("parquet_backend", "parquet_s3_backend", "parquet_s3_backend_db2")
@pytest.mark.skipif(pl is None, reason="polars is needed for this test")
def test_table_store():
    @materialize(version="1.1")
    def in_table():
        table_store = ConfigContext.get().store.table_store
        schema = Schema("user_controlled_schema", prefix="", suffix="")
        table_name = "external_table"
        table_store.execute(CreateSchema(schema, if_not_exists=True))
        try:
            table_store.execute(DropView("external_view", schema))
        except ProgrammingError:
            pass
        table_store.execute(DropTable(table_name, schema, if_exists=True))
        query = sql_table_expr({"col": [0, 1, 2, 3]})
        table_store.execute(
            CreateTableAsSelect(
                table_name,
                schema,
                query,
            )
        )
        return Table(ExternalTableReference(table_name, schema=schema.get()))

    @materialize(input_type=sa.Table)
    def duplicate_table_reference(tbl: sa.sql.expression.Alias):
        return Table(ExternalTableReference(tbl.original.name, schema=tbl.original.schema))

    @materialize(version="1.1", input_type=sa.Table)
    def in_view(tbl: sa.sql.expression.Alias):
        table_store = ConfigContext.get().store.table_store
        schema = Schema("user_controlled_schema", prefix="", suffix="")
        view_name = "external_view"
        try:
            # We cannot use if_exists=True here because DB2 does not support it
            table_store.execute(DropView(view_name, schema))
        except ProgrammingError:
            pass
        query = sa.select(tbl.c.col).where(tbl.c.col > 1)
        table_store.execute(
            CreateViewAsSelect(
                view_name,
                schema,
                query,
            )
        )
        return Table(ExternalTableReference(view_name, schema=schema.get()))

    @materialize
    def expected_out_table():
        return Table(
            pd.DataFrame(
                {
                    "col": [0, 1, 2, 3],
                }
            )
        )

    @materialize
    def expected_out_view():
        return Table(
            pd.DataFrame(
                {
                    "col": [2, 3],
                }
            )
        )

    @materialize(input_type=sa.Table)
    def copy_table(tbl: sa.sql.expression.Alias):
        query = sa.select(tbl)
        return Table(query, name=tbl.original.name)

    with Flow() as f:
        with Stage("sql_table_reference"):
            external_table = in_table()
            _ = duplicate_table_reference(external_table)
            expected_external_table = expected_out_table()
            _ = copy_table(external_table)

            _ = m.assert_table_equal(external_table, expected_external_table, check_dtype=False)

        with Stage("sql_view_reference"):
            external_view = in_view(external_table)
            expected_external_view = expected_out_view()
            _ = m.assert_table_equal(external_view, expected_external_view, check_dtype=False)
            external_view_polars = m.noop_polars(external_view)
            external_view_lazy_polars = m.noop_lazy_polars(external_view)
            _ = m.assert_table_equal(external_view_polars, expected_external_view, check_dtype=False)
            _ = m.assert_table_equal(external_view_lazy_polars, expected_external_view, check_dtype=False)

    assert f.run(cache_validation_mode=CacheValidationMode.FORCE_CACHE_INVALID).successful
    assert f.run().successful
    assert f.run().successful


def test_bad_table_reference():
    @materialize()
    def bad_table_reference():
        return Table(
            ExternalTableReference(name="this_table_does_not_exist", schema="ext_schema"),
        )

    with Flow() as f:
        with Stage("sql_table_reference"):
            bad_table_reference()

    with swallowing_raises(ValueError, match="this_table_does_not_exist"):
        f.run()
