Halo
发布于 2023-07-03 / 236 阅读 / 0 评论 / 0 点赞

Embedding

Embedding

torch.nn.Embedding 是用来将一个数字变成一个指定维度的向量的,比如数字a变成一个10维的向量,数字b变成另外一个10维的向量

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

num_embeddings

词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999)

embedding_dim

嵌入向量的维度,即用多少维来表示一个符号

padding_idx

填充id,比如,输入长度为100,但是每次的句子长度并不一样,后面就需要用统一的数字填充,而这里就是指定这个数字,这样,网络在遇到填充id时,就不会计算其与其它符号的相关性。(初始化为0)

max_norm

最大范数,如果嵌入向量的范数超过了这个界限,就要进行再归一化。

norm_type

指定利用什么范数计算,并用于对比max_norm

scale_grad_by_freq

根据单词在mini-batch中出现的频率,对梯度进行放缩

sparse

若为True,则与权重矩阵相关的梯度转变为稀疏张量。

示例

  1. 原始语言
['I am a boy.','How are you?','I am very lucky.']
  1. 标准化(大写转小写,标点分离)
['i','am','a','boy','.'],['how','are','you','?'],['i','am','very','lucky','.']
  1. 建立词典id与单词映射
[11,12,13,14,15],[16,17,18,19],[11,12,20,21,15]
  1. 补齐和增加结束标志
[11,12,13,14,15,0],[16,17,18,19,1,0],[11,12,20,21,15,0]
  1. 调用Embedding
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
batch = [[11,12,13,14,15,0],[16,17,18,19,1,0],[11,12,20,21,15,0]]
embeds = nn.Embedding(22, 3)
hello_idx = torch.LongTensor(batch)
hello_embed = embeds(hello_idx)
print(hello_embed)

输出结果:

tensor([[[-0.0238,  0.8295, -1.3396, -1.3248,  0.0594,  1.3167, -1.0940,
          -1.1951, -1.0396, -0.0233],
         [ 0.4356,  0.6370, -0.7741,  0.8926,  0.4708, -1.5016,  1.2592,
           0.3013, -0.8454,  0.5197],
         [ 0.2925,  0.2184,  1.8541,  0.2730, -1.2973,  1.8819, -0.9783,
           0.3813, -1.9256, -0.9753],
         [-1.2932,  1.5927, -0.1280, -0.1410, -0.6610, -0.4409,  0.8779,
          -1.3380,  1.9395, -0.4685],
         [-1.6556,  0.0503, -0.0331, -0.1597, -0.5395, -0.3199,  0.7204,
          -2.7539, -0.3776, -0.6279],
         [ 0.7956,  0.4990, -1.4565,  0.4149, -0.0138,  2.1350, -0.3909,
           0.7946, -0.0683, -0.8851]],

        [[-0.2052,  0.0779,  0.0293,  0.5325, -1.1952, -0.9435, -1.9977,
           0.0736, -0.0971,  0.7654],
         [-1.5563, -0.1219, -2.0514, -2.0681,  0.3873,  0.3895,  0.7156,
           0.0918,  2.1920, -1.2471],
         [-0.8210, -0.5885,  0.4094,  0.5952, -0.5706,  1.2906,  1.3115,
          -0.3000,  2.2222,  0.5521],
         [ 0.0379, -0.9638, -0.4586, -1.1863, -0.1216,  0.8507,  0.6049,
          -0.5133, -0.1359, -0.0879],
         [-1.1264,  0.8603, -1.8749,  0.3225, -0.1212,  0.1518, -0.1020,
           1.1794,  0.9940, -0.7038],
         [ 0.7956,  0.4990, -1.4565,  0.4149, -0.0138,  2.1350, -0.3909,
           0.7946, -0.0683, -0.8851]],

        [[-0.0238,  0.8295, -1.3396, -1.3248,  0.0594,  1.3167, -1.0940,
          -1.1951, -1.0396, -0.0233],
         [ 0.4356,  0.6370, -0.7741,  0.8926,  0.4708, -1.5016,  1.2592,
           0.3013, -0.8454,  0.5197],
         [-0.2668, -1.1652, -2.0050,  1.3835, -0.7573,  1.7355, -0.3102,
          -0.4714, -0.6888,  0.1705],
         [-0.8642,  0.6668, -0.8236, -0.2389, -0.5548,  2.1566,  0.2045,
          -0.5670,  0.5661, -1.2467],
         [-1.6556,  0.0503, -0.0331, -0.1597, -0.5395, -0.3199,  0.7204,
          -2.7539, -0.3776, -0.6279],
         [ 0.7956,  0.4990, -1.4565,  0.4149, -0.0138,  2.1350, -0.3909,
           0.7946, -0.0683, -0.8851]]], grad_fn=<EmbeddingBackward0>)

评论