BERT 技术博客:从双向上下文到微调落地




2019-12-28

blog_main_img

它最有意思的地方,不只是把 `Transformer Encoder` 用在了自然语言处理上,而是把“先做大规模预训练,再拿去做下游任务微调”这件事,做成了大家都能复用的标准套路。于是文本分类、问答、命名实体识别、句子匹配这些任务,突然不需要每次都从头练一个模型了。

有些模型是“看过就知道名字”,有些模型是“真的把后面的路都带偏了”。BERT 显然属于后者。

如果把这篇文章压缩成一句话,那就是:

BERT = 双向上下文建模 + 预训练表示学习 + 针对任务快速微调

下面从结构、输入、预训练目标和 Python 用法几条线,把这件事讲顺。

BERT 结构图

BERT 到底解决了什么

BERT 之前,很多语言模型更像是单向地读句子。比如从左往右,或者从右往左。这样做当然能学到上下文,但总有种“只看到半边”的味道。

BERT 的关键改动很直接:它想让每个 token 在编码时,同时看见左边和右边的语义线索。这样一来,同一个词在不同语境下就更容易被表示成不同的向量。

举个很直观的例子:

bank 在“river bank”和“open a bank account”里压根不是一个意思。双向上下文一旦建立起来,模型就没那么容易把它们混成一团。

先看输入:BERT 不是只吃词向量

送进 BERT 的输入,通常不是一份单纯的词嵌入,而是三部分相加:

  • Token Embedding
  • Position Embedding
  • Segment Embedding

可以把它理解成一句话里的每个位置都要回答三个问题:

  • 我是谁
  • 我在第几个位置
  • 我属于句子 A 还是句子 B

所以像下面这种输入形式就很常见:

[CLS] 今 天 天 气 真 不 错 [SEP] 我 想 去 散 步 [SEP]

这里几个特殊符号很关键:

  • [CLS] 常被拿来汇总整句语义,后面接分类头很方便
  • [SEP] 用来分隔句子或标记结束
  • [MASK] 则主要在预训练阶段用来做“猜词填空”

这一步看起来平平无奇,但很重要。因为后面编码器再强,也得先有一个带位置感、句段感的输入起点。

编码器的核心:一层层 Transformer Encoder 往上叠

BERT 的主体是多层堆叠的 Transformer Encoder。每一层里最核心的部件有两个:

  • Multi-Head Self-Attention
  • Feed Forward Network

外加残差连接和层归一化,保证信息传递别太拧巴。

如果把它讲得接地气一点,自注意力干的是这件事:

一个 token 在编码自己时,不是只盯着自己,而是会去看整句里哪些词和它关系更大,然后按权重把别人的信息拉过来。

这样模型就能把“局部词义”慢慢抬升成“上下文中的词义”。

比如一句“这家店的服务不错,但是菜偏咸”,到了后面几层,不错服务 的关联、 的关联,通常会被建得更清楚。

BERT 为什么能预训练得这么有效

BERT 最经典的预训练目标,通常会提到两项:

  • Masked Language Modeling
  • Next Sentence Prediction

Masked Language Modeling:让模型学会补词

这部分是 BERT 最出圈的设计。

训练时,会把输入里的部分 token 挖掉,换成 [MASK] 或其他扰动形式,然后要求模型根据上下文把原词猜回来。

例如:

我 特 别 喜 欢 [MASK] 然 语 言 处 理

模型需要根据上下文判断这个位置可能是什么词。

这个任务的妙处在于:为了猜准一个词,模型必须同时利用左边和右边的语义信息,于是“双向”这件事不是口号,而是被直接写进训练目标里的。

Next Sentence Prediction:让模型有一点句间关系意识

这项任务会喂给模型两句话,让它判断第二句是不是第一句后面接着的内容。

它的原始出发点不难理解:

  • 帮模型学句子之间的关系
  • 让问答、匹配、推断这类任务更容易接上

不过后来不少后续模型会对这部分做调整,甚至直接移除,因为大家发现它不一定总是带来最稳定的收益。这个现象本身也很有意思:好模型的设计,从来不是某一项技巧永远不动,而是不断被验证、精简和替换。

预训练与微调流程图

微调为什么这么省事

BERT 受欢迎,一个很现实的原因是它接任务时非常顺手。

思路很统一:

  • 先拿预训练好的参数做通用文本表示
  • 再在顶部接一个轻量任务头
  • 用你的数据把整套网络一起微调,或者只微调一部分

于是不同任务只是“头”不同:

  • 文本分类:拿 [CLS] 接全连接层
  • 序列标注:取每个 token 的隐状态分别分类
  • 问答:预测答案的起点和终点
  • 句子相似度:比较两个句子的表示

这套范式的真正价值不在“统一”,而在“省重练成本”。以前得给每个任务单独搭特征,现在模型自己就把大部分语言知识提前学好了。

一个简洁但够用的 Python 体验版

下面直接上点代码。为了不把文章写成安装说明书,我用 transformers 的常见写法演示。

1. 先把文本变成 BERT 能吃的输入

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

text = "BERT 读句子时,不会只看左边。"
inputs = tokenizer(
    text,
    return_tensors="pt",
    padding=True,
    truncation=True
)

print(inputs["input_ids"].shape)
print(inputs["attention_mask"])

这里通常会得到几个关键张量:

  • input_ids
  • token_type_ids
  • attention_mask

它们分别对应词表索引、句段编号和有效位置掩码。

2. 取出 BERT 编好的语义表示

import torch
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-chinese")
model.eval()

with torch.no_grad():
    outputs = model(**inputs)

last_hidden_state = outputs.last_hidden_state
cls_vector = outputs.pooler_output

print("token 级输出:", last_hidden_state.shape)
print("句子级输出:", cls_vector.shape)

这里可以简单理解为:

  • last_hidden_state 给你每个 token 的上下文化表示
  • pooler_output 给你一句话的聚合表示,常用于分类起步

3. 让模型玩一次 masked language modeling

import torch
from transformers import BertTokenizer, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
mlm_model = BertForMaskedLM.from_pretrained("bert-base-chinese")

text = "自然语言处理很[MASK]。"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    logits = mlm_model(**inputs).logits

mask_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
pred_id = logits[0, mask_index, :].argmax(dim=-1)
pred_token = tokenizer.decode(pred_id)

print("预测词:", pred_token)

这段代码很适合拿来建立直觉:BERT 确实是在上下文里补词,而不是机械地背词表。

4. 分类任务接起来通常就这么几步

import torch
import torch.nn as nn
from transformers import BertModel

class BertClassifier(nn.Module):
    def __init__(self, num_labels: int):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-chinese")
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        cls_state = outputs.last_hidden_state[:, 0]
        logits = self.classifier(self.dropout(cls_state))
        return logits

这个结构没什么花哨动作,但很实用。很多文本分类任务,从这里就能稳稳起飞。

Python 使用路线图

用 BERT 时几个常见的工程感受

真正把 BERT 用起来之后,大家通常会很快感受到几件事。

第一,输入长度是现实约束,不是小细节。文本一长,显存和计算量都会往上窜,所以切分、截断、分块推理往往得认真做。

第二,BERT 很强,但不是拿来乱堆就一定赢。数据量偏小的时候,微调学习率、冻结层数、分类头形式,往往比你想象中更影响效果。

第三,预训练模型帮你解决的是“语言表示起点”问题,不是“数据标注质量”问题。如果标签本身混乱,再好的 encoder 也会被带偏。

BERT 的价值,不止是一种模型

如果只把 BERT 看成一个网络结构,那其实低估它了。

它更像是一个范式:

  • 先在大语料上学通用语言知识
  • 再把知识迁移到具体任务
  • 把“特征工程”尽量压缩成“表示学习”

后面大量模型,不管是轻量化版本、蒸馏版本,还是更激进的预训练变体,基本都在这条线上继续往前拱。

所以 BERT 最值得记住的地方,不只是 MLM[CLS]、自注意力这些关键词,而是它把 NLP 的工作流彻底改写成了“预训练 + 微调”的常态。

如果你读到这里,脑子里至少应该已经有一张比较清楚的图:

  • 输入不是单一词向量,而是 token、位置、句段信息的组合
  • 主体是多层 Transformer Encoder
  • 预训练里最核心的动作是 masked language modeling
  • 下游任务通常只需要在顶部接一个轻量头再微调
  • Python 里最常见的落地方式,就是 transformers + PyTorch