import torch
import torch.nn as nn
import matplotlib.pyplot as plt


def train_nn(
    model, loss_fn, metric, optimizer,
    train_input, train_output,
    test_input, test_output,
    epochs,
    problem,
    patience=None,
    verbose=1,
    random_seed=None,
    train_losses=None, test_losses=None,
    metrics=None,
    epochs_counter=None,
    plot=True
):
    """
    Train a PyTorch neural network with optional plotting of losses and metrics.

    Supports regression, binary classification, and multiclass classification.

    Parameters
    ----------
    model : nn.Module
        The PyTorch model to train. Must return either only the output or a tuple where the last element is the output.
    loss_fn : torch.nn.modules.loss._Loss
        Loss function appropriate for the problem type (e.g., nn.MSELoss(), nn.BCELoss(), nn.CrossEntropyLoss()).
    metric : callable
        Metric function from sklearn.metrics that accepts (y_true, y_pred) as numpy arrays.
    optimizer : torch.optim.Optimizer
        Optimizer to update the model's parameters.
    train_input : torch.Tensor
        Training input data.
    train_output : torch.Tensor
        Training target data.
    test_input : torch.Tensor
        Validation/test input data.
    test_output : torch.Tensor
        Validation/test target data.
    epochs : int
        Number of epochs to train in this call.
    problem : str
        Problem type: "regression", "classification_binary", or "classification_multiclass".
    patience : int, optional
        Number of epochs without improvement to suggest stopping early.
    verbose : int, optional
        Verbosity level: 0 = silent, 1 = minimal printing, 2 = detailed printing.
    random_seed : int, optional
        Seed for reproducibility.
    train_losses : list, optional
        List to store training loss values. Required if plot=True.
    test_losses : list, optional
        List to store test loss values. Required if plot=True.
    metrics : list, optional
        List to store metric values. Required if plot=True.
    epochs_counter : list, optional
        A single-element list storing cumulative epochs across multiple calls (e.g., [0]).
    plot : bool, default True
        If True, generates plots of train/test loss and metric over epochs.

    Raises
    ------
    ValueError
        If plot=True and any of train_losses, test_losses, or metrics is None.

    Notes
    -----
    - Handles consistent transformation of outputs for classification problems.
    - First train and re-train statuses are derived by the `epochs_counter` parameterS
    - If epochs_counter is not provided, the function will assume training starts at epoch 0.
    - Best metric and test loss are reported at the end.
    - Prints warnings regarding early stopping based on `patience`.

    Examples
    --------
    import torch
    import torch.nn as nn
    from sklearn.metrics import mean_squared_error, accuracy_score, balanced_accuracy_score

    # Regression
    model = NeuralNetworkModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_nn(model, nn.MSELoss(), mean_squared_error, optimizer,
             train_input, train_output,
             test_input, test_output,
             epochs=50, problem="regression",
             train_losses=[], test_losses=[], metrics=[], plot=True)

    # Binary classification
    model = NeuralNetworkModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_nn(model, nn.BCELoss(), balanced_accuracy_score, optimizer,
             train_input, train_output,
             test_input, test_output,
             epochs=50, problem="classification_binary",
             train_losses=[], test_losses=[], metrics=[], plot=True)

    # Multiclass classification
    model = NeuralNetworkModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_nn(model, nn.CrossEntropyLoss(), accuracy_score, optimizer,
             train_input, train_output,
             test_input, test_output,
             epochs=50, problem="classification_multiclass",
             train_losses=[], test_losses=[], metrics=[], plot=True)
    """
    
    # ----------------- CONTROLLI -----------------
    if plot:
        if train_losses is None or test_losses is None or metrics is None:
            raise ValueError("Se plot=True, devi fornire le liste 'train_losses', 'test_losses' e 'metrics' inizializzate come liste vuote.")
    
    if random_seed is not None:
        torch.manual_seed(random_seed)
    
    nocounter = False
    
    if test_losses is None:
        test_losses = []
    if train_losses is None:
        train_losses = []
    if metrics is None:
        metrics = []
    if epochs_counter is None:
        epochs_counter = [0]
        nocounter = True
  
    starting_done_epochs = epochs_counter[0]
    digits = 3
    verb_print = 10

    if verbose == 1:
        verb_print = 10
        digits = 4
    elif verbose == 2:
        verb_print = epochs
        digits = 6

    model.train()
    only_output = model(train_input)[-1].ndim == 1

    if starting_done_epochs == 0 and not nocounter:
      print(f"Training started on {model.__class__.__name__}() (first train):\n")
    elif starting_done_epochs == 0 and nocounter:
      print(f"Training started on {model.__class__.__name__}():\nNOTE: no epoch counter was passed, so it cannot be derived wether this is the first training loop or not. Epochs will be forced to start from 0\n")
    else:
      print(f"Training started on {model.__class__.__name__}() (re-train):\n")

    for epoch in range(epochs):
        model.train()
        y_pred = model(train_input) if only_output else model(train_input)[-1]

        # ----------------- LOSS -----------------
        if problem == "regression":
            loss = loss_fn(y_pred, train_output)
        elif problem == "classification_binary":
            loss = loss_fn(y_pred, train_output.float())
        elif problem == "classification_multiclass":
            train_target = torch.argmax(train_output, dim=1) if train_output.ndim > 1 else train_output
            loss = loss_fn(y_pred, train_target)
        else:
            raise ValueError("Invalid problem type")

        if train_losses is not None:
            train_losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ----------------- VALIDATION -----------------
        model.eval()
        with torch.no_grad():
            y_test_pred = model(test_input) if only_output else model(test_input)[-1]

            if problem == "regression":
                test_loss = float(loss_fn(y_test_pred, test_output).detach().numpy())
                metr = metric(test_output.detach().numpy(), y_test_pred.detach().numpy())
            elif problem == "classification_binary":
                test_loss = float(loss_fn(y_test_pred, test_output.float()).detach().numpy())
                y_pred_labels = (y_test_pred > 0.5).int()
                metr = metric(test_output.detach().numpy(), y_pred_labels.detach().numpy())
            elif problem == "classification_multiclass":
                y_true_labels = torch.argmax(test_output, dim=1).detach().numpy() if test_output.ndim > 1 else test_output.detach().numpy()
                y_pred_labels = torch.argmax(y_test_pred, dim=1).detach().numpy()
                test_loss = float(loss_fn(y_test_pred, torch.from_numpy(y_true_labels)).detach().numpy())
                metr = metric(y_true_labels, y_pred_labels)

            if test_losses is not None:
                test_losses.append(test_loss)
            if metrics is not None:
                metrics.append(metr)

        # ----------------- PRINT PROGRESS -----------------
        epochs_counter[0] += 1
        if (epoch % max(1, epochs // verb_print) == 0 or epoch == epochs-1) and verbose != 0:
            print(f"Epoch {epoch+starting_done_epochs+1} ({epoch+1} from last call) - Test Loss: {round(test_loss, digits)} - Train Loss: {round(loss.item(), digits)} - Metric: {round(metr, digits)}")
        
    # ----------------- PLOTS -----------------
    if plot:
        print("\n---------------\n")
        plt.plot(train_losses, label="Train Loss", color="green")
        plt.plot(test_losses, label="Test Loss", color="red")
        plt.title("Loss over epochs")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.legend()
        plt.show()

        plt.plot(metrics, color="red", label="Metric")
        plt.title("Metric over epochs")
        plt.xlabel("Epoch")
        plt.ylabel("Metric")
        plt.grid(True)
        plt.legend()
        plt.show()

    # ----------------- SUMMARY -----------------
    print(f"\nFinished train loop at epoch {epochs+starting_done_epochs} (trained for {epochs} epochs from last call)\n\n---------------\n")
    if plot:
        print(f"Best metric: {round(max(metrics), digits)} at epoch {metrics.index(max(metrics))+1}")
        print(f"Best test loss: {round(min(test_losses), digits)} at epoch {test_losses.index(min(test_losses))+1}")
    
        if patience is not None:
            last_improve = abs(epochs_counter[0] - (metrics.index(max(metrics))+1))
            if last_improve >= patience:
                print(f"\n---------------\n\nMetric hasn't improved in the last {last_improve} epochs, consider finishing training")
            else:
                print(f"\n---------------\n\nMetric has improved in the last {last_improve} epochs, consider continuing training")
