import torch
import matplotlib.pyplot as plt
import numpy as np

def show_activation_areas(model, class_inputs, percentile=None, only_hidden=False):
    """
    Visualize the average activations of a neural network for a given class.

    This function computes the mean activation values across all samples in 
    `class_inputs` for each layer of the model and visualizes the network using 
    colored nodes. The forward method of the model must return a tuple or list 
    of layer activations (hidden layers and output, es --> return hidden1, hidden2, output).

    Node coloring scheme:
        - Red: activation < 0
        - White: activation = 0
        - Green: activation > 0
        - Saturation is proportional to the magnitude of the activation
        - If `percentile` is set, only neurons with activation above the 
          specified percentile are highlighted (green); others are white.

    Parameters
    ----------
    model : torch.nn.Module
        PyTorch model to visualize. Its forward method should return the activations
        of each layer as a tuple or list.
    class_inputs : torch.Tensor
        Batch of inputs, shape (N, input_dim). Preferably all samples belong to a 
        single class to visualize class-specific activation patterns, but it can 
        also be used with multi-class inputs, producing an ensemble of average 
        activations.
    percentile : float, optional
        Percentile threshold (0-100) to highlight top neurons in hidden/output layers.
        If None, all neurons are colored based on the red/white/green scheme.
    only_hidden : bool, default=False
        If True, plots only hidden layers. If False, includes input and output layers

    Returns
    -------
    None
        Displays a matplotlib plot of the network with average activations for each layer

    Example
    -------
    >>> class_inputs = train_input[train_labels == 3]
    >>> show_avg_class_forward(model, class_inputs, percentile=90, only_hidden=False)
    """
    model.eval()
    
    # Forward pass
    with torch.inference_mode():
        acts = model(class_inputs)
        if not isinstance(acts, (tuple, list)):
            acts = (acts,)
    
    # Compute average activations
    avg_input = class_inputs.detach().numpy().mean(axis=0)
    avg_acts = [avg_input] + [a.detach().numpy().mean(axis=0) if a.ndim==2 else a.detach().numpy() for a in acts]
    
    if only_hidden:
        avg_acts = avg_acts[1:-1] if len(avg_acts) > 2 else avg_acts[1:]
    
    # Determine layer names
    n_hidden = len(avg_acts) - (0 if only_hidden else 2)
    layer_names = [f"hidden{i+1}" for i in range(n_hidden)]
    if not only_hidden:
        layer_names = ["input"] + layer_names + ["output"]
    
    # Plotting
    neuroni_lista = []
    layer_neurons = [len(a) for a in avg_acts]
    
    for i, layer_vals in enumerate(avg_acts):
        n_neurons = len(layer_vals)
        coords_y = np.linspace(-n_neurons/2, n_neurons/2, n_neurons)
        max_abs = np.max(np.abs(layer_vals)) + 1e-8
        
        plt.text(i, n_neurons/2 + 1, layer_names[i], fontsize=12, ha='center')
        
        for idx, y in enumerate(coords_y):
            val = layer_vals[idx]
            # Color mapping
            if percentile is None or (not only_hidden and i==0):
                if val > 0:
                    norm = val / max_abs
                    color = (1-norm, 1, 1-norm)  # green
                elif val < 0:
                    norm = -val / max_abs
                    color = (1, 1-norm, 1-norm)  # red
                else:
                    color = (1,1,1)  # white
            else:
                thresh = np.percentile(layer_vals, percentile)
                color = "green" if val >= thresh else "white"
            
            plt.scatter(i, y, color=color, s=500, zorder=2)
            neuroni_lista.append((i, y))
    
    # Draw connections
    for i in range(len(layer_neurons)-1):
        actual_neurons = [t for t in neuroni_lista if t[0]==i]
        next_neurons = [t for t in neuroni_lista if t[0]==i+1]
        for n1 in actual_neurons:
            x1, y1 = n1
            for n2 in next_neurons:
                x2, y2 = n2
                plt.plot([x1,x2],[y1,y2], color="gray", zorder=1)
    
    plt.axis("off")
    plt.show()
