2019-12-28
它最有意思的地方,不只是把 `Transformer Encoder` 用在了自然语言处理上,而是把“先做大规模预训练,再拿去做下游任务微调”这件事,做成了大家都能复用的标准套路。于是文本分类、问答、命名实体识别、句子匹配这些任务,突然不需要每次都从头练一个模型了。
有些模型是“看过就知道名字”,有些模型是“真的把后面的路都带偏了”。BERT 显然属于后者。
如果把这篇文章压缩成一句话,那就是:
BERT = 双向上下文建模 + 预训练表示学习 + 针对任务快速微调
下面从结构、输入、预训练目标和 Python 用法几条线,把这件事讲顺。
在 BERT 之前,很多语言模型更像是单向地读句子。比如从左往右,或者从右往左。这样做当然能学到上下文,但总有种“只看到半边”的味道。
BERT 的关键改动很直接:它想让每个 token 在编码时,同时看见左边和右边的语义线索。这样一来,同一个词在不同语境下就更容易被表示成不同的向量。
举个很直观的例子:
bank 在“river bank”和“open a bank account”里压根不是一个意思。双向上下文一旦建立起来,模型就没那么容易把它们混成一团。
送进 BERT 的输入,通常不是一份单纯的词嵌入,而是三部分相加:
Token EmbeddingPosition EmbeddingSegment Embedding可以把它理解成一句话里的每个位置都要回答三个问题:
所以像下面这种输入形式就很常见:
[CLS] 今 天 天 气 真 不 错 [SEP] 我 想 去 散 步 [SEP]
这里几个特殊符号很关键:
[CLS] 常被拿来汇总整句语义,后面接分类头很方便[SEP] 用来分隔句子或标记结束[MASK] 则主要在预训练阶段用来做“猜词填空”这一步看起来平平无奇,但很重要。因为后面编码器再强,也得先有一个带位置感、句段感的输入起点。
BERT 的主体是多层堆叠的 Transformer Encoder。每一层里最核心的部件有两个:
Multi-Head Self-AttentionFeed Forward Network外加残差连接和层归一化,保证信息传递别太拧巴。
如果把它讲得接地气一点,自注意力干的是这件事:
一个 token 在编码自己时,不是只盯着自己,而是会去看整句里哪些词和它关系更大,然后按权重把别人的信息拉过来。
这样模型就能把“局部词义”慢慢抬升成“上下文中的词义”。
比如一句“这家店的服务不错,但是菜偏咸”,到了后面几层,不错 和 服务 的关联、咸 和 菜 的关联,通常会被建得更清楚。
BERT 最经典的预训练目标,通常会提到两项:
Masked Language ModelingNext Sentence Prediction这部分是 BERT 最出圈的设计。
训练时,会把输入里的部分 token 挖掉,换成 [MASK] 或其他扰动形式,然后要求模型根据上下文把原词猜回来。
例如:
我 特 别 喜 欢 [MASK] 然 语 言 处 理
模型需要根据上下文判断这个位置可能是什么词。
这个任务的妙处在于:为了猜准一个词,模型必须同时利用左边和右边的语义信息,于是“双向”这件事不是口号,而是被直接写进训练目标里的。
这项任务会喂给模型两句话,让它判断第二句是不是第一句后面接着的内容。
它的原始出发点不难理解:
不过后来不少后续模型会对这部分做调整,甚至直接移除,因为大家发现它不一定总是带来最稳定的收益。这个现象本身也很有意思:好模型的设计,从来不是某一项技巧永远不动,而是不断被验证、精简和替换。
BERT 受欢迎,一个很现实的原因是它接任务时非常顺手。
思路很统一:
于是不同任务只是“头”不同:
[CLS] 接全连接层这套范式的真正价值不在“统一”,而在“省重练成本”。以前得给每个任务单独搭特征,现在模型自己就把大部分语言知识提前学好了。
下面直接上点代码。为了不把文章写成安装说明书,我用 transformers 的常见写法演示。
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
text = "BERT 读句子时,不会只看左边。"
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True
)
print(inputs["input_ids"].shape)
print(inputs["attention_mask"])
这里通常会得到几个关键张量:
input_idstoken_type_idsattention_mask它们分别对应词表索引、句段编号和有效位置掩码。
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-chinese")
model.eval()
with torch.no_grad():
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
cls_vector = outputs.pooler_output
print("token 级输出:", last_hidden_state.shape)
print("句子级输出:", cls_vector.shape)
这里可以简单理解为:
last_hidden_state 给你每个 token 的上下文化表示pooler_output 给你一句话的聚合表示,常用于分类起步import torch
from transformers import BertTokenizer, BertForMaskedLM
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
mlm_model = BertForMaskedLM.from_pretrained("bert-base-chinese")
text = "自然语言处理很[MASK]。"
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = mlm_model(**inputs).logits
mask_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
pred_id = logits[0, mask_index, :].argmax(dim=-1)
pred_token = tokenizer.decode(pred_id)
print("预测词:", pred_token)
这段代码很适合拿来建立直觉:BERT 确实是在上下文里补词,而不是机械地背词表。
import torch
import torch.nn as nn
from transformers import BertModel
class BertClassifier(nn.Module):
def __init__(self, num_labels: int):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-chinese")
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
cls_state = outputs.last_hidden_state[:, 0]
logits = self.classifier(self.dropout(cls_state))
return logits
这个结构没什么花哨动作,但很实用。很多文本分类任务,从这里就能稳稳起飞。
真正把 BERT 用起来之后,大家通常会很快感受到几件事。
第一,输入长度是现实约束,不是小细节。文本一长,显存和计算量都会往上窜,所以切分、截断、分块推理往往得认真做。
第二,BERT 很强,但不是拿来乱堆就一定赢。数据量偏小的时候,微调学习率、冻结层数、分类头形式,往往比你想象中更影响效果。
第三,预训练模型帮你解决的是“语言表示起点”问题,不是“数据标注质量”问题。如果标签本身混乱,再好的 encoder 也会被带偏。
如果只把 BERT 看成一个网络结构,那其实低估它了。
它更像是一个范式:
后面大量模型,不管是轻量化版本、蒸馏版本,还是更激进的预训练变体,基本都在这条线上继续往前拱。
所以 BERT 最值得记住的地方,不只是 MLM、[CLS]、自注意力这些关键词,而是它把 NLP 的工作流彻底改写成了“预训练 + 微调”的常态。
如果你读到这里,脑子里至少应该已经有一张比较清楚的图:
Transformer Encodertransformers + PyTorch