PyTorch 实现了 SWA(Stochastic Weight Averaging,随机加权平均),相比于传统的 SGD,使用 SWA 能够明显改善一些深度神经网络模型的测试精度(Test Accuracy)。而且,SWA 使用起来非常简单,能够加速模型训练,并提高模型的泛化能力。
SWA 基本原理
SWA 依赖两个重要的因素:
第一个是,SWA 使用一个不断修改的 LR 调节器(Learning Rate Schedule),使得 SGD 能够在最优值附近进行调整,并评估最优解附近的值对应的模型的精度,而不是只选取最优解对应的模型。因为,最优解对应的模型不一定是最优的,而且泛化能力可能也不一定最好。比如,在 75% 的训练时间里,可以使用一个标准的衰减学习率(Decaying Learning Rate)策略,然后在剩余 25% 的训练时间里将学习率设置为一个比较高的固定值。如下图所示:
第二个是,SWA 计算的是 SGD 遍历过的神经网络权重的平均值。例如,上面提到模型训练的后 25% 时间,我们可以在这 25% 时间里的每一轮训练(every epoch)后,计算一个权重的 running 平均值,在训练结束后再设置网络模型的权重为 SWA 权重平均值。
SWA 论文提供了对其算法的过程的描述,如下图所示:
上面算法描述,给出了下面几个重要的内容:
SWA 使用要点
在 PyTorch 中,使用 SWA 训练模型的基本要点,描述如下:
创建 SWA 模型,直接使用 PyTorch 提供的 AveragedModel,传入上面创建的 model,model 可以是任意继承了 torch.nn.Module 模型:
swa_model = AveragedModel(model)
这样 swa_model 会在模型训练过程中持续跟踪 model 的参数的平均值。如果希望更新这个 running 平均值,需要在 optimizer.step() 之后调用 update_parameters() 函数:
swa_model.update_parameters(model)
使用 SWA 时,通常配合使用 SWALR 这个 Learning Rate Scheduler 来退火(anneals)至一个固定的常量值,然后一直保持不变,使用示例:
swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05) # 或者 swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, anneal_strategy="cos", anneal_epochs=5, swa_lr=0.05)
在模型训练结束之后,需要进行计算 BN(Batch Normalization)统计量并更新模型:
torch.optim.swa_utils.update_bn(train_dataloader, swa_model)
SWA 编程实践
下面,我们基于上面提到的使用 SWA 的方法训练模型,通过实际编程来加强理解:
1 准备数据集并定义模型
我们使用 MNIST 数据集,神经网络模型使用 LeNet-5,代码处理逻辑如下所示:
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor # Download training/test data from open datasets. train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor(),) test_dataset = datasets.MNIST(root="data", train=False, download=False, transform=ToTensor(),) print(f"train_dataset_size = {len(train_dataset)}, test_dataset_size = {len(test_dataset)}") batch_size = 64 # Create data loaders. train_dataloader = DataLoader(train_dataset, batch_size=batch_size) test_dataloader = DataLoader(test_dataset, batch_size=batch_size) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") print(f"Shape of y: {y.shape} {y.dtype}") break # Get cpu, gpu or mps device for training. device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using {device} device") # Define model class LeNet5Model(nn.Module): def __init__(self): super().__init__() self._conv = nn.Sequential( nn.Conv2d(1, 6, 5, 1), nn.MaxPool2d(2), nn.Conv2d(6, 16, 5, 1), nn.MaxPool2d(2) ) self._fc = nn.Sequential( nn.Linear(4*4*16, 120), nn.Linear(120, 84), nn.Linear(84, 10) ) def forward(self, x): x = self._conv(x) x = x.view(-1, 4 * 4 * 16) x = self._fc(x) return x
代码比较容易,不再赘述。
2 使用 SWA 训练模型
首先,我们创建模型,并指定要使用的 Loss Function 和 Optimizer,实现代码如下所示:
# Create model model = LeNet5Model().to(device) print(model) # Define loss function and optimizer loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
使用 SWA,直接在以往的编程基础上,复用上面创建的一些对象即可。
下面,实现基于 SWA 训练模型的核心逻辑代码,如下所示:
from torch.optim.swa_utils import AveragedModel, SWALR from torch.optim.lr_scheduler import CosineAnnealingLR def swa_train(epoch, train_loader, test_loader, model, loss_fn, optimizer, swa_start): # scheduler = CosineAnnealingLR(optimizer, T_max=10) scheduler = SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05) swa_scheduler = SWALR(optimizer, swa_lr=0.05) swa_model = AveragedModel(model) size = len(train_loader.dataset) for batch, (X, y) in enumerate(train_loader): optimizer.zero_grad() loss = loss_fn(model(X), y) loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") if epoch > swa_start: swa_model.update_parameters(model) swa_scheduler.step() else: scheduler.step() return swa_model
在模型训练过程中,使用 swa_start 控制启动 SWA 更新模型参数,并使用 swa_scheduler 来调节学习率,否则就使用默认的学习率工具控制训练过程。
可以看到,在训练的迭代过程中,并没有对模型进行预测调用,所以 Batch Norm 层也就没有计算过神经网络中这些激活统计量。所以,在模型训练结束后,应用训练过程中使用的训练数据,对 swa_model 模型进行一次 forward 计算并更新这些统计量的值。
上面的 test() 函数,实现了使用测试集 test_dataloader 来计算 loss 值和测试精度,实现代码如下:
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
现在,就可以调用上面提供的 SWA 模型训练函数 swa_train(),训练我们前面定义的 LeNet 模型,代码如下:
epochs = 5 swa_start = 3 for epoch in range(epochs): print(f"Epoch {epoch + 1}\n-------------------------------") swa_model = swa_train(epoch, train_dataloader, test_dataloader, model, loss_fn, optimizer, swa_start) # Update bn statistics for the swa_model at the end torch.optim.swa_utils.update_bn(train_dataloader, swa_model) test(test_loader, swa_model, loss_fn)
在训练期间,SWA 权重并没有被用来进行预测,所以 BatchNorm 层在训练结束时并没有计算模型激活函数的统计量。我们需要在训练 SWA 模型的训练数据集上,再执行一次 forward 计算来得到这些统计量。从上面代码可以看到,BatchNorm 层在模型训练期结束后会调用函数 update_bn() 来计算激活统计量,并更新模型。
运行上面代码,训练过程输出信息,示例如下:
Epoch 1 ------------------------------- loss: 0.034225 [ 64/60000] loss: 0.075405 [ 6464/60000] loss: 0.097025 [12864/60000] loss: 0.041041 [19264/60000] loss: 0.006614 [25664/60000] loss: 0.075064 [32064/60000] loss: 0.108254 [38464/60000] loss: 0.035989 [44864/60000] loss: 0.211421 [51264/60000] loss: 0.092199 [57664/60000] Epoch 2 ------------------------------- loss: 0.014172 [ 64/60000] loss: 0.079780 [ 6464/60000] loss: 0.082239 [12864/60000] loss: 0.055328 [19264/60000] loss: 0.006914 [25664/60000] loss: 0.071126 [32064/60000] loss: 0.106919 [38464/60000] loss: 0.031134 [44864/60000] loss: 0.196080 [51264/60000] loss: 0.096906 [57664/60000] Epoch 3 ------------------------------- loss: 0.015581 [ 64/60000] loss: 0.083868 [ 6464/60000] loss: 0.075796 [12864/60000] loss: 0.057100 [19264/60000] loss: 0.007014 [25664/60000] loss: 0.068562 [32064/60000] loss: 0.105347 [38464/60000] loss: 0.029306 [44864/60000] loss: 0.185119 [51264/60000] loss: 0.099400 [57664/60000] Epoch 4 ------------------------------- loss: 0.015702 [ 64/60000] loss: 0.086666 [ 6464/60000] loss: 0.071594 [12864/60000] loss: 0.056831 [19264/60000] loss: 0.007097 [25664/60000] loss: 0.066988 [32064/60000] loss: 0.104115 [38464/60000] loss: 0.028361 [44864/60000] loss: 0.176726 [51264/60000] loss: 0.100689 [57664/60000] Epoch 5 ------------------------------- loss: 0.015535 [ 64/60000] loss: 0.088614 [ 6464/60000] loss: 0.052040 [12864/60000] loss: 0.103008 [19264/60000] loss: 0.020882 [25664/60000] loss: 0.061086 [32064/60000] loss: 0.115166 [38464/60000] loss: 0.025734 [44864/60000] loss: 0.156586 [51264/60000] loss: 0.138262 [57664/60000] Test Error: Accuracy: 98.4%, Avg loss: 0.046543
3 使用 SWA 模型
使用模型,示例代码如下所示:
classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] swa_model.eval() for i in range(10): start = 13 * i + 8 x, y = test_dataset[start][0], test_dataset[start][1] with torch.no_grad(): x = x.to(device) pred = swa_model(x) predicted, actual = classes[pred[0].argmax(0)], classes[y] print(f'Predicted: "{predicted:<12s}", Actual: "{actual:<12s}"')
可以看到,使用模型预测的结果示例:
Predicted: "Sandal ", Actual: "Sandal " Predicted: "Shirt ", Actual: "Shirt " Predicted: "Sneaker ", Actual: "Sneaker " Predicted: "Pullover ", Actual: "Pullover " Predicted: "Sneaker ", Actual: "Sneaker " Predicted: "Ankle boot ", Actual: "Ankle boot " Predicted: "Sneaker ", Actual: "Sneaker " Predicted: "Ankle boot ", Actual: "Ankle boot " Predicted: "Dress ", Actual: "Dress " Predicted: "Ankle boot ", Actual: "Ankle boot "
总结
我们从使用 SWA 能为我们带来的优势出发,并结合 SWA 论文实验给定的一些结论,对 SWA 进行总结:
SWA 使用了一个可以在训练模型过程中不断修改的学习率调节器(Learning Rate Schedule),能够更快地到达最优解附近区域。而且,模型训练的时候,降低迭代次数(epochs)也能够很快得到比较高的精度。
在 SWA 论文中,使用预激活(Preactivation) ResNet-164 模型,在 CIFAR-100 数据集上分别使用 SWA 和 SGD 训练模型,得到结果如下图所示:
上图中表明,使用 SGD,训练模型得到的 Train Loss 是最优的,但是对应的 Test Error 并不是最优的,说明这个最优值可能是一个局部最优值;而使用 SWA 得到的 Train Loss 不是最优的,但是对应的 Test Error 却是最优的,说明 SWA 得到的 Train Loss 是一个比 SGD 更优的全局解。
从泛化性能方面来考虑,对比 SGD 和 SWA,如下图所示:
使用 SGD,得到的 Train Loss 是一个最优值,但是位于边界位置上,所以对应的 Test Error 变化幅度就会相对更大,从图中可以看到 Train Loss 是最优值但对应的 Test Error 并不是。
使用 SWA,能够得到一个解,它使 Train Loss 集中在一个足够宽的平滑区域范围内,而在这个区域内得到的对应 Test Error 是全局最优解,这样就能更有潜力使获得的模型具有更好的泛化性能。
SWA 不仅可以用于 SGD 优化器,还可以用于其他一些优化器,比如 Adam。
SWA 还有其他一些扩展,如 SWAG、MultiSWAG、SWALP、SWAP,在对应的应用领域内,都能够得到比较好的效果。
另外,对于基于大模型的预训练场景,鉴于 SWA 的加快模型训练速度的优势,也能够更加广泛地被应用。
参考资源