BERT模型源码解析(14)


return final_outputs
else: 如果不要求返回所有层
final_output = reshape_from_matrix(prev_output, input_shape) 变形
return final_output
相关
辅助函数
def get_shape_list(tensor, expected_rank=None, name=None):
"""Returns a list of the shape of tensor, preferring static dimensions.
返回一个张量在各个维度上的大小,最好是静态维度
Args:  入参:张量,想要得到的秩,名称
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown. 如果指定参数与张量的秩不同,则报错
name: Optional name of the tensor for the error message.
Returns: 返回值:张量在各个维度上的大小,构成的一个列表
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars. 如果是动态维度,将返回一个标量
"""
if name is None: 如果没有指定名称,就用张量的名称
name = tensor.name
if expected_rank is not None:  如果没有指定秩,就用张量的秩
assert_rank(tensor, expected_rank, name)
shape = tensor.shape.as_list() 将尺寸参数转换为列表
non_static_indexes = []
for (index, dim) in enumerate(shape):
if dim is None: 维度是None,表示该维度为动态维度dynamic dimension
non_static_indexes.append(index)
if not non_static_indexes: 如果没有 非静态维度(全是静态维度),就直接返回
return shape
dyn_shape = tf.shape(tensor) 包含动态维度的形状
for index in non_static_indexes: 获取所有动态维度
shape[index] = dyn_shape[index]
return shape
■多维变2维
def reshape_to_matrix(input_tensor): 将张量转换为二维矩阵
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims = input_tensor.shape.ndims
if ndims < 2: 待转换张量维度小于2,就报错
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
(input_tensor.shape))
if ndims == 2: 维度恰好2,直接返回
return input_tensor
width = input_tensor.shape[-1] 获取最后一维(倒数第一维)的大小
output_tensor = tf.reshape(input_tensor, [-1, width]) 最后一维不变,前面的其他维度自适应(相乘)
return output_tensor
■二维变多维
def reshape_from_matrix(output_tensor, orig_shape_list):
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
if len(orig_shape_list) == 2:
return output_tensor
output_shape = get_shape_list(output_tensor)
orig_dims = orig_shape_list[0:-1] 将原始形状 去除最后一维
width = output_shape[-1] 宽度为最后一维的大小
return tf.reshape(output_tensor, orig_dims + [width])
■秩的断言
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
如果对不上,就报错
Args:参数:张量,期望的秩,名称(用于打印报错信息)
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
if name is None: 如果没有指定名称,则取张量的变量名称
name = tensor.name
expected_rank_dict = {}
if isinstance(expected_rank, six.integer_types): 如果指定的秩是个整数
expected_rank_dict[expected_rank] = True

经验总结扩展阅读