import torch
import torch.nn as nn
from typing import Optional, Union, List


def count_params(model: nn.Module, trainable_only: bool = False) -> int:
  if trainable_only:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  return sum(p.numel() for p in model.parameters())


def freeze_layers(model: nn.Module, layer_names: Optional[List[str]] = None):
  if layer_names is None:
    for param in model.parameters():
      param.requires_grad = False
  else:
    for name, param in model.named_parameters():
      if any(layer_name in name for layer_name in layer_names):
        param.requires_grad = False


def unfreeze_layers(model: nn.Module, layer_names: Optional[List[str]] = None):
  if layer_names is None:
    for param in model.parameters():
      param.requires_grad = True
  else:
    for name, param in model.named_parameters():
      if any(layer_name in name for layer_name in layer_names):
        param.requires_grad = True


def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
  if device is None:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")
  return torch.device(device)


def set_seed(seed: int):
  import random
  import numpy as np

  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False


def get_lr(optimizer: torch.optim.Optimizer) -> float:
  for param_group in optimizer.param_groups:
    return param_group["lr"]
  return 0.0


def set_lr(optimizer: torch.optim.Optimizer, lr: float):
  for param_group in optimizer.param_groups:
    param_group["lr"] = lr


class ExponentialMovingAverage:
  def __init__(self, model: nn.Module, decay: float = 0.999):
    self.model = model
    self.decay = decay
    self.shadow = {}
    self.backup = {}

    for name, param in model.named_parameters():
      if param.requires_grad:
        self.shadow[name] = param.data.clone()

  def update(self):
    for name, param in self.model.named_parameters():
      if param.requires_grad:
        self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data

  def apply_shadow(self):
    for name, param in self.model.named_parameters():
      if param.requires_grad:
        self.backup[name] = param.data.clone()
        param.data = self.shadow[name]

  def restore(self):
    for name, param in self.model.named_parameters():
      if param.requires_grad:
        param.data = self.backup[name]
    self.backup = {}
