"""
Smoother - Gradient Scale Synchronization for PyTorch
"""

import torch
from torch.optim import Optimizer
from typing import Optional, Callable

__version__ = "0.1.0"
__all__ = ['SmartAdam']

class SmartAdam(Optimizer):
    """Adam optimizer with Gradient Scale Synchronization."""
    
    def __init__(self, params, base_lr=0.001, betas=(0.9, 0.999), eps=1e-8):
        defaults = dict(lr=base_lr, betas=betas, eps=eps)
        super().__init__(params, defaults)
        
    def step(self, closure: Optional[Callable] = None):
        """Performs optimization step."""
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                grad = p.grad.data
                state = self.state[p]
                
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1
                
                exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
                
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * (bias_correction2 ** 0.5) / bias_correction1
                
                denom = exp_avg_sq.sqrt().add_(group['eps'])
                p.data.addcdiv_(exp_avg, denom, value=-step_size)
        
        return loss
