Real-ESRGAN进行




2023-11-11

blog_main_img

Segment Anything Model,通常简称 SAM,是一类“可提示分割”模型。它不是传统意义上只会识别固定类别的分割网络,而是接收点、框、粗略 mask 等提示,然后返回目标区域的候选遮罩。

这件事很适合做交互式标注、抠图、遥感区域提取、医学影像辅助、工业缺陷圈选、自动化数据集制作。你不需要为每个类别都重新训练模型,先用提示把对象切出来,再用业务规则、分类模型或人工审核补上语义。

SAM能干哈

传统分割常见痛点是:模型只认识训练过的类别,换场景就容易掉链子;如果要做精细标注,人工一笔一笔描边很累。

SAM 的思路更像一个通用分割助手:

给一个点:告诉模型目标大概在这里
给一个框:告诉模型目标大概在这个范围
给一个粗 mask:让模型沿着已有区域继续细化
不给提示:让模型自动扫出一批候选区域

它返回的是 mask,不是类别名。比如你点了图片里的杯子,它能帮你切出杯子轮廓,但“这是杯子”这个语义仍然要由你的业务系统或别的模型来判断。

SAM 提示分割流程

三段式结构:图像、提示、mask

可以把 SAM 的工作流拆成三段:

image encoder:把整张图编码成特征
prompt encoder:把点、框、mask 等提示编码进去
mask decoder:结合图像特征和提示,输出候选遮罩

工程里最重要的优化点是:同一张图可以先 set_image,后续不同点、不同框都复用图像特征。这样交互标注时,用户连续点几次,体验会更顺。

环境准备

安装常用依赖:

pip install torch torchvision opencv-python matplotlib numpy
pip install git+https://github.com/facebookresearch/segment-anything.git

模型权重需要从官方仓库说明里下载,放到本地路径。示例里用 vit_h,如果显存紧张,可以换更小的模型类型。

import torch
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor


checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

这里的 predictor 适合交互式分割:先加载图片,再喂点或框。

用一个点切出目标

点提示分两类:正点表示“这里是目标”,负点表示“这里不是目标”。正负点结合起来,可以快速修正边界。

image_bgr = cv2.imread("demo.jpg")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

predictor.set_image(image_rgb)

input_points = np.array([
    [420, 260],
    [510, 310],
])
input_labels = np.array([1, 0])

masks, scores, logits = predictor.predict(
    point_coords=input_points,
    point_labels=input_labels,
    multimask_output=True,
)

best_index = int(np.argmax(scores))
best_mask = masks[best_index]

print(scores[best_index], best_mask.shape)

multimask_output=True 会返回多个候选结果。交互工具里可以把候选 mask 展示给用户,或者按 score 选一个默认结果。

可视化 mask:先把效果看清楚

做分割任务,肉眼检查很重要。下面这个函数把 mask 半透明叠到原图上。

def overlay_mask(image_rgb, mask, color=(34, 211, 238), alpha=0.55):
    canvas = image_rgb.copy()
    color_layer = np.zeros_like(canvas)
    color_layer[:, :] = color

    mask_3d = mask.astype(bool)[:, :, None]
    canvas = np.where(
        mask_3d,
        (canvas * (1 - alpha) + color_layer * alpha).astype(np.uint8),
        canvas,
    )
    return canvas


preview = overlay_mask(image_rgb, best_mask)
preview_bgr = cv2.cvtColor(preview, cv2.COLOR_RGB2BGR)
cv2.imwrite("mask_preview.png", preview_bgr)

这类预览图非常适合调试:点是不是点偏了、负点有没有起作用、mask 有没有吞掉背景,一眼就能看出来。

用框提示更稳

如果你已经有检测框,比如来自人工拖框、目标检测模型或前端框选,SAM 用 box prompt 会更稳。

box = np.array([210, 120, 640, 520])

masks, scores, logits = predictor.predict(
    box=box,
    multimask_output=True,
)

best_mask = masks[int(np.argmax(scores))]
preview = overlay_mask(image_rgb, best_mask, color=(245, 158, 11))
cv2.imwrite("box_mask_preview.png", cv2.cvtColor(preview, cv2.COLOR_RGB2BGR))

框提示适合“先定位、再精修”的流程。比如检测模型先找出所有商品框,SAM 再把每个商品轮廓切出来,最后做背景替换或素材抠图。

自动分割:让模型先扫一遍

如果你没有点和框,可以用 SamAutomaticMaskGenerator 自动生成候选 mask。

SAM 自动分割遮罩地图

from segment_anything import SamAutomaticMaskGenerator


mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=24,
    pred_iou_thresh=0.88,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    min_mask_region_area=300,
)

annotations = mask_generator.generate(image_rgb)
print(len(annotations))

annotations = sorted(
    annotations,
    key=lambda item: item["area"],
    reverse=True,
)

自动分割经常会“多切”。这不一定是坏事,关键是后处理要接住。你可以按面积、置信度、矩形范围、与业务 ROI 的交集来筛。

def filter_annotations(annotations, min_area=800, max_area=180000):
    result = []
    for item in annotations:
        area = item["area"]
        if min_area <= area <= max_area:
            result.append(item)
    return result


clean_annotations = filter_annotations(annotations)

把 mask 保存成可用资产

单纯展示不够,业务里通常要保存 mask 图、裁剪图或 JSON 元数据。

def save_mask_assets(image_rgb, mask, output_prefix):
    mask_uint8 = (mask.astype(np.uint8) * 255)
    cv2.imwrite(f"{output_prefix}_mask.png", mask_uint8)

    rgba = np.zeros((*image_rgb.shape[:2], 4), dtype=np.uint8)
    rgba[:, :, :3] = image_rgb
    rgba[:, :, 3] = mask_uint8

    rgba_bgra = cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA)
    cv2.imwrite(f"{output_prefix}_cutout.png", rgba_bgra)


save_mask_assets(image_rgb, best_mask, "object_001")

mask.png 适合继续做算法处理,cutout.png 适合做前端预览或设计素材。不要只保存截图,原始 mask 才是后续处理的关键资产。

轮廓、面积和外接框

拿到 mask 后,可以用 OpenCV 做几何信息提取。

def mask_geometry(mask):
    mask_uint8 = (mask.astype(np.uint8) * 255)
    contours, _ = cv2.findContours(
        mask_uint8,
        cv2.RETR_EXTERNAL,
        cv2.CHAIN_APPROX_SIMPLE,
    )
    if not contours:
        return None

    contour = max(contours, key=cv2.contourArea)
    area = float(cv2.contourArea(contour))
    x, y, w, h = cv2.boundingRect(contour)

    return {
        "area": area,
        "bbox": [int(x), int(y), int(w), int(h)],
        "points": contour.squeeze(1).astype(int).tolist(),
    }


geometry = mask_geometry(best_mask)
print(geometry["bbox"] if geometry else "empty mask")

有了这些信息,你就能做区域筛选、目标排序、尺寸估算、前端多边形展示。SAM 负责边界,OpenCV 负责把边界变成业务可用的数据。

封装成一个 Python 类

项目里不要到处散落 predictor.set_imagepredictor.predict。可以封一个服务类,把模型加载、图片预处理、点提示、框提示都收起来。

SAM Python 工程管线

class SegmentAnythingService:
    def __init__(self, checkpoint, model_type="vit_h"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = sam_model_registry[model_type](checkpoint=checkpoint)
        self.model.to(device=self.device)
        self.predictor = SamPredictor(self.model)

    def set_image_path(self, image_path):
        image_bgr = cv2.imread(image_path)
        if image_bgr is None:
            raise ValueError(f"cannot read image: {image_path}")
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        self.predictor.set_image(image_rgb)
        return image_rgb

    def predict_by_points(self, points, labels):
        masks, scores, _ = self.predictor.predict(
            point_coords=np.array(points),
            point_labels=np.array(labels),
            multimask_output=True,
        )
        best_index = int(np.argmax(scores))
        return masks[best_index], float(scores[best_index])

    def predict_by_box(self, box):
        masks, scores, _ = self.predictor.predict(
            box=np.array(box),
            multimask_output=True,
        )
        best_index = int(np.argmax(scores))
        return masks[best_index], float(scores[best_index])

调用时就清爽多了:

service = SegmentAnythingService("sam_vit_h_4b8939.pth")
image_rgb = service.set_image_path("demo.jpg")

mask, score = service.predict_by_points(
    points=[[420, 260], [510, 310]],
    labels=[1, 0],
)

preview = overlay_mask(image_rgb, mask)
cv2.imwrite("service_preview.png", cv2.cvtColor(preview, cv2.COLOR_RGB2BGR))
print(score)

做成接口时要注意什么

SAM 模型比较重,接口服务不能每个请求都加载一次模型。常见结构是:

应用启动时加载模型
请求进来读取图片和提示
同一张图片复用图像特征
返回 mask、预览图、轮廓和置信度
大图和批任务走队列

如果做前端标注工具,可以让前端传点坐标或框坐标,后端返回 mask 的 PNG 地址和轮廓点。

from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel


app = FastAPI()
service = SegmentAnythingService("sam_vit_h_4b8939.pth")
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)


class PointRequest(BaseModel):
    image_path: str
    points: list[list[int]]
    labels: list[int]


@app.post("/segment/points")
def segment_by_points(request: PointRequest):
    image_rgb = service.set_image_path(request.image_path)
    mask, score = service.predict_by_points(
        points=request.points,
        labels=request.labels,
    )

    name = Path(request.image_path).stem
    output_prefix = output_dir / f"{name}_sam"
    save_mask_assets(image_rgb, mask, str(output_prefix))

    geometry = mask_geometry(mask)
    return {
        "score": score,
        "mask": f"{output_prefix}_mask.png",
        "cutout": f"{output_prefix}_cutout.png",
        "geometry": geometry,
    }

这只是示例结构。真正上线时还要处理权限、文件校验、并发队列、显存保护、任务取消和结果清理。

常见业务玩法

交互式标注:用户点一下,SAM 给 mask,人工只做微调。

检测框精修:目标检测给 box,SAM 把粗框变成精细轮廓。

商品抠图:框选商品后生成透明底素材,再接排版系统。

遥感区域提取:用点或框圈出道路、建筑、地块,再转成多边形。

医学辅助标注:医生给少量提示,模型生成候选区域,最终由专业人员确认。

缺陷区域圈选:工业图像里先框出可疑区域,再用 mask 提取边界。