"""
ERICA Clustering Package - Example Usage
Author: Siamak Sorooshyari & Shawn Shirazi
Package: erica-clustering (https://pypi.org/project/erica-clustering/)

This example demonstrates the main features of the ERICA package for
evaluating clustering replicability using Monte Carlo subsampling.
"""

import numpy as np
import pandas as pd
from erica import ERICA

# Set random seed for reproducibility
np.random.seed(42)


def example_1_basic_usage():
    """
    Example 1: Basic ERICA Analysis
    --------------------------------
    Simple clustering replicability analysis with K-Means
    """
    print("="*60)
    print("Example 1: Basic ERICA Analysis")
    print("="*60)
    
    # Generate sample data (100 samples, 50 features)
    data = np.random.rand(100, 50)
    print(f"Data shape: {data.shape}")
    
    # Initialize ERICA with K-Means clustering
    erica = ERICA(
        data=data,
        k_range=[2, 3, 4, 5],  # Test different numbers of clusters
        n_iterations=200,       # Number of Monte Carlo iterations
        method='kmeans',        # Use K-Means clustering
        train_percent=0.8,      # 80% train, 20% test split
        random_seed=123         # For reproducibility
    )
    
    # Run the analysis
    print("\nRunning ERICA analysis...")
    results = erica.run()
    
    # Get replicability metrics
    metrics = erica.get_metrics()
    
    # Display results for each k
    print("\nReplicability Metrics:")
    print("-" * 60)
    for k in [2, 3, 4, 5]:
        cri = metrics[k]['kmeans']['CRI']
        wcri = metrics[k]['kmeans']['WCRI']
        twcri = metrics[k]['kmeans']['TWCRI']
        print(f"K={k}: CRI={cri:.3f}, WCRI={wcri:.3f}, TWCRI={twcri:.3f}")
    
    # Visualize metrics
    print("\nGenerating visualization plots...")
    fig1, fig2 = erica.plot_metrics()
    fig1.write_html("erica_metrics_plot.html")
    print("Plot saved as 'erica_metrics_plot.html'")
    
    return erica, metrics


def example_2_comparing_methods():
    """
    Example 2: Comparing K-Means vs Agglomerative Clustering
    --------------------------------------------------------
    Compare replicability of different clustering methods
    """
    print("\n" + "="*60)
    print("Example 2: Comparing Clustering Methods")
    print("="*60)
    
    # Generate sample data with clearer clusters
    np.random.seed(42)
    
    # Create 3 distinct clusters
    cluster1 = np.random.randn(30, 20) + [0, 0, 0, 0, 0] * 4
    cluster2 = np.random.randn(30, 20) + [5, 5, 5, 5, 5] * 4
    cluster3 = np.random.randn(30, 20) + [-5, -5, -5, -5, -5] * 4
    data = np.vstack([cluster1, cluster2, cluster3])
    
    print(f"Data shape: {data.shape}")
    
    # Test both methods
    erica_both = ERICA(
        data=data,
        k_range=[2, 3, 4],
        n_iterations=100,
        method='both',  # Compare both methods
        random_seed=123
    )
    
    print("\nRunning ERICA with both K-Means and Agglomerative clustering...")
    results = erica_both.run()
    metrics = erica_both.get_metrics()
    
    # Compare methods
    print("\nMethod Comparison:")
    print("-" * 60)
    for k in [2, 3, 4]:
        km_cri = metrics[k]['kmeans']['CRI']
        agg_cri = metrics[k]['agglomerative']['CRI']
        print(f"K={k}:")
        print(f"  K-Means CRI:       {km_cri:.3f}")
        print(f"  Agglomerative CRI: {agg_cri:.3f}")
        print(f"  Difference:        {abs(km_cri - agg_cri):.3f}")
    
    return erica_both, metrics


def example_3_gene_expression_simulation():
    """
    Example 3: Simulated Gene Expression Analysis
    --------------------------------------------
    Realistic example for bioinformatics applications
    """
    print("\n" + "="*60)
    print("Example 3: Simulated Gene Expression Analysis")
    print("="*60)
    
    # Simulate gene expression data
    # 200 samples (patients), 1000 genes
    np.random.seed(42)
    n_samples = 200
    n_genes = 1000
    
    # Simulate 4 patient subtypes with different expression patterns
    subtype1 = np.random.gamma(2, 2, (50, n_genes))
    subtype2 = np.random.gamma(3, 1.5, (50, n_genes))
    subtype3 = np.random.gamma(1.5, 3, (50, n_genes))
    subtype4 = np.random.gamma(2.5, 2.5, (50, n_genes))
    
    gene_expression = np.vstack([subtype1, subtype2, subtype3, subtype4])
    
    print(f"Gene expression data: {n_samples} samples × {n_genes} genes")
    
    # Run ERICA to find optimal number of subtypes
    erica_genes = ERICA(
        data=gene_expression,
        k_range=[2, 3, 4, 5, 6],  # Test different subtype counts
        n_iterations=150,
        method='both',
        train_percent=0.75,
        random_seed=123
    )
    
    print("\nAnalyzing clustering replicability for patient stratification...")
    results = erica_genes.run()
    metrics = erica_genes.get_metrics()
    
    # Find optimal k based on highest TWCRI
    best_k = max(metrics.keys(), 
                 key=lambda k: metrics[k]['kmeans']['TWCRI'])
    
    print("\nOptimal Number of Patient Subtypes:")
    print("-" * 60)
    print(f"Recommended K: {best_k}")
    print(f"TWCRI: {metrics[best_k]['kmeans']['TWCRI']:.3f}")
    print(f"CRI:   {metrics[best_k]['kmeans']['CRI']:.3f}")
    
    # Get CLAM matrix for optimal k
    clam_matrix = erica_genes.get_clam_matrix(k=best_k, method='kmeans')
    print(f"\nCLAM matrix shape: {clam_matrix.shape}")
    
    # Visualize CLAM matrix
    fig = erica_genes.plot_clam_heatmap(k=best_k, method='kmeans')
    fig.write_html("gene_expression_clam.html")
    print("CLAM heatmap saved as 'gene_expression_clam.html'")
    
    return erica_genes, metrics, best_k


def example_4_identifying_unstable_samples():
    """
    Example 4: Identifying Unstable/Ambiguous Samples
    ------------------------------------------------
    Find samples with low replicability (potential outliers)
    """
    print("\n" + "="*60)
    print("Example 4: Identifying Unstable Samples")
    print("="*60)
    
    # Generate data with some ambiguous samples
    np.random.seed(42)
    
    # Clear clusters
    cluster1 = np.random.randn(40, 15) + [0, 0, 0, 0, 0] * 3
    cluster2 = np.random.randn(40, 15) + [6, 6, 6, 6, 6] * 3
    
    # Add ambiguous samples (between clusters)
    ambiguous = np.random.randn(20, 15) + [3, 3, 3, 3, 3] * 3
    
    data = np.vstack([cluster1, cluster2, ambiguous])
    
    # Run ERICA
    erica = ERICA(
        data=data,
        k_range=[2],
        n_iterations=100,
        method='kmeans',
        random_seed=123
    )
    
    print(f"Data shape: {data.shape}")
    print("Running analysis to identify unstable samples...")
    
    results = erica.run()
    
    # Get CLAM matrix
    clam_matrix = erica.get_clam_matrix(k=2, method='kmeans')
    
    # Calculate sample-wise replicability
    # Diagonal elements of CLAM matrix indicate consistency
    sample_consistency = np.diag(clam_matrix) / clam_matrix.sum(axis=1)
    
    # Identify unstable samples (low consistency)
    threshold = 0.6
    unstable_samples = np.where(sample_consistency < threshold)[0]
    
    print(f"\nSamples with consistency < {threshold}:")
    print(f"Found {len(unstable_samples)} potentially ambiguous samples")
    print(f"Sample indices: {unstable_samples.tolist()}")
    
    # Statistics
    print(f"\nConsistency Statistics:")
    print(f"  Mean: {sample_consistency.mean():.3f}")
    print(f"  Min:  {sample_consistency.min():.3f}")
    print(f"  Max:  {sample_consistency.max():.3f}")
    
    return erica, unstable_samples


def example_5_custom_output_and_saving():
    """
    Example 5: Custom Output Directory and Result Saving
    --------------------------------------------------
    Control where results are saved and how to load them
    """
    print("\n" + "="*60)
    print("Example 5: Custom Output and Result Saving")
    print("="*60)
    
    # Generate sample data
    data = np.random.rand(50, 20)
    
    # Run ERICA with custom output directory
    erica = ERICA(
        data=data,
        k_range=[2, 3],
        n_iterations=50,
        method='kmeans',
        output_dir='./my_erica_results',  # Custom output directory
        random_seed=123
    )
    
    print("Running ERICA with custom output directory...")
    results = erica.run()
    
    print(f"\nResults saved to: {erica.output_dir}")
    print("\nSaved files include:")
    print("  - train_indices.npy: Training sample indices for each iteration")
    print("  - test_indices.npy:  Test sample indices for each iteration")
    print("  - clam_matrix_k*.npy: CLAM matrices for each k")
    print("  - metrics.yaml: Computed replicability metrics")
    
    # Access saved results
    metrics = erica.get_metrics()
    
    print("\nAccess metrics from saved results:")
    for k in [2, 3]:
        print(f"K={k}: CRI={metrics[k]['kmeans']['CRI']:.3f}")
    
    return erica


def main():
    """
    Main function to run all examples
    """
    print("\n")
    print("*" * 70)
    print("*" + " " * 68 + "*")
    print("*" + "  ERICA Clustering Package - Comprehensive Examples".center(68) + "*")
    print("*" + " " * 68 + "*")
    print("*" * 70)
    print("\n")
    
    # Run all examples
    try:
        # Example 1: Basic usage
        erica1, metrics1 = example_1_basic_usage()
        
        # Example 2: Method comparison
        erica2, metrics2 = example_2_comparing_methods()
        
        # Example 3: Gene expression simulation
        erica3, metrics3, best_k = example_3_gene_expression_simulation()
        
        # Example 4: Unstable samples
        erica4, unstable = example_4_identifying_unstable_samples()
        
        # Example 5: Custom output
        erica5 = example_5_custom_output_and_saving()
        
        print("\n" + "="*60)
        print("All examples completed successfully!")
        print("="*60)
        print("\nGenerated files:")
        print("  - erica_metrics_plot.html")
        print("  - gene_expression_clam.html")
        print("  - Various output directories with results")
        
    except Exception as e:
        print(f"\nError occurred: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

