Source code for nfnets.utils

import torch
from torch import nn

from nfnets import WSConv2d

import warnings


[docs]def replace_conv(module: nn.Module, conv_class=WSConv2d): """Recursively replaces every convolution with WSConv2d. Usage: replace_conv(model) #(In-line replacement) Args: module (nn.Module): target's model whose convolutions must be replaced. conv_class (Class): Class of Conv(WSConv2d or ScaledStdConv2d) """ warnings.warn("Make sure to use it with non-residual models only") for name, mod in module.named_children(): target_mod = getattr(module, name) if type(mod) == torch.nn.Conv2d: setattr(module, name, conv_class(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size, target_mod.stride, target_mod.padding, target_mod.dilation, target_mod.groups, target_mod.bias is not None)) if type(mod) == torch.nn.BatchNorm2d: setattr(module, name, torch.nn.Identity()) for name, mod in module.named_children(): replace_conv(mod, conv_class)
[docs]def unitwise_norm(x: torch.Tensor): if x.ndim <= 1: dim = 0 keepdim = False elif x.ndim in [2, 3]: dim = 0 keepdim = True elif x.ndim == 4: dim = [1, 2, 3] keepdim = True else: raise ValueError('Wrong input dimensions') return torch.sum(x**2, dim=dim, keepdim=keepdim) ** 0.5