# ------------------------------------------------------------------------------
# Project: IFE_Surrogate
# Authors: Tobias Leitgeb, Julian Tischler
# CD Lab 2025
# ------------------------------------------------------------------------------
from abc import ABC, abstractmethod
import typing
from abc import ABC, abstractmethod
from typing import Callable, Dict, Tuple, Optional
from functools import partial
from ..models import GPModel


TypeGPModel = typing.TypeVar("GPModel", bound=GPModel)


class Trainer(ABC):
    def __init__(
        self,
        key,
        number_iterations: int = 100,
        number_restarts: int = 1,
        sample_parameters: bool = True,
        save_history: bool = False,
        verbose: bool = True,
    ):
        self.key = key
        self.number_iterations = number_iterations
        self.number_restarts = number_restarts
        self.sample_parameters = sample_parameters
        self.save_history = save_history
        self.verbose = verbose
        self._model = None

    @property
    def model(self):
        return self._model

    @model.setter
    def model(self, new_model):
        self._model = new_model
        likelihood = self.model.likelihood
        if "sigma_sq" in self.model.get_attributes().keys():
            self.nlml = partial(likelihood, self.model.X, self.model.Y, self.model.sigma_sq, self.model.jitter)
        else:
            self.nlml = partial(likelihood, self.model.X, self.model.Y, self.model.jitter)

    def loss_fn(self, params: Dict) -> float:
        self.model.kernel.update_params(params)
        return self.nlml(self.model.kernel.get_params())

    @abstractmethod
    def train(self, *args, **kwargs) -> Tuple[Dict, Optional[Dict]]:
        """Run optimization and return (best_run, history)."""
        pass