from typing import List, Dict, Tuple, Optional
import pandas as pd


def get_table_stats(duckrun_con, table_name: str, 
                   top_n_values: int = 10) -> pd.DataFrame:
    """
    Get comprehensive table statistics including NDV and value frequency analysis.
    
    The theory: If a value appears frequently (high repetition), it may provide better RLE compression
    even if the column has higher NDV. This function helps identify such patterns.
    
    Args:
        duckrun_con: Duckrun connection (from duckrun.connect())
        table_name: Name of the table to analyze
        top_n_values: Number of top frequent values to show per column (default: 10)
    
    Returns:
        DataFrame with columns:
        - column_name: Name of the column
        - data_type: Data type of the column
        - total_rows: Total number of rows
        - null_count: Number of NULL values
        - null_pct: Percentage of NULL values
        - ndv: Number of distinct values (exact)
        - cardinality_ratio: NDV / total_rows (lower = better for RLE)
        - top_value: Most frequent value
        - top_value_count: Count of most frequent value
        - top_value_pct: Percentage of most frequent value
        - top_n_coverage: Percentage covered by top N values
        - repetition_score: Custom score indicating RLE potential (higher = better)
    """
    con = duckrun_con.con  # Get underlying DuckDB connection
    from_clause = table_name
    
    # Get column names and types
    schema_info = con.sql(f"""
        SELECT column_name, column_type
        FROM (DESCRIBE SELECT * FROM {from_clause})
    """).df()
    
    if schema_info.empty:
        return pd.DataFrame()
    
    # Get total row count once
    total_rows = con.sql(f"SELECT COUNT(*) FROM {from_clause}").fetchone()[0]
    print(f"Analyzing {len(schema_info)} columns across {total_rows:,} rows...")
    
    results = []
    
    for idx, row in schema_info.iterrows():
        col_name = row['column_name']
        col_type = row['column_type']
        
        print(f"  [{idx+1}/{len(schema_info)}] Analyzing column: {col_name}")
        
        # Get basic stats in one query
        stats_query = f"""
        SELECT 
            COUNT(*) as total,
            COUNT({col_name}) as non_null,
            COUNT(DISTINCT {col_name}) as ndv
        FROM {from_clause}
        """
        
        stats = con.sql(stats_query).fetchone()
        total = stats[0]
        non_null = stats[1]
        ndv = stats[2]
        null_count = total - non_null
        null_pct = (null_count / total * 100) if total > 0 else 0
        cardinality_ratio = (ndv / total) if total > 0 else 0
        
        # Get top N values with their frequencies
        top_values_query = f"""
        SELECT 
            {col_name} as value,
            COUNT(*) as count,
            COUNT(*) * 100.0 / {total} as percentage
        FROM {from_clause}
        WHERE {col_name} IS NOT NULL
        GROUP BY {col_name}
        ORDER BY count DESC
        LIMIT {top_n_values}
        """
        
        top_values = con.sql(top_values_query).df()
        
        # Extract top value info
        if not top_values.empty:
            top_value = top_values.iloc[0]['value']
            top_value_count = top_values.iloc[0]['count']
            top_value_pct = top_values.iloc[0]['percentage']
            top_n_coverage = top_values['percentage'].sum()
        else:
            top_value = None
            top_value_count = 0
            top_value_pct = 0
            top_n_coverage = 0
        
        # Calculate repetition score: higher means better for RLE
        # Score considers:
        # 1. How much the top value covers (higher = better)
        # 2. How much top N values cover (higher = better)  
        # 3. Inverse of cardinality ratio (lower cardinality = better)
        repetition_score = (top_value_pct * 2 + top_n_coverage) / 3 / (cardinality_ratio + 0.01)
        
        results.append({
            'column_name': col_name,
            'data_type': col_type,
            'total_rows': total_rows,
            'null_count': null_count,
            'null_pct': round(null_pct, 2),
            'ndv': ndv,
            'cardinality_ratio': round(cardinality_ratio, 4),
            'top_value': top_value,
            'top_value_count': top_value_count,
            'top_value_pct': round(top_value_pct, 2),
            'top_n_coverage': round(top_n_coverage, 2),
            'repetition_score': round(repetition_score, 2)
        })
    
    df = pd.DataFrame(results)
    
    # Sort by repetition score (best RLE candidates first)
    df = df.sort_values('repetition_score', ascending=False).reset_index(drop=True)
    
    print(f"\n✓ Analysis complete!")
    print(f"\nTop columns by repetition score (best RLE candidates):")
    for idx, row in df.head(5).iterrows():
        print(f"  {idx+1}. {row['column_name']}: score={row['repetition_score']}, "
              f"top_value_pct={row['top_value_pct']}%, ndv={row['ndv']:,}")
    
    return df


def get_value_frequency_details(duckrun_con, table_name: str, column_name: str, 
                                limit: int = 20) -> pd.DataFrame:
    """
    Get detailed value frequency distribution for a specific column.
    
    Shows the most frequent values and their counts/percentages.
    Useful for understanding repetition patterns that drive RLE compression.
    
    Args:
        duckrun_con: Duckrun connection (from duckrun.connect())
        table_name: Name of the table to analyze
        column_name: Name of the column to analyze
        limit: Maximum number of values to return (default: 20)
    
    Returns:
        DataFrame with columns:
        - value: The distinct value
        - count: Number of occurrences
        - percentage: Percentage of total rows
        - cumulative_pct: Cumulative percentage
    """
    con = duckrun_con.con  # Get underlying DuckDB connection
    from_clause = table_name
    
    # Get total row count
    total_rows = con.sql(f"SELECT COUNT(*) FROM {from_clause}").fetchone()[0]
    
    # Get value frequencies
    query = f"""
    WITH value_counts AS (
        SELECT 
            {column_name} as value,
            COUNT(*) as count,
            COUNT(*) * 100.0 / {total_rows} as percentage
        FROM {from_clause}
        WHERE {column_name} IS NOT NULL
        GROUP BY {column_name}
        ORDER BY count DESC
        LIMIT {limit}
    )
    SELECT 
        value,
        count,
        percentage,
        SUM(percentage) OVER (ORDER BY count DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumulative_pct
    FROM value_counts
    ORDER BY count DESC
    """
    
    df = con.sql(query).df()
    
    # Round percentages
    if not df.empty:
        df['percentage'] = df['percentage'].round(2)
        df['cumulative_pct'] = df['cumulative_pct'].round(2)
    
    return df


def find_optimal_sort_order(duckrun_con, table_name: str, 
                           max_combinations: int = 10) -> pd.DataFrame:
    """
    Determine optimal sort order using V-Order-like logic: pure compression testing.
    
    This mimics how VertiPaq/V-Order actually works:
    1. Calculate cardinality for each column
    2. Test different sort orderings
    3. Measure actual RLE run counts for each ordering
    4. Pick the ordering with best overall compression (fewest total runs)
    
    NO semantic understanding, NO query pattern assumptions.
    Pure mechanical testing of compression effectiveness.
    
    Args:
        duckrun_con: Duckrun connection (from duckrun.connect())
        table_name: Name of the table to analyze
        max_combinations: Maximum sort orderings to test (default: 10)
    
    Returns:
        DataFrame with tested orderings ranked by compression effectiveness
    """
    from itertools import permutations
    
    con = duckrun_con.con  # Get underlying DuckDB connection
    from_clause = table_name
    
    # Get column names and cardinalities
    print("Step 1: Analyzing column cardinalities...")
    schema_info = con.sql(f"""
        SELECT column_name, column_type
        FROM (DESCRIBE SELECT * FROM {from_clause})
    """).df()
    
    total_rows = con.sql(f"SELECT COUNT(*) FROM {from_clause}").fetchone()[0]
    
    # Calculate NDV for each column
    cardinality_map = {}
    for _, row in schema_info.iterrows():
        col = row['column_name']
        ndv = con.sql(f"SELECT COUNT(DISTINCT {col}) FROM {from_clause}").fetchone()[0]
        cardinality_ratio = ndv / total_rows
        cardinality_map[col] = {'ndv': ndv, 'ratio': cardinality_ratio}
        print(f"  {col}: {ndv:,} distinct ({cardinality_ratio*100:.4f}%)")
    
    # Filter to low-cardinality columns only (< 1% cardinality)
    # High cardinality columns won't benefit from reordering
    low_card_cols = [col for col, stats in cardinality_map.items() 
                     if stats['ratio'] < 0.01]
    
    print(f"\nStep 2: Testing sort orderings for {len(low_card_cols)} low-cardinality columns...")
    print(f"Columns to test: {', '.join(low_card_cols)}")
    
    if len(low_card_cols) < 2:
        print("Not enough columns to test different orderings!")
        return pd.DataFrame()
    
    # Generate candidate orderings
    # Start with cardinality-based orderings
    sorted_by_card = sorted(low_card_cols, key=lambda c: cardinality_map[c]['ndv'])
    
    test_orderings = [
        sorted_by_card,  # Lowest cardinality first
        sorted_by_card[::-1],  # Highest cardinality first
    ]
    
    # Add some permutations of top 3 columns
    if len(low_card_cols) >= 3:
        for perm in permutations(sorted_by_card[:3]):
            if list(perm) not in test_orderings:
                test_orderings.append(list(perm))
            if len(test_orderings) >= max_combinations:
                break
    
    # Test each ordering by calculating actual RLE runs
    print(f"\nStep 3: Testing {len(test_orderings)} different orderings...")
    results = []
    
    for idx, ordering in enumerate(test_orderings, 1):
        print(f"\n[{idx}/{len(test_orderings)}] Testing: {' → '.join(ordering)}")
        
        # Calculate RLE runs for each column with this ordering
        # We'll sort the data by the ordering and count runs
        order_clause = ', '.join(ordering)
        
        column_rle = {}
        for col in schema_info['column_name']:
            # Count runs: a new run starts when value changes
            rle_query = f"""
            WITH sorted_data AS (
                SELECT 
                    {col},
                    ROW_NUMBER() OVER (ORDER BY {order_clause}) as rn
                FROM {from_clause}
            ),
            with_prev AS (
                SELECT 
                    {col},
                    LAG({col}) OVER (ORDER BY rn) as prev_val
                FROM sorted_data
            )
            SELECT COUNT(*) as runs
            FROM with_prev
            WHERE prev_val IS NULL OR {col} != prev_val OR {col} IS NULL OR prev_val IS NULL
            """
            
            runs = con.sql(rle_query).fetchone()[0]
            column_rle[col] = runs
            print(f"    {col}: {runs:,} runs")
        
        total_runs = sum(column_rle.values())
        print(f"    TOTAL: {total_runs:,} runs")
        
        results.append({
            'sort_order': ' → '.join(ordering),
            'total_runs': total_runs,
            'compression_score': total_rows / total_runs,  # Higher = better compression
            **column_rle
        })
    
    # Create results DataFrame
    df = pd.DataFrame(results)
    df = df.sort_values('total_runs').reset_index(drop=True)
    
    print("\n" + "=" * 80)
    print("RESULTS: Best to Worst Compression")
    print("=" * 80)
    
    for idx, row in df.iterrows():
        print(f"\n{idx + 1}. {row['sort_order']}")
        print(f"   Total runs: {row['total_runs']:,}")
        print(f"   Compression score: {row['compression_score']:.2f}x")
        if idx == 0:
            print("   ⭐ BEST COMPRESSION")
    
    print("\n" + "=" * 80)
    print("CONCLUSION")
    print("=" * 80)
    best = df.iloc[0]
    print(f"\nOptimal sort order: {best['sort_order']}")
    print(f"This ordering achieves the fewest total RLE runs ({best['total_runs']:,})")
    print(f"\nThis is how V-Order actually works:")
    print("✓ No query pattern assumptions")
    print("✓ No semantic understanding") 
    print("✓ Pure compression effectiveness testing")
    print("✓ Mechanical optimization based on data patterns")
    
    return df


# Example usage:
#
# import duckrun
#
# con = duckrun.connect('workspace/lakehouse.lakehouse')
#
# # Get RLE statistics:
# stats_df = con.get_rle_stats('my_table', top_n_values=10)
# print(stats_df)
#
# # Detailed frequency distribution for a specific column:
# freq_df = con.get_value_frequency('my_table', 'status_column', limit=20)
# print(freq_df)
#
# # Find optimal sort order (V-Order simulation):
# optimal_df = con.find_optimal_sort_order('my_table', max_combinations=10)
# print(optimal_df)
