IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    NLP学习笔记(3) - 简单文本分类模型架构与embedingbag

    52txr发表于 2024-05-17 20:11:00
    love 0

    本文主要介绍了PyTorch中的torch.nn.Embedding()和torch.nn.EmbeddingBag()两种词嵌入方法。torch.nn.Embedding()用于创建固定大小的词嵌入矩阵,将输入的整数索引映射到对应的词向量。而torch.nn.EmbeddingBag()在torch.nn.Embedding()的基础上,提供了处理变长序列的池化功能,可以计算每个序列的平均池化或求和池化结果,适合处理可变长度的文本输入。此外,文章还提到了如何使用EmbeddingBag层和Linear层快速创建文本分类模型,并详细介绍了文本分类的流程、词表、索引、EmbeddingBag层和偏移值等概念。最后,文章还给出了一个文本分类架构的示例代码,展示了如何使用PyTorch和TorchText库进行文本分类。

    EmbeddingBag聚合包

    除了前面说到的torch.nn.Embedding()方法,PyTorch还提供了torch.nn.EmbeddingBag()一种聚合方法。聚合方法有求和、求均值、求最大值等。

    torch.nn.Embedding():这个类用于创建一个固定大小的词嵌入矩阵,其中每一行代表一个词的词向量。它将输入的整数索引映射到对应的词向量。

    torch.nn.EmbeddingBag():这个类在 torch.nn.Embedding() 的基础上提供了更多的功能。它可以用于根据不定长度的输入序列计算每个序列的平均池化(average pooling)或者求和池化(sum pooling)的结果。它可以处理变长序列,并且支持权重,这使得它在处理可变长度的文本输入时非常有用。

    总结:torch.nn.Embedding() 主要用于将整数索引映射到固定维度的词向量,而 torch.nn.EmbeddingBag() 则更适合处理可变长度的文本输入,并提供了对变长序列的池化功能。

    简单文本分类

    简单文本分类模型中可使用EmbedingBag层加Linear层快速创建文本分类模型。这样的模型创建起来非常便捷,计算效率也很高。

    缺点:这种聚合方式非常简单粗暴,而导致文本之间的关系缺失,导致精度比较差。

    简单的文本分类流程:

    • 词表将单词映射到索引。
    • 使用EmbeddingBag层对文本做词嵌入聚合
    • 这个层会对每一条评论中的文本单词做embeding词嵌入,并使用默认模式“mean”计算embeding的平均值
    • 输出一个聚合结果
    • 在这一层基础上添加分类器
    • 即可快速创建一个文本分类模型

    简单文本分类架构

    下图是PyTorch官方文档的一个文本分类示例。使用TorchText库进行处理。

    文本分类示例官方

    词表(word look-up table)

    这相当于是一个字典,预先建立好的,每个单词都对应一个唯一的索引。比如“hello”可能对应索引1,“world”可能对应索引2,依此类推。这个步骤是在预处理阶段完成的,确保每个单词都有一个唯一的数字表示。

    索引(index1, index2, ..., indexn)

    输入到模型中的单词索引,也就是上一步中词表生成的数字。

    EmbeddingBag层

    类似于一个查找表,可以将单词索引转化为更高维度的向量表示。

    EmbeddingBag的特殊之处在于它可以处理变长的序列。即使每条评论长度不一样,它也能处理而不需要填充成相同长度。

    偏移值

    由于每条评论长度不同,我们需要记录每条评论的起始位置。

    比如如果有三条评论,第一个偏移量是0(从第0个单词开始),第二个偏移量可能是5(第一个评论有5个单词,第二个评论从第6个单词开始),依此类推。

    Linear Layer

    这是一个简单的全连接层,它接受EmbeddingBag输出的向量并进行分类。最终的输出就是模型的预测结果。

    分类架构具体步骤

    (1)预处理文本

    将每条评论转换成单词索引序列。比如评论["hello world"]会被转换成[1, 2]。

    (2)整合批次数据

    由于批次中的评论长度不一样,使用EmbeddingBag时,可以将所有评论合并成一个长序列。

    比如批次中有两条评论["hello", "world peace"],合并成[1, 2, 3]。

    记录偏移值[0, 1],表示第一条评论从第0个索引开始,第二条评论从第1个索引开始。

    (3)定义批次处理函数(collate_fn)

    定义一个批次处理函数,这个函数会对每个批次的数据进行预处理,生成上述的长序列和偏移值。

    在DataLoader中通过collate_fn参数指定这个函数。

    def collate_batch(batch):
        text_list, label_list, offsets = [], [], [0]
        for (text, label) in batch:
            text_list.append(torch.tensor(text, dtype=torch.int64))
            label_list.append(label)
            offsets.append(text_list[-1].size(0))
        text_list = torch.cat(text_list)
        offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
        label_list = torch.tensor(label_list, dtype=torch.int64)
        return text_list, offsets, label_list
    
    # 使用DataLoader加载数据
    from torch.utils.data import DataLoader
    train_dataloader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_batch)


沪ICP备19023445号-2号
友情链接