"""
ChalkML RELEVANCE Engine
========================
Statistical feature selection via mutual information and hypothesis testing.

Core Concept:
- Mutual Information: I(X;Y) measures statistical dependency
- Hypothesis Testing: Chi-squared, F-test, ANOVA for significance
- Redundancy Elimination: Remove correlated features that add no information
- Optimal Subset: Maximize I(X;Y) while minimizing |X|

Relevance = Information Gain + Statistical Significance - Redundancy
"""

import json
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Set
from datetime import datetime
import pandas as pd
import numpy as np
from scipy.stats import chi2_contingency, f_oneway, pearsonr, spearmanr
from scipy.special import rel_entr
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
from sklearn.preprocessing import KBinsDiscretizer


class RelevanceAnalyzer:
    """
    Core analyzer for statistical feature relevance.
    Computes mutual information, statistical tests, and redundancy metrics.
    """
    
    def __init__(self, alpha: float = 0.05, n_bins: int = 10):
        """
        Args:
            alpha: Significance level for hypothesis tests (default: 0.05)
            n_bins: Number of bins for discretization (MI estimation)
        """
        self.alpha = alpha
        self.n_bins = n_bins
        
    def mutual_information(
        self,
        X: pd.DataFrame,
        y: pd.Series,
        task: str = 'auto'
    ) -> Dict[str, float]:
        """
        Compute mutual information I(X_i; Y) for each feature.
        
        I(X;Y) = Σ p(x,y) log(p(x,y) / (p(x)p(y)))
        
        Args:
            X: Feature matrix
            y: Target variable
            task: 'classification', 'regression', or 'auto'
            
        Returns:
            Dict mapping feature names to MI scores
        """
        if task == 'auto':
            task = 'classification' if len(y.unique()) < 20 else 'regression'
        
        if task == 'classification':
            mi_scores = mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3)
        else:
            mi_scores = mutual_info_regression(X, y, n_neighbors=3)
        
        return {col: float(score) for col, score in zip(X.columns, mi_scores)}
    
    def chi_squared_test(
        self,
        X: pd.DataFrame,
        y: pd.Series
    ) -> Dict[str, Tuple[float, float]]:
        """
        Chi-squared test for categorical features vs categorical target.
        
        H0: Feature and target are independent
        H1: Feature and target are dependent
        
        Returns:
            Dict mapping feature names to (chi2_statistic, p_value)
        """
        results = {}
        
        for col in X.columns:
            try:
                # Create contingency table
                contingency = pd.crosstab(X[col], y)
                chi2_stat, p_value, dof, expected = chi2_contingency(contingency)
                results[col] = (float(chi2_stat), float(p_value))
            except Exception as e:
                results[col] = (0.0, 1.0)  # Not applicable or error
        
        return results
    
    def f_test(
        self,
        X: pd.DataFrame,
        y: pd.Series
    ) -> Dict[str, Tuple[float, float]]:
        """
        ANOVA F-test for numerical features vs categorical target.
        
        H0: Feature means are equal across target classes
        H1: At least one mean differs
        
        Returns:
            Dict mapping feature names to (f_statistic, p_value)
        """
        results = {}
        
        if len(y.unique()) < 2:
            return {col: (0.0, 1.0) for col in X.columns}
        
        for col in X.columns:
            try:
                groups = [X[y == label][col].dropna() for label in y.unique()]
                groups = [g for g in groups if len(g) > 0]  # Remove empty groups
                
                if len(groups) < 2:
                    results[col] = (0.0, 1.0)
                else:
                    f_stat, p_value = f_oneway(*groups)
                    results[col] = (float(f_stat), float(p_value))
            except Exception as e:
                results[col] = (0.0, 1.0)
        
        return results
    
    def correlation_test(
        self,
        X: pd.DataFrame,
        y: pd.Series,
        method: str = 'pearson'
    ) -> Dict[str, Tuple[float, float]]:
        """
        Correlation test (Pearson or Spearman) for numerical features.
        
        Returns:
            Dict mapping feature names to (correlation, p_value)
        """
        results = {}
        
        for col in X.columns:
            try:
                if method == 'pearson':
                    corr, p_value = pearsonr(X[col].dropna(), y.loc[X[col].dropna().index])
                else:
                    corr, p_value = spearmanr(X[col].dropna(), y.loc[X[col].dropna().index])
                results[col] = (float(corr), float(p_value))
            except Exception as e:
                results[col] = (0.0, 1.0)
        
        return results
    
    def compute_redundancy(
        self,
        X: pd.DataFrame,
        threshold: float = 0.8
    ) -> Dict[str, List[str]]:
        """
        Find redundant features via pairwise correlation.
        
        Two features are redundant if |ρ(X_i, X_j)| > threshold
        
        Returns:
            Dict mapping each feature to list of correlated features
        """
        corr_matrix = X.corr().abs()
        redundancy_map = {}
        
        for col in X.columns:
            correlated = corr_matrix[col][corr_matrix[col] > threshold].index.tolist()
            correlated.remove(col)  # Remove self
            redundancy_map[col] = correlated
        
        return redundancy_map
    
    def information_gain_ratio(
        self,
        X: pd.DataFrame,
        y: pd.Series
    ) -> Dict[str, float]:
        """
        Information Gain Ratio (normalized by feature entropy).
        
        IGR(X,Y) = IG(X,Y) / H(X)
        
        where IG(X,Y) = H(Y) - H(Y|X)
        """
        mi_scores = self.mutual_information(X, y)
        
        # Compute feature entropies
        igr_scores = {}
        for col in X.columns:
            # Discretize continuous features
            if X[col].dtype in [np.float64, np.float32]:
                discretizer = KBinsDiscretizer(n_bins=self.n_bins, encode='ordinal', strategy='quantile')
                X_discrete = discretizer.fit_transform(X[[col]])
                feature_entropy = self._entropy(X_discrete.flatten())
            else:
                feature_entropy = self._entropy(X[col])
            
            if feature_entropy > 0:
                igr_scores[col] = mi_scores[col] / feature_entropy
            else:
                igr_scores[col] = 0.0
        
        return igr_scores
    
    def _entropy(self, x: np.ndarray) -> float:
        """Compute Shannon entropy: H(X) = -Σ p(x) log p(x)"""
        _, counts = np.unique(x, return_counts=True)
        probs = counts / len(x)
        return -np.sum(probs * np.log2(probs + 1e-10))


class RelevanceEngine:
    """
    High-level engine for feature selection based on relevance analysis.
    """
    
    def __init__(self, workspace_path: Optional[str] = None):
        if workspace_path is None:
            workspace_path = Path.cwd()
        self.workspace_path = Path(workspace_path)
        self.chalkml_dir = self.workspace_path / ".chalkml"
        self.relevance_dir = self.chalkml_dir / "relevance_reports"
        self.relevance_dir.mkdir(parents=True, exist_ok=True)
        
    def analyze_relevance(
        self,
        df: pd.DataFrame,
        target_col: str,
        alpha: float = 0.05,
        task: str = 'auto'
    ) -> Dict[str, Any]:
        """
        Comprehensive relevance analysis for all features.
        
        Returns:
            Dict with:
                - mutual_information: I(X;Y) scores
                - statistical_tests: p-values from hypothesis tests
                - redundancy: Correlation-based redundancy map
                - information_gain_ratio: IGR scores
                - selected_features: Recommended feature subset
        """
        analyzer = RelevanceAnalyzer(alpha=alpha)
        
        # Separate features and target
        X = df.drop(columns=[target_col])
        y = df[target_col]
        
        # Remove non-numeric columns for now (can add categorical handling later)
        numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
        X_numeric = X[numeric_cols]
        
        if len(X_numeric.columns) == 0:
            raise ValueError("No numeric features found for analysis")
        
        # Compute all metrics
        mi_scores = analyzer.mutual_information(X_numeric, y, task=task)
        
        # Choose appropriate statistical test based on task
        if task == 'classification' or len(y.unique()) < 20:
            stat_tests = analyzer.f_test(X_numeric, y)
            test_name = "f_test"
        else:
            stat_tests = analyzer.correlation_test(X_numeric, y, method='pearson')
            test_name = "correlation"
        
        redundancy_map = analyzer.compute_redundancy(X_numeric, threshold=0.8)
        igr_scores = analyzer.information_gain_ratio(X_numeric, y)
        
        # Feature selection: Keep features with significant p-values and high MI
        selected = []
        for col in X_numeric.columns:
            _, p_value = stat_tests[col]
            mi_score = mi_scores[col]
            
            # Keep if statistically significant AND has non-zero MI
            if p_value < alpha and mi_score > 0.01:
                selected.append(col)
        
        # Remove redundant features (keep highest MI from each group)
        final_selected = self._remove_redundant(selected, mi_scores, redundancy_map)
        
        return {
            "mutual_information": mi_scores,
            "statistical_tests": {test_name: stat_tests},
            "redundancy": redundancy_map,
            "information_gain_ratio": igr_scores,
            "selected_features": final_selected,
            "original_count": len(X_numeric.columns),
            "selected_count": len(final_selected),
            "reduction_pct": (1 - len(final_selected) / len(X_numeric.columns)) * 100
        }
    
    def _remove_redundant(
        self,
        features: List[str],
        mi_scores: Dict[str, float],
        redundancy_map: Dict[str, List[str]]
    ) -> List[str]:
        """
        Remove redundant features, keeping highest MI from each group.
        """
        selected = set(features)
        removed = set()
        
        for feat in features:
            if feat in removed:
                continue
            
            # Find redundant features
            redundant = set(redundancy_map[feat]) & selected
            
            for red_feat in redundant:
                # Keep feature with higher MI
                if mi_scores[red_feat] < mi_scores[feat]:
                    selected.discard(red_feat)
                    removed.add(red_feat)
        
        return sorted(list(selected))
    
    def select_features_file(
        self,
        input_file: str,
        output_file: str,
        target_col: str,
        alpha: float = 0.05,
        task: str = 'auto'
    ) -> Dict[str, Any]:
        """
        Analyze relevance and save selected features to new file.
        """
        # Load data
        df = pd.read_csv(input_file)
        
        if target_col not in df.columns:
            raise ValueError(f"Target column '{target_col}' not found in data")
        
        # Analyze
        analysis = self.analyze_relevance(df, target_col, alpha=alpha, task=task)
        
        # Save selected features + target
        selected_cols = analysis['selected_features'] + [target_col]
        df_selected = df[selected_cols]
        df_selected.to_csv(output_file, index=False)
        
        # Save analysis report
        report_path = self.relevance_dir / f"relevance_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(report_path, 'w') as f:
            json.dump({
                "input_file": input_file,
                "output_file": output_file,
                "target": target_col,
                "alpha": alpha,
                "task": task,
                "timestamp": datetime.now().isoformat(),
                "analysis": {
                    "original_features": analysis['original_count'],
                    "selected_features": analysis['selected_count'],
                    "reduction_pct": analysis['reduction_pct'],
                    "features": analysis['selected_features']
                }
            }, f, indent=2)
        
        return analysis
    
    def analyze_file(
        self,
        input_file: str,
        target_col: str,
        alpha: float = 0.05,
        task: str = 'auto'
    ) -> Dict[str, Any]:
        """
        Analyze relevance without creating output file (analysis only).
        """
        df = pd.read_csv(input_file)
        
        if target_col not in df.columns:
            raise ValueError(f"Target column '{target_col}' not found in data")
        
        return self.analyze_relevance(df, target_col, alpha=alpha, task=task)


def get_relevance_engine(workspace_path: Optional[str] = None) -> RelevanceEngine:
    """Factory function to get RelevanceEngine instance."""
    return RelevanceEngine(workspace_path=workspace_path)
