BERT模型源码解析
modeling.py
目录
属性
类
class BertConfig(object) BERT模型配置参数类
class BertModel(object) BERT模型类
函数
def gelu(x) 格鲁激活函数
def get_activation(activation_string) 通过名称获取激活函数
def get_assignment_map_from_checkpoint 读取检查点函数
def dropout(input_tensor, dropout_prob) 丢弃函数,按一定比例丢弃权重数据
def layer_norm(input_tensor, name=None) 数据标准化
def layer_norm_and_dropout 先标准化,再丢弃
def create_initializer(initializer_range=0.02) 数据初始化
def embedding_lookup 嵌入查找函数
def embedding_postprocessor 嵌入处理函数
def create_attention_mask_from_input_mask 创建注意力掩码
def attention_layer 注意力层 处理函数
def transformer_model transformer模型
def get_shape_list 获取张量的形状参数列表
def reshape_to_matrix(input_tensor) 将张量转换为二维矩阵
def reshape_from_matrix(output_tensor, orig_shape_list) 将二维张量转换为指定维数
def assert_rank(tensor, expected_rank, name=None) 断言 张量的维数
源码
许可信息
# coding=utf-8 编码使用utf-8
# Copyright 2018 The Google AI Language Team Authors.版权术语谷歌语言团队的作者
#
# Licensed under the Apache License, Version 2.0 (the "License");根据Apache许可证进行许可
# you may not use this file except in compliance with the License.
如不符合许可证的规定 , 则不可使用本文件
# You may obtain a copy of the License at 可以通过下面的网址获取许可证副本
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The main BERT model and related functions."""
导入依赖
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
import re
import numpy as np
import six
import tensorflow as tf
模型配置
构造函数
参数说明
class BertConfig(object):
"""Configuration for `BertModel`."""对BERT模型进行参数配置
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02):
"""Constructs BertConfig.构造函数
Args:参数说明
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
inputs_ids集合的大小
hidden_size: Size of the encoder layers and the pooler layer.
编码层和池化层的大小
num_hidden_layers: Number of hidden layers in the Transformer encoder.
Transformer 编码器中隐藏层个数
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
Transformer 编码器中每个注意层的头数
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
Transformer 编码器中中间层个数
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
编码器和池化器的激活函数
hidden_dropout_prob: The dropout probability for all fully connected
layers in the embeddings, encoder, and pooler.
丢弃概率(嵌入层、编码层、池化层)
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
注意力概率的丢弃比例
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
最大序列长度 , 一般设置大一些以防万一,例如可以设置为512,1024,2048
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
token_type_ids的词汇量
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
初始化权重参数的标准差
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
推荐阅读
- 【lwip】11-UDP协议&源码分析
- 硬核剖析Java锁底层AQS源码,深入理解底层架构设计
- SpringCloudAlibaba 微服务组件 Nacos 之配置中心源码深度解析
- Seata 1.5.2 源码学习
- MindStudio模型训练场景精度比对全流程和结果分析
- .NET 源码学习 [数据结构-线性表1.2] 链表与 LinkedList<T>
- Redisson源码解读-公平锁
- OpenHarmony移植案例: build lite源码分析之hb命令__entry__.py
- 【深入浅出 Yarn 架构与实现】1-2 搭建 Hadoop 源码阅读环境
- JVM学习笔记——内存模型篇