import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pathlib import Path
from typing import Optional, Callable, Dict, Any


class Trainer:
  def __init__(
    self,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: Callable,
    device: Optional[str] = None,
    mixed_precision: bool = False,
    grad_acc_steps: int = 1,
    max_grad_norm: Optional[float] = None,
    checkpoint_dir: Optional[str] = None,
  ):
    self.model = model
    self.optimizer = optimizer
    self.loss_fn = loss_fn
    self.grad_acc_steps = grad_acc_steps
    self.max_grad_norm = max_grad_norm
    self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None

    if self.checkpoint_dir:
      self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

    if device is None:
      self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
      self.device = torch.device(device)

    self.model.to(self.device)

    self.mixed_precision = mixed_precision
    self.scaler = torch.amp.GradScaler("cuda") if mixed_precision else None

    self.global_step = 0
    self.epoch = 0
    self.callbacks = []

  def add_callback(self, callback: "Callback"):
    self.callbacks.append(callback)

  def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
    self.model.train()
    total_loss = 0.0
    num_batches = 0

    self._trigger_callbacks("on_epoch_start")

    for batch_idx, batch in enumerate(train_loader):
      batch = self._to_device(batch)

      self._trigger_callbacks("on_batch_start", batch_idx, batch)

      loss = self._train_step(batch, batch_idx)

      total_loss += loss
      num_batches += 1

      metrics = {"loss": loss, "avg_loss": total_loss / num_batches}
      self._trigger_callbacks("on_batch_end", batch_idx, batch, metrics)

      self.global_step += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    epoch_metrics = {"loss": avg_loss}

    self._trigger_callbacks("on_epoch_end", epoch_metrics)
    self.epoch += 1

    return epoch_metrics

  def _train_step(self, batch: Any, batch_idx: int) -> float:
    if isinstance(batch, (tuple, list)):
      inputs, targets = batch[0], batch[1]
    elif isinstance(batch, dict):
      inputs = batch.get("inputs", batch.get("input", None))
      targets = batch.get("targets", batch.get("target", None))
    else:
      raise ValueError("Batch must be tuple, list, or dict")

    if self.mixed_precision:
      with torch.amp.autocast("cuda"):
        outputs = self.model(inputs)
        loss = self.loss_fn(outputs, targets)
        loss = loss / self.grad_acc_steps

      self.scaler.scale(loss).backward()

      if (batch_idx + 1) % self.grad_acc_steps == 0:
        if self.max_grad_norm is not None:
          self.scaler.unscale_(self.optimizer)
          torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
    else:
      outputs = self.model(inputs)
      loss = self.loss_fn(outputs, targets)
      loss = loss / self.grad_acc_steps
      loss.backward()

      if (batch_idx + 1) % self.grad_acc_steps == 0:
        if self.max_grad_norm is not None:
          torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        self.optimizer.step()
        self.optimizer.zero_grad()

    return loss.item() * self.grad_acc_steps

  @torch.no_grad()
  def validate(self, val_loader: DataLoader) -> Dict[str, float]:
    self.model.eval()
    total_loss = 0.0
    num_batches = 0

    self._trigger_callbacks("on_validation_start")

    for batch_idx, batch in enumerate(val_loader):
      batch = self._to_device(batch)

      if isinstance(batch, (tuple, list)):
        inputs, targets = batch[0], batch[1]
      elif isinstance(batch, dict):
        inputs = batch.get("inputs", batch.get("input", None))
        targets = batch.get("targets", batch.get("target", None))
      else:
        raise ValueError("Batch must be tuple, list, or dict")

      outputs = self.model(inputs)
      loss = self.loss_fn(outputs, targets)

      total_loss += loss.item()
      num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    val_metrics = {"val_loss": avg_loss}

    self._trigger_callbacks("on_validation_end", val_metrics)

    return val_metrics

  def fit(
    self,
    train_loader: DataLoader,
    val_loader: Optional[DataLoader] = None,
    epochs: int = 1,
  ) -> Dict[str, list]:
    history = {"train_loss": [], "val_loss": []}

    self._trigger_callbacks("on_fit_start")

    for epoch in range(epochs):
      train_metrics = self.train_epoch(train_loader)
      history["train_loss"].append(train_metrics["loss"])

      if val_loader is not None:
        val_metrics = self.validate(val_loader)
        history["val_loss"].append(val_metrics["val_loss"])

    self._trigger_callbacks("on_fit_end")

    return history

  def save_checkpoint(self, path: Optional[str] = None, **extra_state):
    if path is None:
      if self.checkpoint_dir is None:
        raise ValueError("No checkpoint path or directory specified")
      path = self.checkpoint_dir / f"checkpoint_epoch_{self.epoch}_step_{self.global_step}.pt"

    checkpoint = {
      "epoch": self.epoch,
      "global_step": self.global_step,
      "model_state_dict": self.model.state_dict(),
      "optimizer_state_dict": self.optimizer.state_dict(),
    }

    if self.scaler is not None:
      checkpoint["scaler_state_dict"] = self.scaler.state_dict()

    checkpoint.update(extra_state)

    torch.save(checkpoint, path)
    return path

  def load_checkpoint(self, path: str, strict: bool = True):
    checkpoint = torch.load(path, map_location=self.device, weights_only=False)

    self.model.load_state_dict(checkpoint["model_state_dict"], strict=strict)
    self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    if "scaler_state_dict" in checkpoint and self.scaler is not None:
      self.scaler.load_state_dict(checkpoint["scaler_state_dict"])

    self.epoch = checkpoint.get("epoch", 0)
    self.global_step = checkpoint.get("global_step", 0)

    return checkpoint

  def _to_device(self, batch: Any) -> Any:
    if isinstance(batch, torch.Tensor):
      return batch.to(self.device)
    elif isinstance(batch, (tuple, list)):
      return type(batch)(self._to_device(item) for item in batch)
    elif isinstance(batch, dict):
      return {key: self._to_device(val) for key, val in batch.items()}
    return batch

  def _trigger_callbacks(self, hook_name: str, *args, **kwargs):
    for callback in self.callbacks:
      if hasattr(callback, hook_name):
        getattr(callback, hook_name)(self, *args, **kwargs)


class Callback:
  def on_fit_start(self, trainer: Trainer):
    pass

  def on_fit_end(self, trainer: Trainer):
    pass

  def on_epoch_start(self, trainer: Trainer):
    pass

  def on_epoch_end(self, trainer: Trainer, metrics: Dict[str, float]):
    pass

  def on_batch_start(self, trainer: Trainer, batch_idx: int, batch: Any):
    pass

  def on_batch_end(self, trainer: Trainer, batch_idx: int, batch: Any, metrics: Dict[str, float]):
    pass

  def on_validation_start(self, trainer: Trainer):
    pass

  def on_validation_end(self, trainer: Trainer, metrics: Dict[str, float]):
    pass
