import torch
import gpytorch
import numpy as np
import os 
from math import pi

def optimise_model (model, X_train, y_train_std, likelihood, mll):
    # Optimizer (Adam with all model parameters)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    # Train mode for model + likelihood
    model.train(); likelihood.train()
    prev_loss = None
    for i in range(10000):
        optimizer.zero_grad()
        output = model(X_train)
        loss = -mll(output, y_train_std)
        loss_value = loss.item()
        loss.backward()
        optimizer.step()
        # Early stopping based on tolerance
        if prev_loss is not None and abs(prev_loss - loss_value) < 0.001:
            print(f"Converged at iteration {i} with change {abs(prev_loss - loss_value):.6e}")
            break

        
        prev_loss = loss_value


def query_model(model, probe_pts, y_std, y_mean,likelihood=None):
    
    if likelihood is not None:
        likelihood.eval()
    model.eval()
    with torch.no_grad():
        if likelihood is not None:
            preds = likelihood(model(probe_pts), noise=torch.full((probe_pts.shape[0],), 0.5))
        else:
            preds = model(probe_pts)
        mean_pred = preds.mean * y_std + y_mean     # de-standardize back to eV
        var_pred = preds.variance * (y_std**2)

    return mean_pred, var_pred


def save_checkpoint(X_train, y_train, noise_train, it, stage="iter"):
    torch.save({
        'X_train': X_train,
        'y_train': y_train,
        'noise_train': noise_train,
        'iteration': it,
        'stage': stage
    }, 'checkpoint.pt')

def load_checkpoint():
    if os.path.exists('checkpoint.pt'):
        return torch.load('checkpoint.pt')
    return None

def estimate_imse(model, likelihood, probe_pts, y_std, theta_range, phi_range):
    model.eval(); likelihood.eval()
    if theta_range is not None:
        theta_min, theta_max = theta_range
    else:
        theta_min, theta_max = 0, np.pi
    if phi_range is not None:
        phi_min, phi_max = phi_range
    else:
        phi_min, phi_max = 0, 2 * np.pi
    area_fraction = (phi_max - phi_min) * (np.cos(theta_min) - np.cos(theta_max)) / (4 * np.pi)
    
    # Don't compute gradients for resources, use fater vairance prediction at minimal accuracy loss
    with torch.no_grad():
        # Extract model predictions at probe points
        preds = model(probe_pts)
        # Preds is GPyTorch MultivariateNormal object with .mean and .variance
        # Multiply by train_y_std^2 to get variance in original units
        var = torch.clamp(preds.variance, min=0.0) * (y_std**2)
    # Integrating over sphere requires 4 pi multiplier so variance spread over the whole sphere gives higher value.
    return area_fraction * ((4*pi )/ probe_pts.shape[0]) * var.sum().item() # Sum of vairance at all points, multiplied by weight factor (float)

