"""
ChalkML REDACT Engine
=====================
Privacy-preserving data transformations with mathematical guarantees.

Core Concepts:
- Differential Privacy: ε-DP via Laplace mechanism
- k-Anonymity: Group-based anonymization
- Data Masking: Tokenization, hashing, encryption
- Information Loss: Quantifiable privacy-utility tradeoff

Privacy = Anonymity + Noise + Suppression
"""

import json
import hashlib
import secrets
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Set
from datetime import datetime
import pandas as pd
import numpy as np
from scipy.stats import laplace


class DifferentialPrivacy:
    """
    Implements ε-differential privacy via Laplace mechanism.
    
    Definition: A mechanism M satisfies ε-DP if for all datasets D1, D2 
    differing by one record:
        P(M(D1) ∈ S) ≤ exp(ε) * P(M(D2) ∈ S)
    """
    
    def __init__(self, epsilon: float = 1.0):
        """
        Args:
            epsilon: Privacy budget (smaller = more private)
                     ε = 0.1 (very private), ε = 1.0 (moderate), ε = 10 (weak)
        """
        self.epsilon = epsilon
        
    def add_laplace_noise(
        self,
        value: float,
        sensitivity: float
    ) -> float:
        """
        Add Laplace noise to achieve ε-DP.
        
        Laplace(μ, b) where b = Δf/ε
        Δf = sensitivity (max change from adding/removing one record)
        
        Args:
            value: True value
            sensitivity: Global sensitivity of the query
            
        Returns:
            Noisy value satisfying ε-DP
        """
        scale = sensitivity / self.epsilon
        noise = np.random.laplace(0, scale)
        return value + noise
    
    def noisy_count(self, count: int) -> int:
        """
        Return differentially private count.
        Sensitivity = 1 (one record changes count by at most 1)
        """
        noisy = self.add_laplace_noise(float(count), sensitivity=1.0)
        return max(0, int(round(noisy)))  # Counts must be non-negative
    
    def noisy_mean(self, values: np.ndarray, value_range: Tuple[float, float]) -> float:
        """
        Return differentially private mean.
        Sensitivity = (max - min) / n
        """
        n = len(values)
        if n == 0:
            return 0.0
        
        true_mean = np.mean(values)
        min_val, max_val = value_range
        sensitivity = (max_val - min_val) / n
        
        return self.add_laplace_noise(true_mean, sensitivity)
    
    def noisy_sum(self, values: np.ndarray, value_range: Tuple[float, float]) -> float:
        """
        Return differentially private sum.
        Sensitivity = max - min (one record contributes at most this)
        """
        true_sum = np.sum(values)
        min_val, max_val = value_range
        sensitivity = max_val - min_val
        
        return self.add_laplace_noise(true_sum, sensitivity)


class KAnonymizer:
    """
    Implements k-anonymity: Each record is indistinguishable from at least k-1 others.
    
    Quasi-identifiers (QI): Attributes that may identify individuals when combined
    (e.g., age, zipcode, gender)
    """
    
    def __init__(self, k: int = 5):
        """
        Args:
            k: Anonymity parameter (minimum group size)
        """
        self.k = k
        
    def generalize_numeric(
        self,
        series: pd.Series,
        bins: int = 10
    ) -> pd.Series:
        """
        Generalize numeric column into ranges.
        
        Example: Age 27 -> "25-30"
        """
        try:
            # Create equal-width bins
            series_numeric = pd.to_numeric(series, errors='coerce')
            min_val, max_val = series_numeric.min(), series_numeric.max()
            
            if pd.isna(min_val) or pd.isna(max_val):
                return series
            
            bin_edges = np.linspace(min_val, max_val, bins + 1)
            labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)]
            
            generalized = pd.cut(series_numeric, bins=bin_edges, labels=labels, include_lowest=True)
            return generalized.astype(str)
        except Exception:
            return series
    
    def generalize_categorical(
        self,
        series: pd.Series,
        hierarchy: Optional[Dict[str, str]] = None
    ) -> pd.Series:
        """
        Generalize categorical column using hierarchy.
        
        Example: 
            City "Cambridge" -> State "MA" -> Region "Northeast"
        """
        if hierarchy is None:
            # Default: suppress rare categories
            value_counts = series.value_counts()
            common_values = value_counts[value_counts >= self.k].index
            return series.apply(lambda x: x if x in common_values else "*")
        else:
            return series.map(hierarchy).fillna(series)
    
    def suppress_rare_groups(
        self,
        df: pd.DataFrame,
        quasi_identifiers: List[str]
    ) -> pd.DataFrame:
        """
        Suppress records in groups smaller than k.
        
        Groups are defined by unique combinations of quasi-identifiers.
        """
        df_copy = df.copy()
        
        # Group by quasi-identifiers
        group_sizes = df_copy.groupby(quasi_identifiers).size()
        
        # Find groups with size < k
        small_groups = group_sizes[group_sizes < self.k].index
        
        # Suppress (remove or generalize further)
        for group_values in small_groups:
            mask = True
            for qi, val in zip(quasi_identifiers, group_values):
                mask &= (df_copy[qi] == val)
            
            # Option 1: Remove these records
            # df_copy = df_copy[~mask]
            
            # Option 2: Suppress to "*"
            df_copy.loc[mask, quasi_identifiers] = "*"
        
        return df_copy
    
    def achieve_k_anonymity(
        self,
        df: pd.DataFrame,
        quasi_identifiers: List[str],
        numeric_bins: int = 5
    ) -> Tuple[pd.DataFrame, Dict[str, Any]]:
        """
        Transform dataset to achieve k-anonymity.
        
        Returns:
            Anonymized dataframe and metadata
        """
        df_anon = df.copy()
        metadata = {
            "k": self.k,
            "quasi_identifiers": quasi_identifiers,
            "transformations": {}
        }
        
        for qi in quasi_identifiers:
            if qi not in df.columns:
                continue
            
            if pd.api.types.is_numeric_dtype(df[qi]):
                df_anon[qi] = self.generalize_numeric(df[qi], bins=numeric_bins)
                metadata["transformations"][qi] = f"numeric_generalization_{numeric_bins}_bins"
            else:
                df_anon[qi] = self.generalize_categorical(df[qi])
                metadata["transformations"][qi] = "categorical_generalization"
        
        # Suppress small groups
        df_anon = self.suppress_rare_groups(df_anon, quasi_identifiers)
        
        # Check if k-anonymity achieved
        group_sizes = df_anon.groupby(quasi_identifiers).size()
        min_group_size = group_sizes.min()
        metadata["achieved_k"] = int(min_group_size)
        metadata["achieved"] = min_group_size >= self.k
        
        return df_anon, metadata


class DataMasking:
    """
    Data masking techniques: tokenization, hashing, encryption.
    """
    
    def __init__(self, secret_key: Optional[str] = None):
        """
        Args:
            secret_key: Secret for deterministic masking (optional)
        """
        self.secret_key = secret_key or secrets.token_hex(32)
        
    def hash_column(
        self,
        series: pd.Series,
        algorithm: str = 'sha256'
    ) -> pd.Series:
        """
        One-way hash (irreversible).
        
        Use for: Identifiers that need to be consistent but unreadable
        """
        def hash_value(val):
            if pd.isna(val):
                return val
            hash_input = f"{self.secret_key}{val}".encode()
            if algorithm == 'sha256':
                return hashlib.sha256(hash_input).hexdigest()[:16]
            elif algorithm == 'md5':
                return hashlib.md5(hash_input).hexdigest()[:12]
            else:
                return hashlib.blake2b(hash_input, digest_size=8).hexdigest()
        
        return series.apply(hash_value)
    
    def tokenize_column(
        self,
        series: pd.Series
    ) -> Tuple[pd.Series, Dict[Any, str]]:
        """
        Replace values with random tokens (reversible with mapping).
        
        Use for: Preserving relationships while hiding actual values
        """
        unique_values = series.unique()
        token_map = {val: f"TOKEN_{secrets.token_hex(4).upper()}" for val in unique_values if pd.notna(val)}
        
        tokenized = series.map(token_map).fillna(series)
        return tokenized, token_map
    
    def mask_partial(
        self,
        series: pd.Series,
        keep_prefix: int = 0,
        keep_suffix: int = 0,
        mask_char: str = '*'
    ) -> pd.Series:
        """
        Partially mask strings (e.g., credit cards, SSN).
        
        Example: "123-45-6789" -> "***-**-6789" (keep_suffix=4)
        """
        def mask_value(val):
            if pd.isna(val):
                return val
            val_str = str(val)
            if len(val_str) <= keep_prefix + keep_suffix:
                return val_str
            
            prefix = val_str[:keep_prefix] if keep_prefix > 0 else ""
            suffix = val_str[-keep_suffix:] if keep_suffix > 0 else ""
            middle_len = len(val_str) - keep_prefix - keep_suffix
            middle = mask_char * middle_len
            
            return f"{prefix}{middle}{suffix}"
        
        return series.apply(mask_value)
    
    def add_synthetic_noise(
        self,
        series: pd.Series,
        noise_level: float = 0.1
    ) -> pd.Series:
        """
        Add random noise to numeric data.
        
        noise_level: Fraction of standard deviation
        """
        if not pd.api.types.is_numeric_dtype(series):
            return series
        
        std = series.std()
        noise = np.random.normal(0, noise_level * std, size=len(series))
        return series + noise


class RedactEngine:
    """
    High-level engine for privacy-preserving transformations.
    """
    
    def __init__(self, workspace_path: Optional[str] = None):
        if workspace_path is None:
            workspace_path = Path.cwd()
        self.workspace_path = Path(workspace_path)
        self.chalkml_dir = self.workspace_path / ".chalkml"
        self.redact_dir = self.chalkml_dir / "redact_reports"
        self.redact_dir.mkdir(parents=True, exist_ok=True)
        
    def apply_differential_privacy(
        self,
        df: pd.DataFrame,
        epsilon: float,
        numeric_columns: Optional[List[str]] = None,
        count_columns: Optional[List[str]] = None
    ) -> Tuple[pd.DataFrame, Dict[str, Any]]:
        """
        Apply differential privacy to specified columns.
        """
        dp = DifferentialPrivacy(epsilon=epsilon)
        df_private = df.copy()
        
        if numeric_columns is None:
            numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
        
        metadata = {
            "epsilon": epsilon,
            "mechanism": "Laplace",
            "transformations": {}
        }
        
        for col in numeric_columns:
            if col not in df.columns:
                continue
            
            values = df[col].dropna().values
            if len(values) == 0:
                continue
            
            value_range = (values.min(), values.max())
            
            # Apply DP to each value (row-level privacy)
            df_private[col] = df[col].apply(
                lambda x: dp.add_laplace_noise(x, sensitivity=value_range[1] - value_range[0]) if pd.notna(x) else x
            )
            
            metadata["transformations"][col] = {
                "method": "laplace_noise",
                "sensitivity": float(value_range[1] - value_range[0]),
                "scale": float((value_range[1] - value_range[0]) / epsilon)
            }
        
        return df_private, metadata
    
    def apply_k_anonymity(
        self,
        df: pd.DataFrame,
        k: int,
        quasi_identifiers: List[str]
    ) -> Tuple[pd.DataFrame, Dict[str, Any]]:
        """
        Apply k-anonymity transformation.
        """
        anonymizer = KAnonymizer(k=k)
        df_anon, metadata = anonymizer.achieve_k_anonymity(df, quasi_identifiers)
        return df_anon, metadata
    
    def apply_masking(
        self,
        df: pd.DataFrame,
        hash_columns: Optional[List[str]] = None,
        tokenize_columns: Optional[List[str]] = None,
        mask_columns: Optional[Dict[str, Dict]] = None
    ) -> Tuple[pd.DataFrame, Dict[str, Any]]:
        """
        Apply various masking techniques.
        
        Args:
            hash_columns: Columns to one-way hash
            tokenize_columns: Columns to tokenize (reversible)
            mask_columns: Dict of {col: {keep_prefix, keep_suffix}}
        """
        masker = DataMasking()
        df_masked = df.copy()
        
        metadata = {
            "transformations": {},
            "token_maps": {}
        }
        
        if hash_columns:
            for col in hash_columns:
                if col in df.columns:
                    df_masked[col] = masker.hash_column(df[col])
                    metadata["transformations"][col] = "sha256_hash"
        
        if tokenize_columns:
            for col in tokenize_columns:
                if col in df.columns:
                    df_masked[col], token_map = masker.tokenize_column(df[col])
                    metadata["transformations"][col] = "tokenization"
                    metadata["token_maps"][col] = token_map
        
        if mask_columns:
            for col, params in mask_columns.items():
                if col in df.columns:
                    df_masked[col] = masker.mask_partial(
                        df[col],
                        keep_prefix=params.get('keep_prefix', 0),
                        keep_suffix=params.get('keep_suffix', 0)
                    )
                    metadata["transformations"][col] = f"partial_mask_{params}"
        
        return df_masked, metadata
    
    def redact_file(
        self,
        input_file: str,
        output_file: str,
        method: str,
        **kwargs
    ) -> Dict[str, Any]:
        """
        Apply privacy transformation to file.
        
        Args:
            method: 'dp', 'k-anonymity', 'masking'
        """
        df = pd.read_csv(input_file)
        
        if method == 'dp':
            epsilon = kwargs.get('epsilon', 1.0)
            df_redacted, metadata = self.apply_differential_privacy(df, epsilon)
        
        elif method == 'k-anonymity':
            k = kwargs.get('k', 5)
            quasi_identifiers = kwargs.get('quasi_identifiers', [])
            df_redacted, metadata = self.apply_k_anonymity(df, k, quasi_identifiers)
        
        elif method == 'masking':
            df_redacted, metadata = self.apply_masking(
                df,
                hash_columns=kwargs.get('hash_columns'),
                tokenize_columns=kwargs.get('tokenize_columns'),
                mask_columns=kwargs.get('mask_columns')
            )
        
        else:
            raise ValueError(f"Unknown method: {method}")
        
        # Save redacted data
        df_redacted.to_csv(output_file, index=False)
        
        # Save report
        report_path = self.redact_dir / f"redact_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(report_path, 'w') as f:
            json.dump({
                "input_file": input_file,
                "output_file": output_file,
                "method": method,
                "timestamp": datetime.now().isoformat(),
                "metadata": metadata
            }, f, indent=2)
        
        return {
            "success": True,
            "input_rows": len(df),
            "output_rows": len(df_redacted),
            "metadata": metadata,
            "report_path": str(report_path)
        }


def get_redact_engine(workspace_path: Optional[str] = None) -> RedactEngine:
    """Factory function to get RedactEngine instance."""
    return RedactEngine(workspace_path=workspace_path)
