import torch
import matplotlib.pyplot as plt
import numpy as np

def show_forward_pass(model, x_input, percentile=None):
    """
    Visualizes the activations of a neural network.

    This function takes a PyTorch model whose forward method returns
    the activations of the logical layers (input, hidden, output) as a tuple or list.
    Neurons are colored based on their final values:

        - Red: value < 0
        - White: value = 0
        - Green: value > 0
        - Saturation proportional to the absolute value
        - Optionally, only the top percentile of hidden/output neurons can be highlighted


    Parameters
    ----------
    model : torch.nn.Module
        PyTorch model to visualize. The forward method must return the layer
        activations as a tuple or list (es. --> return hidden1_activations, hidden2_activations, output_activations)

    x_input : torch.Tensor
        Single input to the network (1D)

    percentile : float, optional
        If specified, only neurons with activation >= percentile
        (value between 0 and 100) will be highlighted. Other hidden/output neurons will be gray.
        If None, all neurons are colored according to the red/white/green scheme.

    Returns
    -------
    None
        The function directly displays the network plot using matplotlib.

    Example
    -------
    >>> model = SimpleModel()
    >>> x_sample = torch.rand(5)
    >>> show_forward_pass(model, x_sample, percentile=90)
    """
    
    model.eval()
    
    # Forward pass, raccogli tutte le attivazioni restituite
    with torch.inference_mode():
        acts = model(x_input.unsqueeze(0)) if x_input.ndim==1 else model(x_input)
    
    # Se il modello restituisce un singolo tensore, trasformalo in tuple
    if not isinstance(acts, (tuple, list)):
        acts = (acts,)
    
    # Inserisci input come primo layer
    input_vals = x_input.detach().numpy() if isinstance(x_input, torch.Tensor) else np.array(x_input)
    activations = [input_vals] + [a.detach().numpy()[0] if a.ndim==2 else a.detach().numpy() for a in acts]
    
    layer_neurons = [len(a) for a in activations]
    neuroni_lista = []
    
    # Nomi dei layer
    n_hidden = len(activations) - 2
    layer_names = ["input"] + [f"hidden{i+1}" for i in range(n_hidden)] + ["output"]
    
    for i, layer_vals in enumerate(activations):
        n_neurons = len(layer_vals)
        coords_y = np.linspace(-n_neurons/2, n_neurons/2, n_neurons)
        max_abs = np.max(np.abs(layer_vals)) + 1e-8
        
        # Scritta del layer sopra i neuroni
        plt.text(i, n_neurons/2 + 1, layer_names[i], fontsize=12, ha='center')
        
        for idx, y in enumerate(coords_y):
            val = layer_vals[idx]
            
            # Colore rosso/bianco/verde
            if percentile is None or i==0:
                if val > 0:
                    norm = val / max_abs
                    color = (1 - norm, 1, 1 - norm)  # verde
                elif val < 0:
                    norm = -val / max_abs
                    color = (1, 1 - norm, 1 - norm)  # rosso
                else:
                    color = (1,1,1)  # bianco
            else:
                thresh = np.percentile(layer_vals, percentile)
                color = "green" if val >= thresh else "white"
            
            plt.scatter(i, y, color=color, s=500, zorder=2)
            neuroni_lista.append((i, y))
    
    # Connessioni
    for i in range(len(layer_neurons)-1):
        actual_neurons = [t for t in neuroni_lista if t[0]==i]
        next_neurons = [t for t in neuroni_lista if t[0]==i+1]
        for n1 in actual_neurons:
            x1, y1 = n1
            for n2 in next_neurons:
                x2, y2 = n2
                plt.plot([x1,x2],[y1,y2],color="gray",zorder=1)
    
    plt.axis("off")
    plt.show()
