YOLOv3 损失函数详解:xy、wh、置信度和类别




2019-04-02

blog_main_img

YOLOv3 的损失函数看起来绕,主要是因为它不是只算一个分类误差,而是同时处理四件事:框中心点是否准 、框宽高是否准 、这个位置是否有目标 、目标属于哪个类别

yolo_loss 的输入是什么

Keras 版本的 YOLOv3 里,损失函数经常会被包进一个 Lambda 层:

model_loss = Lambda(
    yolo_loss,
    output_shape=(1,),
    name="yolo_loss",
    arguments={
        "anchors": anchors,
        "num_classes": num_classes,
        "ignore_thresh": 0.5,
    }
)([*model_body.output, *y_true])

核心函数签名通常类似:

def yolo_loss(args, anchors, num_classes, ignore_thresh=0.5):
    ...

这里的 args 包含两部分。

第一部分是模型输出:

model_body.output

YOLOv3 有三个输出尺度。如果类别数是 20,每个尺度每个 anchor 输出 5 + 20 = 25 个值,三个 anchor 合在一起就是 75 个通道。

所以输出形状大致是:

batch × 13 × 13 × 75
batch × 26 × 26 × 75
batch × 52 × 52 × 75

第二部分是训练真值:

y_true

它同样有三个尺度:

batch × 13 × 13 × 3 × 25
batch × 26 × 26 × 3 × 25
batch × 52 × 52 × 3 × 25

如果类别数是 C,最后一维就是:

5 + C

其中前 5 个值一般是:

x, y, w, h, objectness

后面是类别 one-hot。

损失函数由四部分组成

YOLOv3 的总损失可以拆成四块:

loss = xy_loss + wh_loss + confidence_loss + class_loss

YOLOv3 损失组成

这四部分分别对应:

  • xy_loss:中心点坐标损失
  • wh_loss:宽高损失
  • confidence_loss:置信度损失
  • class_loss:类别损失

检测任务本质上是定位加分类,所以损失函数自然也要同时约束位置和类别。

yolo_head 在做什么

网络原始输出不能直接拿来和 y_true 比较。

比如某个尺度的输出是:

batch × grid × grid × 75

它需要先 reshape 成:

batch × grid × grid × 3 × (5 + C)

然后再把原始预测值解码成真实含义。

yolo_head 解码流程

典型解码公式如下:

box_xy = (sigmoid(raw_xy) + grid) / grid_shape
box_wh = exp(raw_wh) * anchors / input_shape
box_confidence = sigmoid(raw_confidence)
box_class_probs = sigmoid(raw_class)

其中:

  • grid 表示当前 cell 的坐标偏移
  • sigmoid(raw_xy) 把中心点限制在当前 cell 附近
  • exp(raw_wh) 用来缩放 anchor
  • anchors 是预设框宽高
  • input_shape 用来归一化宽高

这样处理后,预测框就从网络原始输出变成了归一化坐标。

用 NumPy 看 yolo_head 的核心逻辑

下面写一个简化版,帮助理解形状变换和解码过程。

import numpy as np


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def yolo_head_numpy(feats, anchors, num_classes, input_shape):
    batch, grid_h, grid_w, _ = feats.shape
    num_anchors = len(anchors)

    feats = feats.reshape(
        batch,
        grid_h,
        grid_w,
        num_anchors,
        num_classes + 5
    )

    grid_y, grid_x = np.meshgrid(
        np.arange(grid_h),
        np.arange(grid_w),
        indexing="ij"
    )
    grid = np.stack([grid_x, grid_y], axis=-1)
    grid = grid.reshape(1, grid_h, grid_w, 1, 2)

    anchors = np.asarray(anchors, dtype=np.float32).reshape(1, 1, 1, num_anchors, 2)
    input_h, input_w = input_shape

    box_xy = (sigmoid(feats[..., 0:2]) + grid) / np.array([grid_w, grid_h])
    box_wh = np.exp(feats[..., 2:4]) * anchors / np.array([input_w, input_h])
    box_confidence = sigmoid(feats[..., 4:5])
    box_class_probs = sigmoid(feats[..., 5:])

    return box_xy, box_wh, box_confidence, box_class_probs

这里最容易迷糊的是 grid

它不是模型学出来的参数,而是人为构造的坐标偏移。每个 cell 都有自己的 (grid_x, grid_y),网络只需要预测在这个 cell 内部的相对偏移。

为什么要把 y_true 反变换

损失计算中,预测值有两种状态:

  • 原始网络输出,也就是 raw_pred
  • 解码后的预测框,比如 pred_xypred_wh

有些损失为了数值稳定,会直接在 raw 空间里计算。

因此需要把 y_true 也转回类似 raw 的空间:

raw_true_xy = y_true_xy * grid_shape - grid
raw_true_wh = log(y_true_wh / anchor * input_shape)

对应代码常见写法:

raw_true_xy = y_true_xy * grid_shape[::-1] - grid
raw_true_wh = log(y_true_wh / anchors * input_shape[::-1])

注意 wh 里有 log,如果真实框宽高为 0,会出现无效值。所以实现里通常只对有目标的位置计算,空位置会被 mask 掉。

object_mask 是什么

object_mask 表示某个位置、某个 anchor 是否负责预测真实目标。

形状通常是:

batch × grid × grid × 3 × 1

如果某个位置有目标:

object_mask = 1

否则:

object_mask = 0

它会参与多个损失:

  • 有目标的位置才计算 xy 损失
  • 有目标的位置才计算 wh 损失
  • 有目标的位置才计算类别损失
  • 置信度损失中,有目标和无目标分开处理

box_loss_scale 是什么

代码里经常能看到:

box_loss_scale = 2 - y_true[..., 2:3] * y_true[..., 3:4]

这里的 y_true[..., 2:3]y_true[..., 3:4] 是归一化后的宽高。

所以:

w * h

可以近似理解成目标框面积比例。

那么:

box_loss_scale = 2 - w * h

目标越小,w * h 越小,box_loss_scale 越大。

它的作用是让小目标的定位误差拥有更高权重。

原因也很直观:同样偏移几个像素,对大目标影响可能不大,对小目标影响却非常明显。

xy_loss:中心点损失

中心点损失常见写法:

xy_loss = object_mask * box_loss_scale * binary_crossentropy(
    raw_true_xy,
    raw_pred[..., 0:2],
    from_logits=True
)

它有三部分:

object_mask
box_loss_scale
binary_crossentropy

含义是:

  • 只在有目标的位置计算
  • 小目标给更高定位权重
  • 用交叉熵约束 cell 内偏移

为什么 xy 用 BCE?

因为 xy 的预测会经过 sigmoid,表示 cell 内部的相对偏移,范围在 0~1。所以训练 raw 输出时,常用 from_logits=True 的二值交叉熵来处理。

wh_loss:宽高损失

宽高损失常见写法:

wh_loss = object_mask * box_loss_scale * 0.5 * square(
    raw_true_wh - raw_pred[..., 2:4]
)

xy_loss 类似,它也只在有目标的位置计算,也会乘上 box_loss_scale

区别在于,wh 使用的是平方误差。

这是因为 wh 在 raw 空间里通过:

pred_wh = exp(raw_wh) * anchor

进行解码,所以 raw 宽高更适合直接做差并平方。

ignore_mask:为什么要忽略一部分负样本

置信度损失里最容易绕的是 ignore_mask

直觉上,没有目标的位置就应该当负样本。但目标检测里有个细节:某些预测框虽然没有被分配为正样本,但它和真实框重合度很高。

如果把这种预测框强行当负样本惩罚,模型会很矛盾。

所以 YOLOv3 会计算每个预测框和真实框的最大 IoU:

best_iou = max(IoU(pred_box, true_box))

如果这个值大于阈值,就不把它当普通负样本惩罚。

ignore_mask 示意

常见逻辑:

best_iou < ignore_thresh -> 参与负样本 confidence loss
best_iou >= ignore_thresh -> 忽略这部分负样本损失

这能减少模型对“其实挺像目标的预测框”的过度惩罚。

confidence_loss:置信度损失

置信度表示这个位置是否有目标。

常见写法:

confidence_loss = (
    object_mask * BCE(object_mask, raw_pred[..., 4:5])
    + (1 - object_mask) * BCE(object_mask, raw_pred[..., 4:5]) * ignore_mask
)

可以拆成两部分:

正样本:object_mask = 1
负样本:object_mask = 0 且 ignore_mask = 1

有目标的位置,希望置信度接近 1。

没有目标、且不需要忽略的位置,希望置信度接近 0。

被 ignore 的位置,不强行压低置信度。

class_loss:类别损失

类别损失常见写法:

class_loss = object_mask * BCE(
    true_class_probs,
    raw_pred[..., 5:],
    from_logits=True
)

只在有目标的位置计算类别损失。

YOLOv3 的类别预测通常使用 sigmoid,而不是 softmax。

这意味着每个类别都可以看成一个独立的二分类问题:

这个目标是不是 class_0
这个目标是不是 class_1
这个目标是不是 class_2
...

这种设计也方便扩展到多标签场景。

PyTorch 写一个简化版 YOLOv3 Loss

下面给一个简化版代码,主要用于理解四部分损失,不追求完整复刻工程实现。

import torch
import torch.nn.functional as F


def yolo_loss_one_scale(raw_pred, y_true, ignore_mask=None):
    """
    raw_pred: batch x grid x grid x anchors x (5 + C)
    y_true:   batch x grid x grid x anchors x (5 + C)
    """
    object_mask = y_true[..., 4:5]
    true_xy = y_true[..., 0:2]
    true_wh = y_true[..., 2:4]
    true_cls = y_true[..., 5:]

    box_loss_scale = 2.0 - true_wh[..., 0:1] * true_wh[..., 1:2]

    xy_loss = object_mask * box_loss_scale * F.binary_cross_entropy_with_logits(
        raw_pred[..., 0:2],
        true_xy,
        reduction="none"
    )

    wh_loss = object_mask * box_loss_scale * 0.5 * torch.square(
        raw_pred[..., 2:4] - true_wh
    )

    conf_bce = F.binary_cross_entropy_with_logits(
        raw_pred[..., 4:5],
        object_mask,
        reduction="none"
    )

    if ignore_mask is None:
        ignore_mask = torch.ones_like(object_mask)

    confidence_loss = object_mask * conf_bce + (1 - object_mask) * conf_bce * ignore_mask

    class_loss = object_mask * F.binary_cross_entropy_with_logits(
        raw_pred[..., 5:],
        true_cls,
        reduction="none"
    )

    batch_size = raw_pred.shape[0]
    loss = (
        xy_loss.sum()
        + wh_loss.sum()
        + confidence_loss.sum()
        + class_loss.sum()
    ) / batch_size

    return loss

这个版本里,true_wh 为了简化直接放在 raw 空间里。完整实现中,通常要把真实宽高按照 anchor 反变换到 raw 空间。

总损失怎么合并

YOLOv3 有三个输出尺度,所以通常会对每个尺度分别计算损失,再加起来:

total_loss = 0

for scale_id in range(3):
    loss = yolo_loss_one_scale(
        raw_preds[scale_id],
        y_trues[scale_id],
        ignore_masks[scale_id]
    )
    total_loss += loss

完整训练中,每个尺度负责不同大小目标,三个尺度一起优化。

一个更接近公式的总览

可以把 YOLOv3 loss 粗略写成:

L = L_xy + L_wh + L_conf + L_cls

其中:

L_xy = object_mask × box_loss_scale × BCE(raw_true_xy, raw_pred_xy)
L_wh = object_mask × box_loss_scale × 0.5 × (raw_true_wh - raw_pred_wh)^2
L_conf = object_mask × BCE(1, pred_conf)
       + (1 - object_mask) × ignore_mask × BCE(0, pred_conf)
L_cls = object_mask × BCE(true_class, pred_class)

最后对所有网格、anchor、batch 求和,再做归一化。