BERT模型源码解析( 六 )


类型令牌的名称
token_type_embedding_name: string. The name of the embedding table variable
for token type ids.
是否使用位置嵌入
use_position_embeddings: bool. Whether to add position embeddings for the
position of each token in the sequence.
位置嵌入表的名称
position_embedding_name: string. The name of the embedding table variable
for positional embeddings.
标准差stdev,也就是参数的范围,用于权重参数的初始化
initializer_range: float. Range of the weight initialization.
位置嵌入的最大长度,可以大于输入序列的长度,但是不能小于
max_position_embeddings: int. Maximum sequence length that might ever be
used with this model. This can be longer than the sequence length of
input_tensor, but cannot be shorter.
丢弃率=1-保留率
dropout_prob: float. Dropout probability applied to the final output tensor.
Returns: 返回值:和输入张量形状相同的另一个张量
float tensor with same shape as `input_tensor`.
Raises: 异常:张量形状或者输入值无效
ValueError: One of the tensor shapes or input values is invalid.
"""
input_shape = get_shape_list(input_tensor, expected_rank=3) 获取形状列表
batch_size = input_shape[0]
seq_length = input_shape[1]
width = input_shape[2]
output = input_tensor
类型嵌入■
if use_token_type:
if token_type_ids is None: 如果没有token_type_ids 就触发异常
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
类型表
token_type_table = tf.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
initializer=create_initializer(initializer_range))
# This vocab will be small so we always do one-hot here, since it is always
# faster for a small vocabulary.
这个词典比较小,所以使用 one-hot,因为更快
flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 平坦化,变成一维的
转换成one_hot格式的id
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
one_hot格式乘以一个类型表,则转换为词向量
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings 将类型数据加进去
位置嵌入■
if use_position_embeddings: 如果使用位置嵌入
断言条件 x <= y 保持元素
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
tf.control_dependencies是tensorflow中的一个flow顺序控制机制 , 作用有二:
插入依赖(dependencies)和清空依赖(依赖是op或tensor) 。
with tf.control_dependencies([assert_op]):
get_variable函数的作用是创建新的tensorflow变量,常见的initializer有:常量初始化器tf.constant_initializer、正太分布初始化器tf.random_normal_initializer、截断正态分布初始化器tf.truncated_normal_initializer、均匀分布初始化器tf.random_uniform_initializer 。
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
因为位置嵌入表 是一个学习变量,可以通过最大长度创建;
实际的序列长度可能小于这个长度 , 因为快速训练任务没有长序列;
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
所以 全位置嵌入是一个高效的嵌入 , 并且当前序列有位置信息 , 所以我们只执行一个切片
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
函数:tf.slice(inputs, begin, size, name)
作用:从列表、数组、张量等对象中抽取一部分数据
position_embeddings = tf.slice(full_position_embeddings, [0, 0],
[seq_length, -1])
num_dims = len(output.shape.as_list()) 维度个数
只有最后两个维度是有意义的,所以我们在第一个维度广播 , 通常这个维度是 批处理量
# Only the last two dimensions are relevant (`seq_length` and `width`), so
# we broadcast among the first dimensions, which is typically just
# the batch size.
position_broadcast_shape = [] 广播形状

推荐阅读