本文主要介绍了PyTorch中的torch.nn.Embedding()
和torch.nn.EmbeddingBag()
两种词嵌入方法。torch.nn.Embedding()
用于创建固定大小的词嵌入矩阵,将输入的整数索引映射到对应的词向量。而torch.nn.EmbeddingBag()
在torch.nn.Embedding()
的基础上,提供了处理变长序列的池化功能,可以计算每个序列的平均池化或求和池化结果,适合处理可变长度的文本输入。此外,文章还提到了如何使用EmbeddingBag层和Linear层快速创建文本分类模型,并详细介绍了文本分类的流程、词表、索引、EmbeddingBag层和偏移值等概念。最后,文章还给出了一个文本分类架构的示例代码,展示了如何使用PyTorch和TorchText库进行文本分类。
除了前面说到的torch.nn.Embedding()
方法,PyTorch还提供了torch.nn.EmbeddingBag()
一种聚合方法。聚合方法有求和、求均值、求最大值等。
torch.nn.Embedding()
:这个类用于创建一个固定大小的词嵌入矩阵,其中每一行代表一个词的词向量。它将输入的整数索引映射到对应的词向量。
torch.nn.EmbeddingBag()
:这个类在 torch.nn.Embedding()
的基础上提供了更多的功能。它可以用于根据不定长度的输入序列计算每个序列的平均池化(average pooling)或者求和池化(sum pooling)的结果。它可以处理变长序列,并且支持权重,这使得它在处理可变长度的文本输入时非常有用。
总结:torch.nn.Embedding()
主要用于将整数索引映射到固定维度的词向量,而 torch.nn.EmbeddingBag()
则更适合处理可变长度的文本输入,并提供了对变长序列的池化功能。
简单文本分类模型中可使用EmbedingBag层加Linear层快速创建文本分类模型。这样的模型创建起来非常便捷,计算效率也很高。
缺点:这种聚合方式非常简单粗暴,而导致文本之间的关系缺失,导致精度比较差。
简单的文本分类流程:
下图是PyTorch官方文档的一个文本分类示例。使用TorchText库进行处理。
这相当于是一个字典,预先建立好的,每个单词都对应一个唯一的索引。比如“hello”可能对应索引1,“world”可能对应索引2,依此类推。这个步骤是在预处理阶段完成的,确保每个单词都有一个唯一的数字表示。
输入到模型中的单词索引,也就是上一步中词表生成的数字。
类似于一个查找表,可以将单词索引转化为更高维度的向量表示。
EmbeddingBag的特殊之处在于它可以处理变长的序列。即使每条评论长度不一样,它也能处理而不需要填充成相同长度。
由于每条评论长度不同,我们需要记录每条评论的起始位置。
比如如果有三条评论,第一个偏移量是0(从第0个单词开始),第二个偏移量可能是5(第一个评论有5个单词,第二个评论从第6个单词开始),依此类推。
这是一个简单的全连接层,它接受EmbeddingBag输出的向量并进行分类。最终的输出就是模型的预测结果。
将每条评论转换成单词索引序列。比如评论["hello world"]会被转换成[1, 2]。
由于批次中的评论长度不一样,使用EmbeddingBag时,可以将所有评论合并成一个长序列。
比如批次中有两条评论["hello", "world peace"],合并成[1, 2, 3]。
记录偏移值[0, 1],表示第一条评论从第0个索引开始,第二条评论从第1个索引开始。
定义一个批次处理函数,这个函数会对每个批次的数据进行预处理,生成上述的长序列和偏移值。
在DataLoader中通过collate_fn参数指定这个函数。
def collate_batch(batch):
text_list, label_list, offsets = [], [], [0]
for (text, label) in batch:
text_list.append(torch.tensor(text, dtype=torch.int64))
label_list.append(label)
offsets.append(text_list[-1].size(0))
text_list = torch.cat(text_list)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
label_list = torch.tensor(label_list, dtype=torch.int64)
return text_list, offsets, label_list
# 使用DataLoader加载数据
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_batch)