IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    基于 PyTorch 编程使用预训练模型

    Yanjun发表于 2024-01-29 15:46:29
    love 0

    使用预训练模型有两种方式:一种是直接使用得到的预训练模型进行推理,并满足应用的需要,使用起来非常简单;另一种是在预训练模型的基础上,进行微调,使得到的新模型能够更好地满足我们解决问题的需要,这种方式需要能够对模型进行调优有一定门槛。这里,我们尝试第一种方式直接使用预训练模型,着重关注使用预训练模型处理图片分类的过程,从而熟悉在实际应用中都需要做哪些处理工作。

    预训练模型

    预训练模型(Pre-trained Models, PTMs)是一种深度学习架构,它在大规模数据集上进行训练,以获取丰富的特征表示。训练得到的模型可以进行复用,不仅能够适用于最初要解决的问题,还可以迁移到其他类似的应用场景中,从而提高在这些新领域的应用的性能。
    预训练模型通常具有较大的参数规模,需要使用海量的数据和高昂的计算资源代价,才能完成模型训练并最终得到模型参数,这对于一些不具备基于超大规模数据训练能力的使用者来说,就无法发挥模型的作用,而且也不能很方便地在特定应用领域内探索并验证一些应用的想法。
    例如,在 NLP 领域,预训练模型应用的特别广泛,因为它们可以从海量的文本数据中学习到有用的语义信息。而从头开始训练这些 NLP 模型需要大量的计算资源,这对于基于此类模型的下游应用场景几乎是不可能的,如解决诸如语言理解、机器翻译、自动问答等问题都受到了极大的限制。所以使用预训练模型,可以极大地降低下游应用场景使用的代价和复杂度,而把精力聚焦在特定的场景的问题上。通过直接使用预训练模型,或者进行简单的微调就能够很好地完成下游的一些任务,像文本分类、序列标注和阅读理解等,从而实现性能的提升。
    在 PyTorch 中内置了很多预训练模型,我们可以直接通过 torchvision.models 提供的 API 来使用。查看当前模型库里面有哪些预训练模型:

    from torchvision import models
    dir(models)
    

    可以看到,有很多可以使用的经典神经网络的预训练模型,如 AleNet、ResNet、GoogLeNet、VGG 等,示例如下所示:

    ['AlexNet',
     'AlexNet_Weights',
     'ConvNeXt',
     'ConvNeXt_Base_Weights',
     'ConvNeXt_Large_Weights',
     'ConvNeXt_Small_Weights',
     'ConvNeXt_Tiny_Weights',
     'DenseNet',
     ... ...
     'ResNet',
     'ResNet101_Weights',
     'ResNet152_Weights',
     'ResNet18_Weights',
     'ResNet34_Weights',
     'ResNet50_Weights',
     'ShuffleNetV2',
     'ShuffleNet_V2_X0_5_Weights',
     'ShuffleNet_V2_X1_0_Weights',
     'ShuffleNet_V2_X1_5_Weights',
     'ShuffleNet_V2_X2_0_Weights',
     'SqueezeNet',
     ... ...
     'vgg',
     'vgg11',
     'vgg11_bn',
     'vgg13',
     ... ...]
    

    根据我们实际的资源和应用需求,可以选择对应的预训练模型来实现推理功能。
    下面,我们使用具有 18 层深度的 ResNet 预训练神经网络模型,来说明如何对指定的任意图片进行分类。

    使用预训练模型 ResNet-18 分类图片

    下面我们把使用预训练模型的过程,分为 4 个步骤进行操作:

    1 获取 ImageNet 数据集 label

    ResNet 基于 ImageNet 数据集进行训练,所以 label 也是来自 ImageNet。可以从网上下载对应的 label 文件,我找到的是 caffe_ilsvrc12.tar.gz,解压缩后可以得到一个 synset_words.txt 文件,里面是关于图片的 1000 个 label,内容示例:

    n01440764 tench, Tinca tinca
    n01443537 goldfish, Carassius auratus
    n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
    n01491361 tiger shark, Galeocerdo cuvieri
    n01494475 hammerhead, hammerhead shark
    n01496331 electric ray, crampfish, numbfish, torpedo
    

    可以直接处理并提取文件中第二列的 label 内容:

    with open("./synset_words.txt") as f:
            classes = [line.split(" ")[1] for line in f.readlines()]
    print(len(classes))
    

    把 label 名称直接加载到 classes 数组中,后面使用预训练模型推理后,需要找到对应的 label 名称。

    2 加载 ResNet-18 预训练模型

    直接使用 models.resnet18,会下载模型参数,并加载到内存:

    from torchvision import models
    resnet = models.resnet18(pretrained=True)
    print(resnet)
    

    可以看到 ResNet-18 模型的结构,如下所示:

    ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (fc): Linear(in_features=512, out_features=1000, bias=True)
    )
    

    可以看到,ResNet-18 网络模型的结构包含了哪些层,以及对应参数情况。
    另外,也可以使用 PyTorch Hub 提供的 API 来直接加载对应的预训练模型,下载后继续使用模型。PyTorch Hub 提供了几种方式,我们通过代码片段简单说明如下,不过多实践了:

    • 使用 torch.hub.load() 获取预训练模型
    model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
    
    • 使用 torch.hub.load_state_dict_from_url() 获取预训练模型
    state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
    
    • 使用 torch.hub.download_url_to_file() 获取预训练模型
    torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
    

    3 预处理待分类图片

    我们随便找了一张带兔子的图片 ./myimages/rabbit.jpeg,你也可以拿其他图片测试:
    rabbit
    然后对输入图片进行处理,得到满足 ResNet-18 模型推理输入要求的 Tensor,如下所示:

    # Define preprocessing transformers chain
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    img = Image.open("./myimages/rabbit.jpeg")
    img = preprocess(img)
    input_tensor = torch.unsqueeze(img, 0)
    

    使用上面得到的 input_tensor 就可以使用模型进行推理。

    4 推理并得到图片 label

    resnet.eval()
    output = resnet(input_tensor)
    _, indices = torch.sort(output, descending=True)
    [(I, classes[i], percentage[i].item()) for i in indices[0][:3]]
    

    上面代码通过 output = resnet(input_tensor) 进行推理,得到一个 Shape 是 torch.Size([1, 1000]) 的结果 Tensor,这里面并没有直接给出分类的 label,我们需要处理一下:对其进行降序排序,并取出 Top 3 最大的分值,并转换成百分比,表示输入图片属于某一个 label 的概率;然后,计算得到这 3 个分值对应索引位置;最后,根据索引位置从 classes 数组得到对应的 label 名称。
    运行代码,输出结果如下:

    [(tensor(331), 'hare\n', 71.56768035888672),
     (tensor(330), 'wood', 27.239879608154297),
     (tensor(332), 'Angora,', 0.4852880835533142)]
    

    通过结果可以看到,我们输入图片经过模型推理,得到第一个更优的 label 是 hare,“野兔”的意思,白兔的图片确实和 hare 这个 label 更加接近。

    参考资源

    • https://pytorch.org/hub/
    • https://pytorch.org/docs/stable/hub.html
    • Deep Learning with PyTorch


沪ICP备19023445号-2号
友情链接