from pathlib import Path
from typing import Optional, Dict, Any
import time
import torch


class ProgressCallback:
  def __init__(self, print_every: int = 100):
    self.print_every = print_every
    self.epoch_start_time = None

  def on_epoch_start(self, trainer):
    self.epoch_start_time = time.time()
    print(f"\nEpoch {trainer.epoch + 1} starting...")

  def on_batch_end(self, trainer, batch_idx: int, batch: Any, metrics: Dict[str, float]):
    if (batch_idx + 1) % self.print_every == 0:
      print(f"Step {trainer.global_step}, Batch {batch_idx + 1}, Loss: {metrics['loss']:.4f}, Avg Loss: {metrics['avg_loss']:.4f}")

  def on_epoch_end(self, trainer, metrics: Dict[str, float]):
    elapsed = time.time() - self.epoch_start_time
    print(f"Epoch {trainer.epoch} finished in {elapsed:.2f}s, Loss: {metrics['loss']:.4f}")

  def on_validation_end(self, trainer, metrics: Dict[str, float]):
    print(f"Validation Loss: {metrics['val_loss']:.4f}")


class CheckpointCallback:
  def __init__(
    self,
    checkpoint_dir: str,
    save_every_n_epochs: int = 1,
    save_best_only: bool = False,
    monitor: str = "val_loss",
    mode: str = "min",
    keep_last_n: Optional[int] = None,
  ):
    self.checkpoint_dir = Path(checkpoint_dir)
    self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
    self.save_every_n_epochs = save_every_n_epochs
    self.save_best_only = save_best_only
    self.monitor = monitor
    self.mode = mode
    self.keep_last_n = keep_last_n
    self.best_value = float("inf") if mode == "min" else float("-inf")
    self.checkpoints = []

  def on_epoch_end(self, trainer, metrics: Dict[str, float]):
    if not self.save_best_only and (trainer.epoch + 1) % self.save_every_n_epochs == 0:
      self._save_checkpoint(trainer, metrics)

  def on_validation_end(self, trainer, metrics: Dict[str, float]):
    if self.save_best_only and self.monitor in metrics:
      current_value = metrics[self.monitor]
      is_best = (self.mode == "min" and current_value < self.best_value) or (self.mode == "max" and current_value > self.best_value)

      if is_best:
        self.best_value = current_value
        self._save_checkpoint(trainer, metrics, is_best=True)

  def _save_checkpoint(self, trainer, metrics: Dict[str, float], is_best: bool = False):
    suffix = "best" if is_best else f"epoch_{trainer.epoch}"
    path = self.checkpoint_dir / f"checkpoint_{suffix}.pt"
    trainer.save_checkpoint(path, **metrics)

    if not is_best:
      self.checkpoints.append(path)

      if self.keep_last_n is not None and len(self.checkpoints) > self.keep_last_n:
        old_checkpoint = self.checkpoints.pop(0)
        if old_checkpoint.exists():
          old_checkpoint.unlink()


class EarlyStoppingCallback:
  def __init__(
    self,
    patience: int = 5,
    monitor: str = "val_loss",
    mode: str = "min",
    min_delta: float = 0.0,
  ):
    self.patience = patience
    self.monitor = monitor
    self.mode = mode
    self.min_delta = min_delta
    self.best_value = float("inf") if mode == "min" else float("-inf")
    self.wait = 0
    self.stopped_epoch = 0

  def on_validation_end(self, trainer, metrics: Dict[str, float]):
    if self.monitor not in metrics:
      return

    current_value = metrics[self.monitor]

    if self.mode == "min":
      improved = current_value < (self.best_value - self.min_delta)
    else:
      improved = current_value > (self.best_value + self.min_delta)

    if improved:
      self.best_value = current_value
      self.wait = 0
    else:
      self.wait += 1
      if self.wait >= self.patience:
        self.stopped_epoch = trainer.epoch
        print(f"\nEarly stopping triggered at epoch {self.stopped_epoch}")
        raise StopTraining(f"Early stopping after {self.patience} epochs without improvement")


class LearningRateSchedulerCallback:
  def __init__(self, scheduler):
    self.scheduler = scheduler

  def on_epoch_end(self, trainer, metrics: Dict[str, float]):
    if hasattr(self.scheduler, "step"):
      if "ReduceLROnPlateau" in self.scheduler.__class__.__name__:
        if "val_loss" in metrics:
          self.scheduler.step(metrics["val_loss"])
      else:
        self.scheduler.step()


class GradientClippingCallback:
  def __init__(self, max_norm: float = 1.0, norm_type: float = 2.0):
    self.max_norm = max_norm
    self.norm_type = norm_type

  def on_batch_end(self, trainer, batch_idx: int, batch: Any, metrics: Dict[str, float]):
    total_norm = torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), self.max_norm, norm_type=self.norm_type)
    metrics["grad_norm"] = total_norm.item()


class StopTraining(Exception):
  pass
