IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    使用 PyTorch SWA 优化模型训练实践

    Yanjun发表于 2024-01-27 16:14:50
    love 0

    PyTorch 实现了 SWA(Stochastic Weight Averaging,随机加权平均),相比于传统的 SGD,使用 SWA 能够明显改善一些深度神经网络模型的测试精度(Test Accuracy)。而且,SWA 使用起来非常简单,能够加速模型训练,并提高模型的泛化能力。

    SWA 基本原理

    SWA 依赖两个重要的因素:
    第一个是,SWA 使用一个不断修改的 LR 调节器(Learning Rate Schedule),使得 SGD 能够在最优值附近进行调整,并评估最优解附近的值对应的模型的精度,而不是只选取最优解对应的模型。因为,最优解对应的模型不一定是最优的,而且泛化能力可能也不一定最好。比如,在 75% 的训练时间里,可以使用一个标准的衰减学习率(Decaying Learning Rate)策略,然后在剩余 25% 的训练时间里将学习率设置为一个比较高的固定值。如下图所示:
    SWA
    第二个是,SWA 计算的是 SGD 遍历过的神经网络权重的平均值。例如,上面提到模型训练的后 25% 时间,我们可以在这 25% 时间里的每一轮训练(every epoch)后,计算一个权重的 running 平均值,在训练结束后再设置网络模型的权重为 SWA 权重平均值。

    SWA 论文提供了对其算法的过程的描述,如下图所示:
    SWA-Alogrithm
    上面算法描述,给出了下面几个重要的内容:

    • 学习率在模型训练过程中是不断调整的,同时更新 SGD 梯度值
    • 基于模型训练的迭代次数,设置在训练过程中不同的阶段(位置)更新两个不同的权重平均值
    • 在模型训练结束后,为 SWA 权重计算 BatchNorm 统计量

    SWA 使用要点

    在 PyTorch 中,使用 SWA 训练模型的基本要点,描述如下:

    • 创建 SWA 模型

    创建 SWA 模型,直接使用 PyTorch 提供的 AveragedModel,传入上面创建的 model,model 可以是任意继承了 torch.nn.Module 模型:

    swa_model = AveragedModel(model)
    

    这样 swa_model 会在模型训练过程中持续跟踪 model 的参数的平均值。如果希望更新这个 running 平均值,需要在 optimizer.step() 之后调用 update_parameters() 函数:

    swa_model.update_parameters(model)
    
    • 使用 SWALR

    使用 SWA 时,通常配合使用 SWALR 这个 Learning Rate Scheduler 来退火(anneals)至一个固定的常量值,然后一直保持不变,使用示例:

    swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
    # 或者
    swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, anneal_strategy="cos", anneal_epochs=5, swa_lr=0.05)
    
    • 计算模型在 DataLoader 上的 BN 统计量

    在模型训练结束之后,需要进行计算 BN(Batch Normalization)统计量并更新模型:

    torch.optim.swa_utils.update_bn(train_dataloader, swa_model)
    

    SWA 编程实践

    下面,我们基于上面提到的使用 SWA 的方法训练模型,通过实际编程来加强理解:

    1 准备数据集并定义模型

    我们使用 MNIST 数据集,神经网络模型使用 LeNet-5,代码处理逻辑如下所示:

    import torch
    from torch import nn
    from torch.utils.data import DataLoader
    from torchvision import datasets
    from torchvision.transforms import ToTensor
    
    # Download training/test data from open datasets.
    train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor(),)
    test_dataset = datasets.MNIST(root="data", train=False, download=False, transform=ToTensor(),)
    print(f"train_dataset_size = {len(train_dataset)}, test_dataset_size = {len(test_dataset)}")
    
    batch_size = 64
    
    # Create data loaders.
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    
    for X, y in test_dataloader:
        print(f"Shape of X [N, C, H, W]: {X.shape}")
        print(f"Shape of y: {y.shape} {y.dtype}")
        break
    
    # Get cpu, gpu or mps device for training.
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")
    
    # Define model
    class LeNet5Model(nn.Module):
        def __init__(self):
            super().__init__()
            self._conv = nn.Sequential(
                nn.Conv2d(1, 6, 5, 1),
                nn.MaxPool2d(2),
                nn.Conv2d(6, 16, 5, 1),
                nn.MaxPool2d(2)
            )
            self._fc = nn.Sequential(
                nn.Linear(4*4*16, 120),
                nn.Linear(120, 84),
                nn.Linear(84, 10)
            )
    
        def forward(self, x):
            x = self._conv(x)
            x = x.view(-1, 4 * 4 * 16)
            x = self._fc(x)
            return x
    

    代码比较容易,不再赘述。

    2 使用 SWA 训练模型

    首先,我们创建模型,并指定要使用的 Loss Function 和 Optimizer,实现代码如下所示:

    # Create model
    model = LeNet5Model().to(device)
    print(model)
    
    # Define loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    

    使用 SWA,直接在以往的编程基础上,复用上面创建的一些对象即可。
    下面,实现基于 SWA 训练模型的核心逻辑代码,如下所示:

    from torch.optim.swa_utils import AveragedModel, SWALR
    from torch.optim.lr_scheduler import CosineAnnealingLR
    
    def swa_train(epoch, train_loader, test_loader, model, loss_fn, optimizer, swa_start):
        # scheduler = CosineAnnealingLR(optimizer, T_max=10)
        scheduler = SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
        swa_scheduler = SWALR(optimizer, swa_lr=0.05)
        swa_model = AveragedModel(model)
        size = len(train_loader.dataset)
        for batch, (X, y) in enumerate(train_loader):
            optimizer.zero_grad()
            loss = loss_fn(model(X), y)
            loss.backward()
            optimizer.step()
            if batch % 100 == 0:
                loss, current = loss.item(), (batch + 1) * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
                    
            if epoch > swa_start:
                swa_model.update_parameters(model)
                swa_scheduler.step()
            else:
                scheduler.step()
        return swa_model
    

    在模型训练过程中,使用 swa_start 控制启动 SWA 更新模型参数,并使用 swa_scheduler 来调节学习率,否则就使用默认的学习率工具控制训练过程。
    可以看到,在训练的迭代过程中,并没有对模型进行预测调用,所以 Batch Norm 层也就没有计算过神经网络中这些激活统计量。所以,在模型训练结束后,应用训练过程中使用的训练数据,对 swa_model 模型进行一次 forward 计算并更新这些统计量的值。
    上面的 test() 函数,实现了使用测试集 test_dataloader 来计算 loss 值和测试精度,实现代码如下:

    def test(dataloader, model, loss_fn):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    

    现在,就可以调用上面提供的 SWA 模型训练函数 swa_train(),训练我们前面定义的 LeNet 模型,代码如下:

    epochs = 5
    swa_start = 3
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}\n-------------------------------")
        swa_model = swa_train(epoch, train_dataloader, test_dataloader, 
                  model, loss_fn, optimizer, swa_start) 
    
    # Update bn statistics for the swa_model at the end
    torch.optim.swa_utils.update_bn(train_dataloader, swa_model)
    test(test_loader, swa_model, loss_fn)
    

    在训练期间,SWA 权重并没有被用来进行预测,所以 BatchNorm 层在训练结束时并没有计算模型激活函数的统计量。我们需要在训练 SWA 模型的训练数据集上,再执行一次 forward 计算来得到这些统计量。从上面代码可以看到,BatchNorm 层在模型训练期结束后会调用函数 update_bn() 来计算激活统计量,并更新模型。
    运行上面代码,训练过程输出信息,示例如下:

    Epoch 1
    -------------------------------
    loss: 0.034225  [   64/60000]
    loss: 0.075405  [ 6464/60000]
    loss: 0.097025  [12864/60000]
    loss: 0.041041  [19264/60000]
    loss: 0.006614  [25664/60000]
    loss: 0.075064  [32064/60000]
    loss: 0.108254  [38464/60000]
    loss: 0.035989  [44864/60000]
    loss: 0.211421  [51264/60000]
    loss: 0.092199  [57664/60000]
    Epoch 2
    -------------------------------
    loss: 0.014172  [   64/60000]
    loss: 0.079780  [ 6464/60000]
    loss: 0.082239  [12864/60000]
    loss: 0.055328  [19264/60000]
    loss: 0.006914  [25664/60000]
    loss: 0.071126  [32064/60000]
    loss: 0.106919  [38464/60000]
    loss: 0.031134  [44864/60000]
    loss: 0.196080  [51264/60000]
    loss: 0.096906  [57664/60000]
    Epoch 3
    -------------------------------
    loss: 0.015581  [   64/60000]
    loss: 0.083868  [ 6464/60000]
    loss: 0.075796  [12864/60000]
    loss: 0.057100  [19264/60000]
    loss: 0.007014  [25664/60000]
    loss: 0.068562  [32064/60000]
    loss: 0.105347  [38464/60000]
    loss: 0.029306  [44864/60000]
    loss: 0.185119  [51264/60000]
    loss: 0.099400  [57664/60000]
    Epoch 4
    -------------------------------
    loss: 0.015702  [   64/60000]
    loss: 0.086666  [ 6464/60000]
    loss: 0.071594  [12864/60000]
    loss: 0.056831  [19264/60000]
    loss: 0.007097  [25664/60000]
    loss: 0.066988  [32064/60000]
    loss: 0.104115  [38464/60000]
    loss: 0.028361  [44864/60000]
    loss: 0.176726  [51264/60000]
    loss: 0.100689  [57664/60000]
    Epoch 5
    -------------------------------
    loss: 0.015535  [   64/60000]
    loss: 0.088614  [ 6464/60000]
    loss: 0.052040  [12864/60000]
    loss: 0.103008  [19264/60000]
    loss: 0.020882  [25664/60000]
    loss: 0.061086  [32064/60000]
    loss: 0.115166  [38464/60000]
    loss: 0.025734  [44864/60000]
    loss: 0.156586  [51264/60000]
    loss: 0.138262  [57664/60000]
    Test Error: 
     Accuracy: 98.4%, Avg loss: 0.046543 
    

    3 使用 SWA 模型

    使用模型,示例代码如下所示:

    classes = [
        "T-shirt/top",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]
    
    swa_model.eval()
    for i in range(10):
        start = 13 * i + 8
        x, y = test_dataset[start][0], test_dataset[start][1]
        with torch.no_grad():
            x = x.to(device)
            pred = swa_model(x)
            predicted, actual = classes[pred[0].argmax(0)], classes[y]
            print(f'Predicted: "{predicted:<12s}", Actual: "{actual:<12s}"')
    

    可以看到,使用模型预测的结果示例:

    Predicted: "Sandal      ", Actual: "Sandal      "
    Predicted: "Shirt       ", Actual: "Shirt       "
    Predicted: "Sneaker     ", Actual: "Sneaker     "
    Predicted: "Pullover    ", Actual: "Pullover    "
    Predicted: "Sneaker     ", Actual: "Sneaker     "
    Predicted: "Ankle boot  ", Actual: "Ankle boot  "
    Predicted: "Sneaker     ", Actual: "Sneaker     "
    Predicted: "Ankle boot  ", Actual: "Ankle boot  "
    Predicted: "Dress       ", Actual: "Dress       "
    Predicted: "Ankle boot  ", Actual: "Ankle boot  "
    

    总结

    我们从使用 SWA 能为我们带来的优势出发,并结合 SWA 论文实验给定的一些结论,对 SWA 进行总结:

    • 使用 SWA 非常简单,有效提高模型训练性能

    SWA 使用了一个可以在训练模型过程中不断修改的学习率调节器(Learning Rate Schedule),能够更快地到达最优解附近区域。而且,模型训练的时候,降低迭代次数(epochs)也能够很快得到比较高的精度。

    • 得到更合适的模型参数

    在 SWA 论文中,使用预激活(Preactivation) ResNet-164 模型,在 CIFAR-100 数据集上分别使用 SWA 和 SGD 训练模型,得到结果如下图所示:
    Weight-Comparison-SGD-vs-SWA
    上图中表明,使用 SGD,训练模型得到的 Train Loss 是最优的,但是对应的 Test Error 并不是最优的,说明这个最优值可能是一个局部最优值;而使用 SWA 得到的 Train Loss 不是最优的,但是对应的 Test Error 却是最优的,说明 SWA 得到的 Train Loss 是一个比 SGD 更优的全局解。

    • 更好的泛化性能

    从泛化性能方面来考虑,对比 SGD 和 SWA,如下图所示:
    SWA-Generalization
    使用 SGD,得到的 Train Loss 是一个最优值,但是位于边界位置上,所以对应的 Test Error 变化幅度就会相对更大,从图中可以看到 Train Loss 是最优值但对应的 Test Error 并不是。
    使用 SWA,能够得到一个解,它使 Train Loss 集中在一个足够宽的平滑区域范围内,而在这个区域内得到的对应 Test Error 是全局最优解,这样就能更有潜力使获得的模型具有更好的泛化性能。

    • 使用 SWA 应用范围比较广泛

    SWA 不仅可以用于 SGD 优化器,还可以用于其他一些优化器,比如 Adam。
    SWA 还有其他一些扩展,如 SWAG、MultiSWAG、SWALP、SWAP,在对应的应用领域内,都能够得到比较好的效果。
    另外,对于基于大模型的预训练场景,鉴于 SWA 的加快模型训练速度的优势,也能够更加广泛地被应用。

    参考资源

    • https://pytorch.org/docs/stable/optim.html#weight-averaging-swa-and-ema
    • https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
    • Averaging Weights Leads to Wider Optima and Better Generalization


沪ICP备19023445号-2号
友情链接