import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import tdsbrondata
from pyspark.sql import functions as F
from pyspark.sql.types import *
from delta.tables import DeltaTable

requiredColumns = {
    "DienstId": IntegerType(),
    "Source": StringType(),
    "SurrogateKey": LongType(),
    "FacturatieMaand": StringType(),
    "VerkooprelatieId": StringType(),
    "ItemCode": StringType(),
    "Aantal": FloatType()
}

optionalColumns = {
    "AfwijkendePrijs": FloatType(),
    "DatumVanOrigineel": StringType(),
    "DatumTotOrigineel": StringType(),
    "AantalOrigineel": FloatType(),
    "DurationOrigineel": FloatType()
}

def validateItemsSchema(items):
    schemaDict = {f.name: f.dataType for f in items.schema.fields}

    for columnName, expectedType in requiredColumns.items():
        if columnName not in schemaDict:
            raise ValueError(f"Missing required column '{columnName}'")
        if type(schemaDict[columnName]) != type(expectedType):
            raise TypeError(f"Column '{columnName}' has type {schemaDict[columnName]}, expected {expectedType}")

    for columnName, expectedType in optionalColumns.items():
        if columnName in schemaDict and type(schemaDict[columnName]) != type(expectedType):
            raise TypeError(f"Column '{columnName}' has type {schemaDict[columnName]}, expected {expectedType}")

def writeItems(items):
    validateItemsSchema(items)

    table = "items"
    tablePath = f"{tdsbrondata.tablesRootPath}/{table}"

    if DeltaTable.isDeltaTable(tdsbrondata._spark, tablePath):
        deltaTable = DeltaTable.forPath(tdsbrondata._spark, tablePath)
        facturatieMaand = items.select("FacturatieMaand").head()["FacturatieMaand"]
        deltaTable.delete(f"FacturatieMaand = '{facturatieMaand}'")
        items.write.format("delta").mode("append").save(tablePath)
    else:
        items.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(tablePath)

    writeLines(items)

def writeLines(items):

    table = "lines"
    tablePath = f"{tdsbrondata.tablesRootPath}/{table}"

    facturatieMaand = items.select("FacturatieMaand").head()["FacturatieMaand"]

    existingOptionalColumns = [col for col in optionalColumns.keys() if col in items.columns]

    if existingOptionalColumns:
        filterExpression = " AND ".join([f"{col} IS NULL" for col in existingOptionalColumns])
        itemsToAggregate = items.filter(filterExpression)
        itemsToPreserve = items.filter(f"NOT ({filterExpression})")
    else:
        itemsToAggregate = items
        itemsToPreserve = items.limit(0)

    for col, colType in optionalColumns.items():
        if col not in items.columns:
            itemsToAggregate = itemsToAggregate.withColumn(col, F.lit(None).cast(colType))
            itemsToPreserve = itemsToPreserve.withColumn(col, F.lit(None).cast(colType))

    aggregated = itemsToAggregate.groupBy(
        "DienstId", "FacturatieMaand", "VerkooprelatieId", "ItemCode"
    ).agg(
        F.sum("Aantal").alias("Aantal"),
        F.collect_list("SurrogateKey").alias("SurrogateKeys")
    )

    for col, colType in optionalColumns.items():
        aggregated = aggregated.withColumn(col, F.lit(None).cast(colType))

    aggregated = aggregated.withColumn(
        "HasItems",
        F.when(
            (F.size("SurrogateKeys") > 1)
            & (F.expr("aggregate(SurrogateKeys, true, (accumulator, x) -> accumulator AND x IS NOT NULL)")),
            True
        ).otherwise(False)
    ).drop("SurrogateKeys")

    itemsToPreserve = itemsToPreserve.withColumn(
        "HasItems",
        F.when(F.col("SurrogateKey").isNotNull(), True).otherwise(False)
    )

    lines = aggregated.unionByName(itemsToPreserve)

    if DeltaTable.isDeltaTable(tdsbrondata._spark, tablePath):
        deltaTable = DeltaTable.forPath(tdsbrondata._spark, tablePath)
        deltaTable.delete(f"FacturatieMaand = '{facturatieMaand}'")
        lines.write.format("delta").mode("append").save(tablePath)
    else:
        lines.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(tablePath)