Embedding
torch.nn.Embedding 是用来将一个数字变成一个指定维度的向量的,比如数字a变成一个10维的向量,数字b变成另外一个10维的向量
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
num_embeddings
词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999)
embedding_dim
嵌入向量的维度,即用多少维来表示一个符号
padding_idx
填充id,比如,输入长度为100,但是每次的句子长度并不一样,后面就需要用统一的数字填充,而这里就是指定这个数字,这样,网络在遇到填充id时,就不会计算其与其它符号的相关性。(初始化为0)
max_norm
最大范数,如果嵌入向量的范数超过了这个界限,就要进行再归一化。
norm_type
指定利用什么范数计算,并用于对比max_norm
scale_grad_by_freq
根据单词在mini-batch中出现的频率,对梯度进行放缩
sparse
若为True,则与权重矩阵相关的梯度转变为稀疏张量。
示例
- 原始语言
['I am a boy.','How are you?','I am very lucky.']
- 标准化(大写转小写,标点分离)
['i','am','a','boy','.'],['how','are','you','?'],['i','am','very','lucky','.']
- 建立词典id与单词映射
[11,12,13,14,15],[16,17,18,19],[11,12,20,21,15]
- 补齐和增加结束标志
[11,12,13,14,15,0],[16,17,18,19,1,0],[11,12,20,21,15,0]
- 调用Embedding
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
batch = [[11,12,13,14,15,0],[16,17,18,19,1,0],[11,12,20,21,15,0]]
embeds = nn.Embedding(22, 3)
hello_idx = torch.LongTensor(batch)
hello_embed = embeds(hello_idx)
print(hello_embed)
输出结果:
tensor([[[-0.0238, 0.8295, -1.3396, -1.3248, 0.0594, 1.3167, -1.0940,
-1.1951, -1.0396, -0.0233],
[ 0.4356, 0.6370, -0.7741, 0.8926, 0.4708, -1.5016, 1.2592,
0.3013, -0.8454, 0.5197],
[ 0.2925, 0.2184, 1.8541, 0.2730, -1.2973, 1.8819, -0.9783,
0.3813, -1.9256, -0.9753],
[-1.2932, 1.5927, -0.1280, -0.1410, -0.6610, -0.4409, 0.8779,
-1.3380, 1.9395, -0.4685],
[-1.6556, 0.0503, -0.0331, -0.1597, -0.5395, -0.3199, 0.7204,
-2.7539, -0.3776, -0.6279],
[ 0.7956, 0.4990, -1.4565, 0.4149, -0.0138, 2.1350, -0.3909,
0.7946, -0.0683, -0.8851]],
[[-0.2052, 0.0779, 0.0293, 0.5325, -1.1952, -0.9435, -1.9977,
0.0736, -0.0971, 0.7654],
[-1.5563, -0.1219, -2.0514, -2.0681, 0.3873, 0.3895, 0.7156,
0.0918, 2.1920, -1.2471],
[-0.8210, -0.5885, 0.4094, 0.5952, -0.5706, 1.2906, 1.3115,
-0.3000, 2.2222, 0.5521],
[ 0.0379, -0.9638, -0.4586, -1.1863, -0.1216, 0.8507, 0.6049,
-0.5133, -0.1359, -0.0879],
[-1.1264, 0.8603, -1.8749, 0.3225, -0.1212, 0.1518, -0.1020,
1.1794, 0.9940, -0.7038],
[ 0.7956, 0.4990, -1.4565, 0.4149, -0.0138, 2.1350, -0.3909,
0.7946, -0.0683, -0.8851]],
[[-0.0238, 0.8295, -1.3396, -1.3248, 0.0594, 1.3167, -1.0940,
-1.1951, -1.0396, -0.0233],
[ 0.4356, 0.6370, -0.7741, 0.8926, 0.4708, -1.5016, 1.2592,
0.3013, -0.8454, 0.5197],
[-0.2668, -1.1652, -2.0050, 1.3835, -0.7573, 1.7355, -0.3102,
-0.4714, -0.6888, 0.1705],
[-0.8642, 0.6668, -0.8236, -0.2389, -0.5548, 2.1566, 0.2045,
-0.5670, 0.5661, -1.2467],
[-1.6556, 0.0503, -0.0331, -0.1597, -0.5395, -0.3199, 0.7204,
-2.7539, -0.3776, -0.6279],
[ 0.7956, 0.4990, -1.4565, 0.4149, -0.0138, 2.1350, -0.3909,
0.7946, -0.0683, -0.8851]]], grad_fn=<EmbeddingBackward0>)