
from datetime import timedelta, datetime
from datetime import datetime
from pyspark.sql import Window
from pyspark.sql import functions as f
from pyspark.sql.types import *
import pandas as pd
import sys

__version__ = '0.0.15'
class spark_functions():
    def __init__(self, spark=None, health_table_name = None) -> None:
        self.spark = spark
        self.health_table_name = health_table_name
    def sample_function(self):
        print("Sample is working")
        pass

    def get_top_duplicates(self,df,col='customer_hash',n=2):
        return (df.groupBy(col)
                .agg(f.count(col).alias('count'))
                .orderBy(f.col('count').desc_nulls_last())
                .limit(n))

    def sdf_to_dwh(self,sdf,table_address,mode,mergeSchema = "true"):
        (sdf.write.mode(mode)
            .option("mergeSchema", mergeSchema)
            .saveAsTable(table_address))

    def sdf_fillDown(self,sdf,groupCol,orderCol,cols_to_fill):   
        window_spec = Window.partitionBy(groupCol).orderBy(orderCol)
        
        for column in cols_to_fill:
            # sdf = sdf.withColumn(column, f.last(f.col(column),ignorenulls=True).over(window))
            sdf = (sdf
                .withColumn(column,
                            f.last(column, ignorenulls=True).over(window_spec))
                )
        return sdf
    
    def sdf_fillUp(self,sdf,groupCol,orderCol,cols_to_fill):
        window_spec = Window.partitionBy(groupCol).orderBy(f.col(orderCol).desc_nulls_last())
        
        for column in cols_to_fill:
            # sdf = sdf.withColumn(column, f.last(f.col(column),ignorenulls=True).over(window))
            sdf = (sdf
                .withColumn(column,
                            f.last(column, ignorenulls=True).over(window_spec))
                )
        return sdf
    
    def sdf_fill_gaps(self,sdf,groupCol,orderCol,cols_to_fill,direction='both'):
        if direction == 'up':
            sdf = self.sdf_fillUp(sdf,groupCol,orderCol,cols_to_fill)
        elif direction == 'down':
            sdf = self.sdf_fillDown(sdf,groupCol,orderCol,cols_to_fill)
        else:
            sdf = self.sdf_fillDown(sdf,groupCol,orderCol,cols_to_fill)
            sdf = self.sdf_fillUp(sdf,groupCol,orderCol,cols_to_fill)
        return sdf
    
    def single_value_expr(partition_col, order_col, value_col, ascending=False):
        windowSpec = Window.partitionBy(partition_col).orderBy(order_col)
        if ascending:
            return f.first(f.when(f.col(order_col) == f.min(order_col).over(windowSpec), f.col(value_col)), True)
        else:
            return f.first(f.when(f.col(order_col) == f.max(order_col).over(windowSpec), f.col(value_col)), True)

    def read_dwh_table(self,table_name, last_update_column=None, save_health=True):
        sdf = self.spark.table(table_name)
        if save_health:
            try:
                last_update = sdf\
                                .filter(
                                f.col(last_update_column).cast('timestamp') < \
                                    (datetime.today()+timedelta(days=1)).strftime('%Y-%m-%d'))\
                                .select(f.max(f.col(last_update_column).cast('timestamp')).alias('last_update'))\
                                .collect()[0]['last_update']
                health_data = {'table_name': [table_name], 'last_update': [last_update],
                               'update_date_IST':[datetime.now() + timedelta(hours=5, minutes=30)]}
                health_sdf =  self.spark.createDataFrame(pd.DataFrame(data=health_data))
                self.sdf_to_dwh(health_sdf,self.health_table_name,'append')
            except: 
                pass
        return (sdf)

    def remove_duplicates_keep_latest(self,sdf, partition_col: str, order_col: str):
        """
        Removes duplicate rows based on the partition_col, keeping only the row with the highest value in order_col.

        Parameters:
        - df (DataFrame): The Spark DataFrame to process.
        - partition_col (str): The name of the column to partition the data (e.g., 'customer_hash').
        - order_col (str): The name of the column to order data within each partition (e.g., 'created_at').
        Returns:
        - DataFrame: A new DataFrame with duplicates removed based on partition_col, keeping only the latest record based on order_col.
        """
        # Define the window specification
        windowSpec = Window.partitionBy(partition_col).orderBy(f.col(order_col).desc_nulls_last())

        # Rank rows within each partition and filter to keep only the top-ranked row
        filtered_df = sdf.withColumn("row_number", f.row_number().over(windowSpec)) \
                        .filter(f.col("row_number") == 1) \
                        .drop("row_number")

        return filtered_df
    def remove_duplicates(self,sdf, partition_col: str, order_col: str, ascending = False):
        """
        Removes duplicate rows based on the partition_col, keeping only the row with the single value in order_col. 
        Ordering will beased on ascending variable.

        Parameters:
        - df (DataFrame): The Spark DataFrame to process.
        - partition_col (str): The name of the column to partition the data (e.g., 'customer_hash').
        - order_col (str): The name of the column to order data within each partition (e.g., 'created_at').
        - ascending (int): 1 means ascending order, 0 means descending order

        Returns:
        - DataFrame: A new DataFrame with duplicates removed based on partition_col, keeping only the latest record based on order_col.
        """
        # Define the window specification
        if ascending:
            windowSpec = Window.partitionBy(partition_col).orderBy(f.col(order_col).asc_nulls_last())
        else:
            windowSpec = Window.partitionBy(partition_col).orderBy(f.col(order_col).desc_nulls_last())

        # Rank rows within each partition and filter to keep only the top-ranked row
        filtered_df = sdf.withColumn("row_number", f.row_number().over(windowSpec)) \
                        .filter(f.col("row_number") == 1) \
                        .drop("row_number")

        return filtered_df
    
    def attribute_actions(
        self,
    action_table, 
    action_table_date_column: str, 
    action_table_id_column: str, 
    cta_table, 
    cta_table_date_column: str, 
    action_entity: str, 
    attribution_days: int, 
    attribution_chronology: str = 'last'
):
        """
        Attributes actions from the `action_table` to events in the `cta_table` within a specified attribution window.
        
        Args:
            action_table (DataFrame): The table containing user actions, such as transactions.
            action_table_date_column (str): The column name representing the date of the action in `action_table`.
            action_table_id_column (str): The unique identifier column for actions in `action_table`.
            cta_table (DataFrame): The table containing call-to-action events, like campaigns or banners.
            cta_table_date_column (str): The column name representing the date of the event in `cta_table`.
            action_entity (str): The entity (e.g., user ID) used to join `action_table` and `cta_table`.
            attribution_days (int): The number of days within which an action can be attributed to a CTA.
            attribution_chronology (Literal['last', 'first'], optional): Whether to attribute to the most recent ('last') 
                or earliest ('first') CTA within the attribution window. Defaults to 'last'.

        Returns:
            DataFrame: The `action_table` with an additional column indicating the attributed CTA.

        Raises:
            ValueError: If `attribution_chronology` is not 'last' or 'first'.

        """
        # Filter and retain only the necessary columns from action_table
        action_table_slim = (
            action_table.select(
                action_table_id_column, action_table_date_column, action_entity
            )
        )

        # Join the action table with the CTA table on the action_entity and filter by the attribution window
        action_table_attributed = (
            action_table_slim
            .join(cta_table, [action_entity], 'inner')
            .filter(f.col(action_table_date_column) >= f.col(cta_table_date_column))
            .filter(
                f.col(action_table_date_column) 
                <= f.date_add(f.col(cta_table_date_column), attribution_days)
            )
        )

        # Determine sorting order based on attribution chronology
        if attribution_chronology == 'last':
            ascending_order = False
        elif attribution_chronology == 'first':
            ascending_order = True
        else:
            raise ValueError("`attribution_chronology` must be either 'last' or 'first'.")

        # Deduplicate actions to retain only the most relevant CTA based on chronology
        action_table_attributed = (
            self.remove_duplicates(
                action_table_attributed,
                partition_col=action_table_id_column,
                order_col=cta_table_date_column,
                ascending=ascending_order
            )
            .drop(action_table_date_column, action_entity)  # Drop unnecessary columns
        )

        # Join the attributed CTAs back to the original action table
        action_table = action_table.join(
            action_table_attributed, [action_table_id_column], 'left'
        )

        return action_table
    
    def prefix_column_names(self,sdf, prefix, col_list=None, exclude_col_list=None):
        """
        Add a prefix to specified columns in a Spark DataFrame.
        
        Parameters:
        sdf (DataFrame): The Spark DataFrame.
        prefix (str): The prefix to add to the column names.
        col_list (list, optional): List of columns to rename. If None, all columns are renamed.
        exclude_col_list (list, optional): List of columns to exclude from renaming. Only used if col_list is None.
        
        Returns:
        DataFrame: The DataFrame with renamed columns.
        """
        
        # If col_list is not provided, use all columns except those in exclude_col_list
        if col_list is None:
            if exclude_col_list is None:
                col_list = sdf.columns
            else:
                # Ensure exclude_col_list is a list
                if not isinstance(exclude_col_list, list):
                    exclude_col_list = [exclude_col_list]
                # Select columns not in exclude_col_list
                col_list = [col for col in sdf.columns if col not in exclude_col_list]
        
        # Ensure col_list is a list
        if not isinstance(col_list, list):
            col_list = [col_list]
        
        # Rename columns by adding the prefix
        for col in col_list:
            sdf = sdf.withColumnRenamed(col, prefix + col)
        
        return sdf
    
    def flatten_sdf(df, columns_to_flatten=None, keywords=None):
        """
        Recursively flatten specified struct, array, and map columns in a DataFrame.
        Keep all original columns that weren't flattened, plus flattened columns
        containing any of the specified keywords.
        
        Args:
            df: PySpark DataFrame with nested columns
            columns_to_flatten: List of column names to flatten or a single column name as string.
                            If None, all columns are considered.
            keywords: List of keywords or a single keyword as string. 
                    Only keep newly generated columns containing any of these keywords.
            
        Returns:
            DataFrame with original columns and filtered flattened columns
        """
        # Store the original column names for later
        original_columns = list(df.columns)
        
        # Handle case where columns_to_flatten is a single string
        if isinstance(columns_to_flatten, str):
            columns_to_flatten = [columns_to_flatten]
        
        # Handle case where keywords is a single string
        if isinstance(keywords, str):
            keywords = [keywords]
        
        # If no specific columns are provided, consider all columns
        if columns_to_flatten is None:
            columns_to_flatten = list(df.columns)
        else:
            columns_to_flatten = list(columns_to_flatten)  # Create a copy
        
        # Get list of columns that will not be flattened and should be preserved at the end
        columns_to_be_flattened = [col for col in columns_to_flatten if col in df.columns]
        preserved_columns = [col for col in original_columns if col not in columns_to_be_flattened]
        
        # First, flatten all nested structures
        result_df = df
        
        # Track columns that are currently being targeted for flattening
        current_flatten_targets = columns_to_be_flattened.copy()
        
        # Track all flattened column names
        all_flattened_columns = []
        
        # Continue flattening until no more nested structures found
        while current_flatten_targets:
            next_targets = []
            
            for col_name in current_flatten_targets:
                if col_name not in result_df.columns:
                    continue
                    
                col_type = result_df.schema[col_name].dataType
                
                # Handle struct columns
                if isinstance(col_type, StructType):
                    # Extract nested fields
                    nested_cols = []
                    
                    for field in col_type.fields:
                        nested_col_name = f"{col_name}_{field.name}"
                        nested_cols.append(f.col(f"{col_name}.{field.name}").alias(nested_col_name))
                        all_flattened_columns.append(nested_col_name)
                        
                        # Check if this field needs further flattening
                        field_type = col_type[field.name].dataType
                        if isinstance(field_type, StructType) or isinstance(field_type, ArrayType) or isinstance(field_type, MapType):
                            next_targets.append(nested_col_name)
                    
                    # Replace the struct column with its flattened fields
                    cols_to_select = [c for c in result_df.columns if c != col_name]
                    result_df = result_df.select(*cols_to_select, *nested_cols)
                
                # Handle array columns
                elif isinstance(col_type, ArrayType):
                    element_type = col_type.elementType
                    
                    # Process array of structs
                    if isinstance(element_type, StructType):
                        # Explode array
                        exploded_col = f"{col_name}_exploded"
                        result_df = result_df.withColumn(exploded_col, f.explode_outer(f.col(col_name)))
                        
                        # Extract fields from exploded column
                        nested_cols = []
                        
                        for field in element_type.fields:
                            nested_col_name = f"{col_name}_{field.name}"
                            nested_cols.append(f.col(f"{exploded_col}.{field.name}").alias(nested_col_name))
                            all_flattened_columns.append(nested_col_name)
                            
                            # Check if this field needs further flattening
                            field_type = element_type[field.name].dataType
                            if isinstance(field_type, StructType) or isinstance(field_type, ArrayType) or isinstance(field_type, MapType):
                                next_targets.append(nested_col_name)
                        
                        # Replace the array column with its flattened fields
                        cols_to_select = [c for c in result_df.columns if c != col_name and c != exploded_col]
                        result_df = result_df.select(*cols_to_select, *nested_cols)
                    
                    # Process array of primitives or other non-struct types
                    else:
                        # Convert to string
                        result_df = result_df.withColumn(col_name, f.concat_ws(",", f.col(col_name)))
                
                # Handle map columns
                elif isinstance(col_type, MapType):
                    try:
                        # Create a temporary view of the data with just the map column
                        result_df.createOrReplaceTempView("temp_map_view")
                        
                        # Use SQL to get distinct keys
                        keys_df = result_df.sparkSession.sql(f"""
                            SELECT DISTINCT explode(map_keys({col_name})) as key
                            FROM temp_map_view
                            LIMIT 1000
                        """)
                        
                        # Collect the keys (this is a small operation as we've limited to 1000 distinct keys)
                        sample_keys = [row.key for row in keys_df.collect()]
                        
                        # For each key, create a column
                        for key in sample_keys:
                            # Convert the key to a safe column name
                            safe_key = str(key).replace(" ", "_").replace("-", "_").replace(".", "_")
                            col_alias = f"{col_name}_{safe_key}"
                            
                            # Get the value for this key
                            result_df = result_df.withColumn(col_alias, f.col(col_name).getItem(key))
                            all_flattened_columns.append(col_alias)
                            
                            # Check if this value needs further flattening
                            value_type = col_type.valueType
                            if isinstance(value_type, StructType) or isinstance(value_type, ArrayType) or isinstance(value_type, MapType):
                                next_targets.append(col_alias)
                        
                        # Drop the original map column
                        result_df = result_df.drop(col_name)
                        
                    except Exception as e:
                        # If there's an error, try a simpler approach
                        print(f"Error handling map column {col_name}: {str(e)}")
                        # Convert to string as fallback
                        result_df = result_df.withColumn(col_name, f.to_json(f.col(col_name)))
            
            # Update target columns for next iteration
            current_flatten_targets = next_targets
        
        # Now filter columns based on keywords if provided
        if keywords:
            keywords_lower = [k.lower() for k in keywords]
            keyword_columns = [
                col for col in all_flattened_columns 
                if any(keyword in col.lower() for keyword in keywords_lower)
            ]
        else:
            keyword_columns = all_flattened_columns
        
        # Combine preserved columns and keyword-matching flattened columns
        final_columns = preserved_columns + keyword_columns
        
        # Verify all columns exist in result_df
        existing_columns = [col for col in final_columns if col in result_df.columns]
        
        # Return the final result
        return result_df.select(*existing_columns)
    
    