2019-04-02
YOLOv3 的损失函数看起来绕,主要是因为它不是只算一个分类误差,而是同时处理四件事:框中心点是否准 、框宽高是否准 、这个位置是否有目标 、目标属于哪个类别
Keras 版本的 YOLOv3 里,损失函数经常会被包进一个 Lambda 层:
model_loss = Lambda(
yolo_loss,
output_shape=(1,),
name="yolo_loss",
arguments={
"anchors": anchors,
"num_classes": num_classes,
"ignore_thresh": 0.5,
}
)([*model_body.output, *y_true])
核心函数签名通常类似:
def yolo_loss(args, anchors, num_classes, ignore_thresh=0.5):
...
这里的 args 包含两部分。
第一部分是模型输出:
model_body.output
YOLOv3 有三个输出尺度。如果类别数是 20,每个尺度每个 anchor 输出 5 + 20 = 25 个值,三个 anchor 合在一起就是 75 个通道。
所以输出形状大致是:
batch × 13 × 13 × 75
batch × 26 × 26 × 75
batch × 52 × 52 × 75
第二部分是训练真值:
y_true
它同样有三个尺度:
batch × 13 × 13 × 3 × 25
batch × 26 × 26 × 3 × 25
batch × 52 × 52 × 3 × 25
如果类别数是 C,最后一维就是:
5 + C
其中前 5 个值一般是:
x, y, w, h, objectness
后面是类别 one-hot。
YOLOv3 的总损失可以拆成四块:
loss = xy_loss + wh_loss + confidence_loss + class_loss
这四部分分别对应:
xy_loss:中心点坐标损失wh_loss:宽高损失confidence_loss:置信度损失class_loss:类别损失检测任务本质上是定位加分类,所以损失函数自然也要同时约束位置和类别。
网络原始输出不能直接拿来和 y_true 比较。
比如某个尺度的输出是:
batch × grid × grid × 75
它需要先 reshape 成:
batch × grid × grid × 3 × (5 + C)
然后再把原始预测值解码成真实含义。
典型解码公式如下:
box_xy = (sigmoid(raw_xy) + grid) / grid_shape
box_wh = exp(raw_wh) * anchors / input_shape
box_confidence = sigmoid(raw_confidence)
box_class_probs = sigmoid(raw_class)
其中:
grid 表示当前 cell 的坐标偏移sigmoid(raw_xy) 把中心点限制在当前 cell 附近exp(raw_wh) 用来缩放 anchoranchors 是预设框宽高input_shape 用来归一化宽高这样处理后,预测框就从网络原始输出变成了归一化坐标。
下面写一个简化版,帮助理解形状变换和解码过程。
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def yolo_head_numpy(feats, anchors, num_classes, input_shape):
batch, grid_h, grid_w, _ = feats.shape
num_anchors = len(anchors)
feats = feats.reshape(
batch,
grid_h,
grid_w,
num_anchors,
num_classes + 5
)
grid_y, grid_x = np.meshgrid(
np.arange(grid_h),
np.arange(grid_w),
indexing="ij"
)
grid = np.stack([grid_x, grid_y], axis=-1)
grid = grid.reshape(1, grid_h, grid_w, 1, 2)
anchors = np.asarray(anchors, dtype=np.float32).reshape(1, 1, 1, num_anchors, 2)
input_h, input_w = input_shape
box_xy = (sigmoid(feats[..., 0:2]) + grid) / np.array([grid_w, grid_h])
box_wh = np.exp(feats[..., 2:4]) * anchors / np.array([input_w, input_h])
box_confidence = sigmoid(feats[..., 4:5])
box_class_probs = sigmoid(feats[..., 5:])
return box_xy, box_wh, box_confidence, box_class_probs
这里最容易迷糊的是 grid。
它不是模型学出来的参数,而是人为构造的坐标偏移。每个 cell 都有自己的 (grid_x, grid_y),网络只需要预测在这个 cell 内部的相对偏移。
损失计算中,预测值有两种状态:
raw_predpred_xy、pred_wh有些损失为了数值稳定,会直接在 raw 空间里计算。
因此需要把 y_true 也转回类似 raw 的空间:
raw_true_xy = y_true_xy * grid_shape - grid
raw_true_wh = log(y_true_wh / anchor * input_shape)
对应代码常见写法:
raw_true_xy = y_true_xy * grid_shape[::-1] - grid
raw_true_wh = log(y_true_wh / anchors * input_shape[::-1])
注意 wh 里有 log,如果真实框宽高为 0,会出现无效值。所以实现里通常只对有目标的位置计算,空位置会被 mask 掉。
object_mask 表示某个位置、某个 anchor 是否负责预测真实目标。
形状通常是:
batch × grid × grid × 3 × 1
如果某个位置有目标:
object_mask = 1
否则:
object_mask = 0
它会参与多个损失:
代码里经常能看到:
box_loss_scale = 2 - y_true[..., 2:3] * y_true[..., 3:4]
这里的 y_true[..., 2:3] 和 y_true[..., 3:4] 是归一化后的宽高。
所以:
w * h
可以近似理解成目标框面积比例。
那么:
box_loss_scale = 2 - w * h
目标越小,w * h 越小,box_loss_scale 越大。
它的作用是让小目标的定位误差拥有更高权重。
原因也很直观:同样偏移几个像素,对大目标影响可能不大,对小目标影响却非常明显。
中心点损失常见写法:
xy_loss = object_mask * box_loss_scale * binary_crossentropy(
raw_true_xy,
raw_pred[..., 0:2],
from_logits=True
)
它有三部分:
object_mask
box_loss_scale
binary_crossentropy
含义是:
为什么 xy 用 BCE?
因为 xy 的预测会经过 sigmoid,表示 cell 内部的相对偏移,范围在 0~1。所以训练 raw 输出时,常用 from_logits=True 的二值交叉熵来处理。
宽高损失常见写法:
wh_loss = object_mask * box_loss_scale * 0.5 * square(
raw_true_wh - raw_pred[..., 2:4]
)
和 xy_loss 类似,它也只在有目标的位置计算,也会乘上 box_loss_scale。
区别在于,wh 使用的是平方误差。
这是因为 wh 在 raw 空间里通过:
pred_wh = exp(raw_wh) * anchor
进行解码,所以 raw 宽高更适合直接做差并平方。
置信度损失里最容易绕的是 ignore_mask。
直觉上,没有目标的位置就应该当负样本。但目标检测里有个细节:某些预测框虽然没有被分配为正样本,但它和真实框重合度很高。
如果把这种预测框强行当负样本惩罚,模型会很矛盾。
所以 YOLOv3 会计算每个预测框和真实框的最大 IoU:
best_iou = max(IoU(pred_box, true_box))
如果这个值大于阈值,就不把它当普通负样本惩罚。
常见逻辑:
best_iou < ignore_thresh -> 参与负样本 confidence loss
best_iou >= ignore_thresh -> 忽略这部分负样本损失
这能减少模型对“其实挺像目标的预测框”的过度惩罚。
置信度表示这个位置是否有目标。
常见写法:
confidence_loss = (
object_mask * BCE(object_mask, raw_pred[..., 4:5])
+ (1 - object_mask) * BCE(object_mask, raw_pred[..., 4:5]) * ignore_mask
)
可以拆成两部分:
正样本:object_mask = 1
负样本:object_mask = 0 且 ignore_mask = 1
有目标的位置,希望置信度接近 1。
没有目标、且不需要忽略的位置,希望置信度接近 0。
被 ignore 的位置,不强行压低置信度。
类别损失常见写法:
class_loss = object_mask * BCE(
true_class_probs,
raw_pred[..., 5:],
from_logits=True
)
只在有目标的位置计算类别损失。
YOLOv3 的类别预测通常使用 sigmoid,而不是 softmax。
这意味着每个类别都可以看成一个独立的二分类问题:
这个目标是不是 class_0
这个目标是不是 class_1
这个目标是不是 class_2
...
这种设计也方便扩展到多标签场景。
下面给一个简化版代码,主要用于理解四部分损失,不追求完整复刻工程实现。
import torch
import torch.nn.functional as F
def yolo_loss_one_scale(raw_pred, y_true, ignore_mask=None):
"""
raw_pred: batch x grid x grid x anchors x (5 + C)
y_true: batch x grid x grid x anchors x (5 + C)
"""
object_mask = y_true[..., 4:5]
true_xy = y_true[..., 0:2]
true_wh = y_true[..., 2:4]
true_cls = y_true[..., 5:]
box_loss_scale = 2.0 - true_wh[..., 0:1] * true_wh[..., 1:2]
xy_loss = object_mask * box_loss_scale * F.binary_cross_entropy_with_logits(
raw_pred[..., 0:2],
true_xy,
reduction="none"
)
wh_loss = object_mask * box_loss_scale * 0.5 * torch.square(
raw_pred[..., 2:4] - true_wh
)
conf_bce = F.binary_cross_entropy_with_logits(
raw_pred[..., 4:5],
object_mask,
reduction="none"
)
if ignore_mask is None:
ignore_mask = torch.ones_like(object_mask)
confidence_loss = object_mask * conf_bce + (1 - object_mask) * conf_bce * ignore_mask
class_loss = object_mask * F.binary_cross_entropy_with_logits(
raw_pred[..., 5:],
true_cls,
reduction="none"
)
batch_size = raw_pred.shape[0]
loss = (
xy_loss.sum()
+ wh_loss.sum()
+ confidence_loss.sum()
+ class_loss.sum()
) / batch_size
return loss
这个版本里,true_wh 为了简化直接放在 raw 空间里。完整实现中,通常要把真实宽高按照 anchor 反变换到 raw 空间。
YOLOv3 有三个输出尺度,所以通常会对每个尺度分别计算损失,再加起来:
total_loss = 0
for scale_id in range(3):
loss = yolo_loss_one_scale(
raw_preds[scale_id],
y_trues[scale_id],
ignore_masks[scale_id]
)
total_loss += loss
完整训练中,每个尺度负责不同大小目标,三个尺度一起优化。
可以把 YOLOv3 loss 粗略写成:
L = L_xy + L_wh + L_conf + L_cls
其中:
L_xy = object_mask × box_loss_scale × BCE(raw_true_xy, raw_pred_xy)
L_wh = object_mask × box_loss_scale × 0.5 × (raw_true_wh - raw_pred_wh)^2
L_conf = object_mask × BCE(1, pred_conf)
+ (1 - object_mask) × ignore_mask × BCE(0, pred_conf)
L_cls = object_mask × BCE(true_class, pred_class)
最后对所有网格、anchor、batch 求和,再做归一化。