# ------------------------------------------------------------------------------
# Project: IFE_Surrogate
# Authors: Tobias Leitgeb, Julian Tischler
# CD Lab 2025
# ------------------------------------------------------------------------------
from .trainer import Trainer
from ..models import GPModel

from functools import partial
import typing
from typing import Callable, Dict, Tuple, Optional
import optax
from jax import value_and_grad, jit, random
from typing import Tuple, Dict, Callable
from jaxtyping import Key, Array, Float, Int, Bool
import jax.numpy as jnp


TypeGPModel = typing.TypeVar("GPModel", bound=GPModel)
# TypeOptimizer = typing.TypeVar("OptaxOptimizer", bound=optax.base.)


class OptaxTrainer(Trainer):
    ## for a list of available optimizers consult "https://optax.readthedocs.io/en/latest/api/optimizers.html"
    
    def __init__(self, optimizer = optax.adam(1e-3), tolerance: Float = 1e-2, patience: Float = 20, **kwargs):
        super().__init__(**kwargs)
        self.optimizer = optimizer
        self.tolerance = tolerance
        self.patience = patience


    def train(self, model: TypeGPModel) -> Tuple[Dict, Optional[Dict]]:
        self.model = model
        value_and_grad_fn = value_and_grad(self.loss_fn)

        @jit
        def step(params: Dict, opt_state: Dict):
            loss, grads = value_and_grad_fn(params)
            updates, opt_state = self.optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            return params, opt_state, loss

        def loop(params, opt_state):
            best_loss = jnp.inf
            no_improve = 0
            for it in range(self.number_iterations):
                params, opt_state, loss = step(params, opt_state)

                if loss + self.tolerance < best_loss:
                    best_loss = loss
                    no_improve = 0
                else:
                    no_improve += 1

                if self.verbose and it % 50 == 0:
                    print(f"Iter {it}: loss={loss:.5f}, best={best_loss:.5f}")

                if no_improve >= self.patience:
                    if self.verbose:
                        print(f"Early stop at iter {it}, best loss {best_loss:.5f}")
                    break
            return params, opt_state, best_loss

        results = {}
        for r in range(self.number_restarts):
            _, self.key = random.split(self.key)
            if self.sample_parameters:
                params = self.model.kernel.sample_hyperparameters(self.key)
            else:
                params = self.model.kernel.get_params()

            opt_state = self.optimizer.init(params)
            params, opt_state, loss = loop(params, opt_state)

            results[f"run_{r}"] = {"params": params, "loss": loss, "key": self.key}
            if self.verbose:
                print(f"Run {r} loss: {loss}")

        results = {k: v for k, v in results.items() if not jnp.isnan(v["loss"])}
        best_key = min(results, key=lambda x: results[x]["loss"])

        if self.verbose:
            print("Best run:", results[best_key]["loss"])

        if self.save_history:
            return results[best_key], results
        return results[best_key], None


