IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    Distributed Data-Parallel training of PyTorch

    RobinDong发表于 2024-01-17 23:15:21
    love 0

    Let’s get to the point directly:

    import os
    import time
    
    import torch
    import torch.nn as nn
    import torch.distributed as dist
    
    from model import resnet152
    from dataset import get_data_loaders
    from torch.nn.parallel import DistributedDataParallel as DDP
    
    learning_rate = 0.001
    num_epochs = 40
    momentum = 0.9
    weight_decay = 1e-5
    
    
    def setup():
        # initialize the process group
        dist.init_process_group("nccl")
    
    
    def cleanup():
        dist.destroy_process_group()
    
    
    def train(rank, world_size):
        setup()
    
        model = resnet152().to(rank)
        model = DDP(model)
    
        if rank == 0 and os.path.exists("last.pth"):
            obj = torch.load("last.pth")
            print(f"Rank{rank} load 'last.pth' with epoch: {obj['epoch']}")
            model.load_state_dict(obj["model"])
            begin = obj["epoch"] + 1
        else:
            begin = 0
        print(f"Rank{rank} begin at {begin}")
    
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
        start = time.time()
        running_loss = 0
        trainloader, testloader = get_data_loaders(rank, world_size)
    
        for epoch in range(begin, num_epochs):
            trainloader.sampler.set_epoch(epoch)
            for index, (images, labels) in enumerate(trainloader):
                # gpu
                images, labels = images.to(rank), labels.to(rank)
    
                outputs = model(images)
    
                loss = criterion(outputs, labels)
    
                # backward and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
    
            # train
            correct = 0
            total = 0
            with torch.no_grad():
                for data in trainloader:
                    images, labels = data
    
                    # gpu
                    images, labels = images.to(rank), labels.to(rank)
    
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            trainset_accu = 100 * correct / total
    
            # test
            correct = 0
            total = 0
            with torch.no_grad():
                for data in testloader:
                    images, labels = data
                    # gpu
                    images, labels = images.to(rank), labels.to(rank)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            testset_accu = 100 * correct / total
            if rank == 0:
                print(
                    f"[{epoch}] Accu: {trainset_accu:.2f}%, {testset_accu:.2f}% \
                        | {(time.time() - start)/60.0:.1f} mins, loss: {running_loss}"
                )
                torch.save(model.state_dict(), f"cifar100_{epoch}.pth")
                torch.save({"model": model.state_dict(), "epoch": epoch}, "last.pth")
            running_loss = 0.0
    
        end = time.time()
        stopWatch = end - start
        print("Training is done")
        print("Total Training Time (second):", stopWatch)
        cleanup()
    
    
    if __name__ == "__main__":
        local_rank = int(os.environ["LOCAL_RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        train(local_rank, world_size)
    

    The main training code comes from this notebook (really appreciate to @batuhan3526). To run this snippet on two nodes (every node has two GPUs), I need to use the powerful “torchrun“:

    torchrun \
      --rdzv-backend=c10d \
      --rdzv-endpoint=rogpt1:23456 \
      --nnodes=1:2 \
      --max-restarts=3 \
      --nproc-per-node=2 \
      train.py

    For the above snippet, the Rank-0 process will save the checkpoint for each node. If one process fails, the whole cluster will restart and resume training from epoch + 1.

    I tried letting only the Rank-0 process on node-0 save the checkpoint once. However since other nodes won’t have the checkpoint to load, the restart failed with a dead loop.



沪ICP备19023445号-2号
友情链接