2023-11-11
Segment Anything Model,通常简称 SAM,是一类“可提示分割”模型。它不是传统意义上只会识别固定类别的分割网络,而是接收点、框、粗略 mask 等提示,然后返回目标区域的候选遮罩。
这件事很适合做交互式标注、抠图、遥感区域提取、医学影像辅助、工业缺陷圈选、自动化数据集制作。你不需要为每个类别都重新训练模型,先用提示把对象切出来,再用业务规则、分类模型或人工审核补上语义。
传统分割常见痛点是:模型只认识训练过的类别,换场景就容易掉链子;如果要做精细标注,人工一笔一笔描边很累。
SAM 的思路更像一个通用分割助手:
给一个点:告诉模型目标大概在这里
给一个框:告诉模型目标大概在这个范围
给一个粗 mask:让模型沿着已有区域继续细化
不给提示:让模型自动扫出一批候选区域
它返回的是 mask,不是类别名。比如你点了图片里的杯子,它能帮你切出杯子轮廓,但“这是杯子”这个语义仍然要由你的业务系统或别的模型来判断。
可以把 SAM 的工作流拆成三段:
image encoder:把整张图编码成特征
prompt encoder:把点、框、mask 等提示编码进去
mask decoder:结合图像特征和提示,输出候选遮罩
工程里最重要的优化点是:同一张图可以先 set_image,后续不同点、不同框都复用图像特征。这样交互标注时,用户连续点几次,体验会更顺。
安装常用依赖:
pip install torch torchvision opencv-python matplotlib numpy
pip install git+https://github.com/facebookresearch/segment-anything.git
模型权重需要从官方仓库说明里下载,放到本地路径。示例里用 vit_h,如果显存紧张,可以换更小的模型类型。
import torch
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
这里的 predictor 适合交互式分割:先加载图片,再喂点或框。
点提示分两类:正点表示“这里是目标”,负点表示“这里不是目标”。正负点结合起来,可以快速修正边界。
image_bgr = cv2.imread("demo.jpg")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
predictor.set_image(image_rgb)
input_points = np.array([
[420, 260],
[510, 310],
])
input_labels = np.array([1, 0])
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=True,
)
best_index = int(np.argmax(scores))
best_mask = masks[best_index]
print(scores[best_index], best_mask.shape)
multimask_output=True 会返回多个候选结果。交互工具里可以把候选 mask 展示给用户,或者按 score 选一个默认结果。
做分割任务,肉眼检查很重要。下面这个函数把 mask 半透明叠到原图上。
def overlay_mask(image_rgb, mask, color=(34, 211, 238), alpha=0.55):
canvas = image_rgb.copy()
color_layer = np.zeros_like(canvas)
color_layer[:, :] = color
mask_3d = mask.astype(bool)[:, :, None]
canvas = np.where(
mask_3d,
(canvas * (1 - alpha) + color_layer * alpha).astype(np.uint8),
canvas,
)
return canvas
preview = overlay_mask(image_rgb, best_mask)
preview_bgr = cv2.cvtColor(preview, cv2.COLOR_RGB2BGR)
cv2.imwrite("mask_preview.png", preview_bgr)
这类预览图非常适合调试:点是不是点偏了、负点有没有起作用、mask 有没有吞掉背景,一眼就能看出来。
如果你已经有检测框,比如来自人工拖框、目标检测模型或前端框选,SAM 用 box prompt 会更稳。
box = np.array([210, 120, 640, 520])
masks, scores, logits = predictor.predict(
box=box,
multimask_output=True,
)
best_mask = masks[int(np.argmax(scores))]
preview = overlay_mask(image_rgb, best_mask, color=(245, 158, 11))
cv2.imwrite("box_mask_preview.png", cv2.cvtColor(preview, cv2.COLOR_RGB2BGR))
框提示适合“先定位、再精修”的流程。比如检测模型先找出所有商品框,SAM 再把每个商品轮廓切出来,最后做背景替换或素材抠图。
如果你没有点和框,可以用 SamAutomaticMaskGenerator 自动生成候选 mask。
from segment_anything import SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=24,
pred_iou_thresh=0.88,
stability_score_thresh=0.92,
crop_n_layers=1,
min_mask_region_area=300,
)
annotations = mask_generator.generate(image_rgb)
print(len(annotations))
annotations = sorted(
annotations,
key=lambda item: item["area"],
reverse=True,
)
自动分割经常会“多切”。这不一定是坏事,关键是后处理要接住。你可以按面积、置信度、矩形范围、与业务 ROI 的交集来筛。
def filter_annotations(annotations, min_area=800, max_area=180000):
result = []
for item in annotations:
area = item["area"]
if min_area <= area <= max_area:
result.append(item)
return result
clean_annotations = filter_annotations(annotations)
单纯展示不够,业务里通常要保存 mask 图、裁剪图或 JSON 元数据。
def save_mask_assets(image_rgb, mask, output_prefix):
mask_uint8 = (mask.astype(np.uint8) * 255)
cv2.imwrite(f"{output_prefix}_mask.png", mask_uint8)
rgba = np.zeros((*image_rgb.shape[:2], 4), dtype=np.uint8)
rgba[:, :, :3] = image_rgb
rgba[:, :, 3] = mask_uint8
rgba_bgra = cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA)
cv2.imwrite(f"{output_prefix}_cutout.png", rgba_bgra)
save_mask_assets(image_rgb, best_mask, "object_001")
mask.png 适合继续做算法处理,cutout.png 适合做前端预览或设计素材。不要只保存截图,原始 mask 才是后续处理的关键资产。
拿到 mask 后,可以用 OpenCV 做几何信息提取。
def mask_geometry(mask):
mask_uint8 = (mask.astype(np.uint8) * 255)
contours, _ = cv2.findContours(
mask_uint8,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE,
)
if not contours:
return None
contour = max(contours, key=cv2.contourArea)
area = float(cv2.contourArea(contour))
x, y, w, h = cv2.boundingRect(contour)
return {
"area": area,
"bbox": [int(x), int(y), int(w), int(h)],
"points": contour.squeeze(1).astype(int).tolist(),
}
geometry = mask_geometry(best_mask)
print(geometry["bbox"] if geometry else "empty mask")
有了这些信息,你就能做区域筛选、目标排序、尺寸估算、前端多边形展示。SAM 负责边界,OpenCV 负责把边界变成业务可用的数据。
项目里不要到处散落 predictor.set_image 和 predictor.predict。可以封一个服务类,把模型加载、图片预处理、点提示、框提示都收起来。
class SegmentAnythingService:
def __init__(self, checkpoint, model_type="vit_h"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = sam_model_registry[model_type](checkpoint=checkpoint)
self.model.to(device=self.device)
self.predictor = SamPredictor(self.model)
def set_image_path(self, image_path):
image_bgr = cv2.imread(image_path)
if image_bgr is None:
raise ValueError(f"cannot read image: {image_path}")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image_rgb)
return image_rgb
def predict_by_points(self, points, labels):
masks, scores, _ = self.predictor.predict(
point_coords=np.array(points),
point_labels=np.array(labels),
multimask_output=True,
)
best_index = int(np.argmax(scores))
return masks[best_index], float(scores[best_index])
def predict_by_box(self, box):
masks, scores, _ = self.predictor.predict(
box=np.array(box),
multimask_output=True,
)
best_index = int(np.argmax(scores))
return masks[best_index], float(scores[best_index])
调用时就清爽多了:
service = SegmentAnythingService("sam_vit_h_4b8939.pth")
image_rgb = service.set_image_path("demo.jpg")
mask, score = service.predict_by_points(
points=[[420, 260], [510, 310]],
labels=[1, 0],
)
preview = overlay_mask(image_rgb, mask)
cv2.imwrite("service_preview.png", cv2.cvtColor(preview, cv2.COLOR_RGB2BGR))
print(score)
SAM 模型比较重,接口服务不能每个请求都加载一次模型。常见结构是:
应用启动时加载模型
请求进来读取图片和提示
同一张图片复用图像特征
返回 mask、预览图、轮廓和置信度
大图和批任务走队列
如果做前端标注工具,可以让前端传点坐标或框坐标,后端返回 mask 的 PNG 地址和轮廓点。
from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
service = SegmentAnythingService("sam_vit_h_4b8939.pth")
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
class PointRequest(BaseModel):
image_path: str
points: list[list[int]]
labels: list[int]
@app.post("/segment/points")
def segment_by_points(request: PointRequest):
image_rgb = service.set_image_path(request.image_path)
mask, score = service.predict_by_points(
points=request.points,
labels=request.labels,
)
name = Path(request.image_path).stem
output_prefix = output_dir / f"{name}_sam"
save_mask_assets(image_rgb, mask, str(output_prefix))
geometry = mask_geometry(mask)
return {
"score": score,
"mask": f"{output_prefix}_mask.png",
"cutout": f"{output_prefix}_cutout.png",
"geometry": geometry,
}
这只是示例结构。真正上线时还要处理权限、文件校验、并发队列、显存保护、任务取消和结果清理。
交互式标注:用户点一下,SAM 给 mask,人工只做微调。
检测框精修:目标检测给 box,SAM 把粗框变成精细轮廓。
商品抠图:框选商品后生成透明底素材,再接排版系统。
遥感区域提取:用点或框圈出道路、建筑、地块,再转成多边形。
医学辅助标注:医生给少量提示,模型生成候选区域,最终由专业人员确认。
缺陷区域圈选:工业图像里先框出可疑区域,再用 mask 提取边界。