import torch
from torch import nn, optim
from nfnets.utils import unitwise_norm
from collections import Iterable
[docs]class AGC(optim.Optimizer):
"""Generic implementation of the Adaptive Gradient Clipping
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
optim (torch.optim.Optimizer): Optimizer with base class optim.Optimizer
clipping (float, optional): clipping value (default: 1e-3)
eps (float, optional): eps (default: 1e-3)
model (torch.nn.Module, optional): The original model
ignore_agc (str, Iterable, optional): Layers for AGC to ignore
"""
def __init__(self, params, optim: optim.Optimizer, clipping: float = 1e-2, eps: float = 1e-3, model=None, ignore_agc=["fc"]):
if clipping < 0.0:
raise ValueError("Invalid clipping value: {}".format(clipping))
if eps < 0.0:
raise ValueError("Invalid eps value: {}".format(eps))
self.optim = optim
defaults = dict(clipping=clipping, eps=eps)
defaults = {**defaults, **optim.defaults}
if not isinstance(ignore_agc, Iterable):
ignore_agc = [ignore_agc]
if model is not None:
assert ignore_agc not in [
None, []], "You must specify ignore_agc for AGC to ignore fc-like(or other) layers"
names = [name for name, module in model.named_modules()]
for module_name in ignore_agc:
if module_name not in names:
raise ModuleNotFoundError(
"Module name {} not found in the model".format(module_name))
params = [{"params": list(module.parameters())} for name,
module in model.named_modules() if name not in ignore_agc]
else:
params = [{"params": params}]
self.agc_params = params
self.eps = eps
self.clipping = clipping
self.param_groups = optim.param_groups
self.state = optim.state
#super(AGC, self).__init__([], defaults)
[docs] @torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.agc_params:
for p in group['params']:
if p.grad is None:
continue
param_norm = torch.max(unitwise_norm(
p.detach()), torch.tensor(self.eps).to(p.device))
grad_norm = unitwise_norm(p.grad.detach())
max_norm = param_norm * self.clipping
trigger = grad_norm > max_norm
clipped_grad = p.grad * \
(max_norm / torch.max(grad_norm,
torch.tensor(1e-6).to(grad_norm.device)))
p.grad.detach().data.copy_(torch.where(trigger, clipped_grad, p.grad))
return self.optim.step(closure)
[docs] def zero_grad(self, set_to_none: bool = False):
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
Arguments:
set_to_none (bool): instead of setting to zero, set the grads to None.
This is will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When the user tries to access a gradient and perform manual ops on it,
a None attribute or a Tensor full of 0s will behave differently.
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
are guaranteed to be None for params that did not receive a gradient.
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skips
the step altogether).
"""
for group in self.agc_params:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()