"""Core ERICA class for clustering replicability analysis.

This module provides the main ERICA class that orchestrates the entire
clustering replicability analysis workflow.
"""

import os
import numpy as np
import pandas as pd
from typing import Union, List, Dict, Optional, Tuple
from datetime import datetime

from .clustering import (
    kmeans_clustering,
    agglomerative_clustering,
    iterative_clustering_subsampling,
)
from .metrics import compute_metrics_for_clam
from .data import prepare_samples_array, validate_dataset
from .utils import set_deterministic_mode, compute_config_hash


class ERICA:
    """Main class for ERICA (Evaluating Replicability via Iterative Clustering Assignments).
    
    This class provides a high-level interface for performing clustering replicability
    analysis using Monte Carlo subsampling methods.
    
    Parameters
    ----------
    data : np.ndarray or pd.DataFrame
        Input data matrix with shape (n_samples, n_features)
    k_range : list of int, optional
        Range of cluster numbers to evaluate, default [2, 3, 4, 5]
    n_iterations : int, optional
        Number of iterative clustering iterations (B), default 200
    train_percent : float, optional
        Proportion of data used for training subsample, default 0.8
    method : {'kmeans', 'agglomerative', 'both'}, optional
        Clustering method(s) to use, default 'both'
    linkages : list of str, optional
        Linkage methods for agglomerative clustering, default ['single', 'ward']
    random_seed : int, optional
        Random seed for reproducibility, default 123
    output_dir : str, optional
        Base directory for output files, default './erica_output'
    verbose : bool, optional
        Whether to print progress messages, default True
        
    Attributes
    ----------
    results_ : dict
        Dictionary containing all analysis results
    clam_matrices_ : dict
        Dictionary of CLAM matrices for each (k, method) combination
    metrics_ : dict
        Dictionary of computed metrics (CRI, WCRI, TWCRI)
        
    Examples
    --------
    >>> import numpy as np
    >>> from erica import ERICA
    >>> 
    >>> # Generate sample data
    >>> data = np.random.rand(100, 50)
    >>> 
    >>> # Run ERICA analysis
    >>> erica = ERICA(data=data, k_range=[2, 3, 4], n_iterations=100)
    >>> erica.run()
    >>> 
    >>> # Get results
    >>> results = erica.get_results()
    >>> metrics = erica.get_metrics()
    >>> 
    >>> # Get CLAM matrix for k=3, kmeans
    >>> clam = erica.get_clam_matrix(k=3, method='kmeans')
    """
    
    def __init__(
        self,
        data: Union[np.ndarray, pd.DataFrame],
        k_range: List[int] = None,
        n_iterations: int = 200,
        train_percent: float = 0.8,
        method: str = 'both',
        linkages: List[str] = None,
        random_seed: int = 123,
        output_dir: str = './erica_output',
        verbose: bool = True
    ):
        """Initialize ERICA analysis."""
        # Store configuration
        self.data = data
        self.k_range = k_range or [2, 3, 4, 5]
        self.n_iterations = n_iterations
        self.train_percent = train_percent
        self.method = method
        self.linkages = linkages or ['single', 'ward']
        self.random_seed = random_seed
        self.output_dir = output_dir
        self.verbose = verbose
        
        # Set deterministic mode
        set_deterministic_mode(random_seed)
        
        # Prepare data
        self.samples_array = prepare_samples_array(data)
        self.n_samples, self.n_features = self.samples_array.shape
        
        # Validate dataset
        validate_dataset(
            self.samples_array,
            min(self.k_range),
            self.train_percent
        )
        
        # Initialize results storage
        self.results_ = {}
        self.clam_matrices_ = {}
        self.metrics_ = {}
        self.output_folders_ = []
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        if self.verbose:
            print(f"ERICA initialized:")
            print(f"  Data shape: {self.samples_array.shape}")
            print(f"  K range: {self.k_range}")
            print(f"  Iterations: {self.n_iterations}")
            print(f"  Method: {self.method}")
            print(f"  Random seed: {self.random_seed}")
    
    def run(self) -> Dict:
        """Run the complete ERICA analysis.
        
        This method performs iterative clustering subsampling, clustering analysis,
        and generates CLAM matrices for all specified k values and methods.
        
        Returns
        -------
        dict
            Dictionary containing all analysis results including:
            - 'clam_matrices': CLAM matrices for each (k, method) combination
            - 'output_folders': List of output folder paths
            - 'config': Configuration parameters used
            
        Examples
        --------
        >>> erica = ERICA(data=my_data, k_range=[2, 3, 4])
        >>> results = erica.run()
        >>> print(results.keys())
        """
        if self.verbose:
            print(f"\nStarting ERICA analysis at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        
        # Create run-specific output directory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_dir = os.path.join(self.output_dir, f"erica_run_{timestamp}")
        os.makedirs(run_dir, exist_ok=True)
        
        # Step 1: Perform iterative clustering subsampling
        if self.verbose:
            print(f"\n[1/3] Performing iterative clustering subsampling...")
        
        train_size = int(self.n_samples * self.train_percent)
        subsamples_folder, indices_folder = iterative_clustering_subsampling(
            samples_array=self.samples_array,
            num_samples=self.n_samples,
            num_iterations=self.n_iterations,
            subsample_size_train=train_size,
            base_save_folder_str=run_dir,
            verbose=self.verbose
        )
        
        # Step 2: Run clustering for each k and method
        if self.verbose:
            print(f"\n[2/3] Running clustering analysis...")
        
        methods_to_run = []
        if self.method == 'both':
            methods_to_run = ['kmeans', 'agglomerative']
        else:
            methods_to_run = [self.method]
        
        for k in self.k_range:
            for method_name in methods_to_run:
                if method_name == 'kmeans':
                    if self.verbose:
                        print(f"  Running K-Means clustering for k={k}...")
                    
                    result = kmeans_clustering(
                        samples_array=self.samples_array,
                        k=k,
                        n_iterations=self.n_iterations,
                        indices_folder=indices_folder,
                        output_dir=run_dir,
                        verbose=self.verbose
                    )
                    
                    self.clam_matrices_[(k, 'kmeans')] = result['clam_matrix']
                    self.results_[(k, 'kmeans')] = result
                    
                elif method_name == 'agglomerative':
                    for linkage in self.linkages:
                        if self.verbose:
                            print(f"  Running Agglomerative clustering (linkage={linkage}) for k={k}...")
                        
                        result = agglomerative_clustering(
                            samples_array=self.samples_array,
                            k=k,
                            linkage=linkage,
                            n_iterations=self.n_iterations,
                            indices_folder=indices_folder,
                            output_dir=run_dir,
                            verbose=self.verbose
                        )
                        
                        method_key = f'agglomerative_{linkage}'
                        self.clam_matrices_[(k, method_key)] = result['clam_matrix']
                        self.results_[(k, method_key)] = result
        
        # Step 3: Compute metrics
        if self.verbose:
            print(f"\n[3/3] Computing metrics...")
        
        self.metrics_ = self._compute_all_metrics()
        
        # Store output information
        self.output_folders_.append(run_dir)
        
        if self.verbose:
            print(f"\nERICA analysis complete!")
            print(f"Results saved to: {run_dir}")
            print(f"Total configurations analyzed: {len(self.results_)}")
        
        return self.get_results()
    
    def _compute_all_metrics(self) -> Dict:
        """Compute CRI, WCRI, TWCRI metrics for all results.
        
        Returns
        -------
        dict
            Dictionary with metrics for each k value
        """
        metrics_by_k = {}
        
        for (k, method_name), result in self.results_.items():
            clam_matrix = result['clam_matrix']
            
            metrics = compute_metrics_for_clam(clam_matrix, k)
            
            if k not in metrics_by_k:
                metrics_by_k[k] = {}
            
            metrics_by_k[k][method_name] = metrics
        
        return metrics_by_k
    
    def get_results(self) -> Dict:
        """Get all analysis results.
        
        Returns
        -------
        dict
            Complete results dictionary containing:
            - 'clam_matrices': All CLAM matrices
            - 'metrics': All computed metrics
            - 'config': Configuration parameters
            - 'output_folders': List of output directories
        """
        return {
            'clam_matrices': self.clam_matrices_,
            'metrics': self.metrics_,
            'config': self._get_config_dict(),
            'output_folders': self.output_folders_,
            'results': self.results_
        }
    
    def get_clam_matrix(self, k: int, method: str = 'kmeans') -> Optional[np.ndarray]:
        """Get CLAM matrix for specific k and method.
        
        Parameters
        ----------
        k : int
            Number of clusters
        method : str, optional
            Clustering method name, default 'kmeans'
            For agglomerative clustering, use 'agglomerative_single' or 'agglomerative_ward'
            
        Returns
        -------
        np.ndarray or None
            CLAM matrix if available, None otherwise
        """
        return self.clam_matrices_.get((k, method))
    
    def get_metrics(self, k: Optional[int] = None) -> Dict:
        """Get computed metrics.
        
        Parameters
        ----------
        k : int, optional
            If specified, return metrics only for this k value
            If None, return all metrics
            
        Returns
        -------
        dict
            Metrics dictionary
        """
        if k is not None:
            return self.metrics_.get(k, {})
        return self.metrics_
    
    def plot_metrics(self, **kwargs):
        """Generate metrics plots.
        
        This method creates interactive plots showing CRI, WCRI, and TWCRI
        metrics across different k values.
        
        Parameters
        ----------
        **kwargs
            Additional arguments passed to plotting function
            
        Returns
        -------
        tuple
            (metrics_plot, optimal_k_plot) if plotly is available
            
        Raises
        ------
        ImportError
            If plotly is not installed
            
        Examples
        --------
        >>> erica = ERICA(data=my_data)
        >>> erica.run()
        >>> fig1, fig2 = erica.plot_metrics()
        >>> fig1.show()
        """
        try:
            from .plotting import create_metrics_plots
        except ImportError:
            raise ImportError(
                "Plotting requires plotly. Install with: pip install erica-clustering[plots]"
            )
        
        # Prepare metrics data for plotting
        metrics_data = self._prepare_metrics_for_plotting()
        
        return create_metrics_plots(metrics_data, **kwargs)
    
    def _prepare_metrics_for_plotting(self) -> Dict:
        """Prepare metrics data in format expected by plotting functions."""
        cri_vector = []
        wcri_vector = []
        twcri_vector = []
        k_values = []
        
        for k in sorted(set([key[0] for key in self.metrics_.keys()])):
            if k in self.metrics_:
                # Average metrics across all methods for this k
                k_metrics = self.metrics_[k]
                cri_vals = [m['CRI'] for m in k_metrics.values() if 'CRI' in m]
                wcri_vals = [m['WCRI'] for m in k_metrics.values() if 'WCRI' in m]
                twcri_vals = [m['TWCRI'] for m in k_metrics.values() if 'TWCRI' in m]
                
                if cri_vals:
                    k_values.append(k)
                    cri_vector.append(np.mean(cri_vals))
                    wcri_vector.append(np.mean(wcri_vals))
                    twcri_vector.append(np.mean(twcri_vals))
        
        return {
            'k_values': k_values,
            'CRI_vector': cri_vector,
            'WCRI_vector': wcri_vector,
            'TWCRI_vector': twcri_vector,
            'success': True
        }
    
    def save_results(self, filepath: str) -> None:
        """Save results to file.
        
        Parameters
        ----------
        filepath : str
            Path to save results (supports .npy, .csv, .json)
        """
        import json
        
        results = self.get_results()
        
        # Convert numpy arrays to lists for JSON serialization
        serializable_results = {}
        for key, value in results.items():
            if key == 'clam_matrices':
                serializable_results[key] = {
                    str(k): v.tolist() for k, v in value.items()
                }
            else:
                serializable_results[key] = value
        
        with open(filepath, 'w') as f:
            json.dump(serializable_results, f, indent=2)
        
        if self.verbose:
            print(f"Results saved to: {filepath}")
    
    def _get_config_dict(self) -> Dict:
        """Get configuration as dictionary."""
        return {
            'k_range': self.k_range,
            'n_iterations': self.n_iterations,
            'train_percent': self.train_percent,
            'method': self.method,
            'linkages': self.linkages,
            'random_seed': self.random_seed,
            'n_samples': self.n_samples,
            'n_features': self.n_features,
            'config_hash': compute_config_hash({
                'k_range': self.k_range,
                'n_iterations': self.n_iterations,
                'train_percent': self.train_percent,
                'method': self.method,
                'linkages': self.linkages,
                'random_seed': self.random_seed,
            })
        }
    
    def __repr__(self) -> str:
        """String representation of ERICA object."""
        return (
            f"ERICA(n_samples={self.n_samples}, n_features={self.n_features}, "
            f"k_range={self.k_range}, n_iterations={self.n_iterations}, "
            f"method='{self.method}')"
        )


