自 Batch Normalization 从 2015 年被 Google 提出来之后,又诞生了很多 Normalization 方法,如 Layer Normalization, Instance Normalization, Group Normalization。 这些方法作用、效果各不相同,但却有着统一的内核和本质:计算输入数据在某些维度上的方差和均值,归一化,最后用可学习参数映射归一化后的特征。这可以统一表达为:
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]我们以图像数据为例子,给定输入数据 $x \in (N, C, H, W)$, 其中 $N, C, H, W$ 分别为 batch size, 通道数,图像高和宽。
如上图所示,BN 计算在 $N, H, W$ 维度上的均值方差,LN 计算在 $C, H, W$ 维度上的均值方差,IN 计算在 $H, W$ 维度上的均值方差,GN 计算在 $C’, H, W$ 维度上的均值方差,其中 $C’$ 是分组后的通道个数。
计算维度的不同是这些方法的唯一区别。也正是因为计算维度的不同,也导致了不同的效果和特性。
我们可以很轻松的用 PyTorch 实现每个方法的等效版本:
BN
inputs = torch.randn(5, 256, 32, 32) # (N, C, H, W)
bn = nn.BatchNorm2d(256) # Weight shape: (C, )
# weight default is 1
bn.weight.data = torch.rand_like(bn.weight)
# compute on (N, H, W)
var, mean = torch.var_mean(inputs, dim=(0, 2, 3), keepdim=True, unbiased=False) # (1, C, 1, 1)
std = (var + bn.eps).sqrt() # (1, C, 1, 1)
norm = (inputs - mean) / std # (N, C, H, W)
print(torch.allclose(
norm * bn.weight.view(1, 256, 1, 1) + bn.bias.view(1, 256, 1, 1),
bn(inputs))
) # True
IN
inputs = torch.randn(5, 256, 32, 32) # (N, C, H, W)
ins = nn.InstanceNorm2d(256, affine=True) # Weight shape: (C, )
# weight default is 1
ins.weight.data = torch.rand_like(ins.weight)
# compute on (H, W)
var, mean = torch.var_mean(inputs, dim=(2, 3), keepdim=True, unbiased=False) # (N, C, 1, 1)
std = (var + ins.eps).sqrt() # (N, C, 1, 1)
norm = (inputs - mean) / std # (N, C, H, W)
print(torch.allclose(
norm * ins.weight.view(1, 256, 1, 1) + ins.bias.view(1, 256, 1, 1),
ins(inputs))
) # True
LN
inputs = torch.randn(5, 256, 32, 32) # (N, C, H, W)
normalized_shape = inputs.shape[1:] # Normalize on (C, H, W)
ln = nn.LayerNorm(normalized_shape)
# weight default is 1
ln.weight.data = torch.rand_like(ln.weight)
# compute on (C, H, W)
var, mean = torch.var_mean(inputs, dim=(1, 2, 3), keepdim=True, unbiased=False) # (N, 1, 1, 1)
std = (var + ln.eps).sqrt() # (N, 1, 1, 1)
norm = (inputs - mean) / std # (N, C, H, W)
print(torch.allclose(norm * ln.weight + ln.bias, ln(inputs))) # True
GN
inputs = torch.randn(5, 256, 32, 32) # (N, C, H, W)
num_groups = 32
bn = nn.GroupNorm(num_groups=num_groups, num_channels=256) # Weight shape: (C, )
# weight default is 1
bn.weight.data = torch.rand_like(bn.weight)
grouped_inputs = inputs.view(5, num_groups, 256 // num_groups, 32, 32) # (N, G, C', H, W)
# compute on (C', H, W)
var, mean = torch.var_mean(grouped_inputs, dim=(2, 3, 4), keepdim=True, unbiased=False) # (N, G, 1, 1, 1)
std = (var + bn.eps).sqrt() # # (N, G, 1, 1, 1)
norm = (grouped_inputs - mean) / std # (N, G, C', H, W)
print(torch.allclose(
norm.view(5, 256, 32, 32) * bn.weight.view(1, 256, 1, 1) + bn.bias.view(1, 256, 1, 1),
bn(inputs))
) # True