Real-ESRGAN进行图像修复




2023-12-29

blog_main_img

Real-ESRGAN 是一个面向真实场景图像修复和超分辨率的开源项目。 它的常见使用方式是:给一张图片或一个目录,输出放大和修复后的结果。

为什么在 Mac 上用 Flask 包 Real-ESRGAN

Real-ESRGAN 官方仓库提供了 Python 推理脚本,也提供了便携执行文件路线。
如果只是偶尔处理几张图,命令行已经很好用;如果你想做一个可交互的小服务,Flask 更适合做外壳。

部署总览

用 Flask 包一层的好处很直接:

  • 对外暴露一个上传接口,调用方式统一
  • 模型只加载一次,不用每张图重新启动脚本
  • 可以限制文件类型和大小,避免随便上传
  • 可以根据图片大小动态设置 tile
  • 可以做结果缓存、任务队列、日志和监控

在 Mac 上要额外注意一点:Real-ESRGAN 的官方推理脚本默认优先找 CUDA,否则走 CPU;而 Apple Silicon 上常见的是 PyTorch MPS。
所以如果你想尝试 MPS,需要在自定义服务里显式把 device=torch.device("mps") 传给 RealESRGANer

环境准备:先让 Python 推理能跑起来

推荐先在独立虚拟环境里做,不要把依赖塞进系统 Python。

python3 -m venv .venv
source .venv/bin/activate
python -m pip install -U pip setuptools wheel

安装 PyTorch:

python -m pip install torch torchvision torchaudio

检查 MPS 是否可用:

import torch

print("mps built:", torch.backends.mps.is_built())
print("mps available:", torch.backends.mps.is_available())

如果你的机器不支持 MPS,也可以用 CPU 跑,只是速度会慢一些。
对 Flask 服务来说,设备选择建议写成函数,不要把平台判断散落在业务代码里。

import torch


def pick_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

接着安装 Real-ESRGAN 相关依赖。官方仓库的安装流程包括 basicsrfacexlibgfpgan、项目 requirements 和开发模式安装。

git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN

python -m pip install basicsr
python -m pip install facexlib
python -m pip install gfpgan
python -m pip install -r requirements.txt
python setup.py develop

再装 Flask 服务需要的依赖:

python -m pip install flask pillow

项目目录建议

可以把 Flask 外壳和 Real-ESRGAN 仓库放在同一个父目录下,也可以直接在 Real-ESRGAN 项目里新建服务文件。
这里用一个更清晰的结构:

realesrgan-flask-mac/
├── app.py
├── super_resolution.py
├── runtime_data/
│   ├── uploads/
│   ├── outputs/
│   └── weights/
└── Real-ESRGAN/

如果你把服务文件放在 Real-ESRGAN 仓库外面,记得让 Python 能找到 realesrgan 包。
最省事的方式是先进入 Real-ESRGAN 仓库执行 python setup.py develop,让包以开发模式挂到环境里。

推理服务封装:模型单例加载

Real-ESRGAN 的推理核心是 RealESRGANer
官方脚本里会根据 model_name 选择网络结构、权重地址和放大倍率。我们在 Flask 服务里也可以封成一个类,让模型只初始化一次。

模型服务核心

下面是一个可改造的 super_resolution.py

from pathlib import Path
import threading

import cv2
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact


ROOT = Path(__file__).resolve().parent
WEIGHT_DIR = ROOT / "runtime_data" / "weights"
WEIGHT_DIR.mkdir(parents=True, exist_ok=True)


MODEL_URLS = {
    "RealESRGAN_x4plus": (
        "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
        4,
    ),
    "RealESRGAN_x4plus_anime_6B": (
        "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
        4,
    ),
    "realesr-animevideov3": (
        "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
        4,
    ),
}


def pick_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def build_network(model_name: str):
    if model_name == "RealESRGAN_x4plus":
        return RRDBNet(
            num_in_ch=3,
            num_out_ch=3,
            num_feat=64,
            num_block=23,
            num_grow_ch=32,
            scale=4,
        )

    if model_name == "RealESRGAN_x4plus_anime_6B":
        return RRDBNet(
            num_in_ch=3,
            num_out_ch=3,
            num_feat=64,
            num_block=6,
            num_grow_ch=32,
            scale=4,
        )

    if model_name == "realesr-animevideov3":
        return SRVGGNetCompact(
            num_in_ch=3,
            num_out_ch=3,
            num_feat=64,
            num_conv=16,
            upscale=4,
            act_type="prelu",
        )

    raise ValueError(f"unsupported model: {model_name}")

继续写服务类:

class RealESRGANService:
    def __init__(self, model_name="RealESRGAN_x4plus", tile=256):
        if model_name not in MODEL_URLS:
            raise ValueError(f"unsupported model: {model_name}")

        url, scale = MODEL_URLS[model_name]
        model_path = WEIGHT_DIR / f"{model_name}.pth"
        if not model_path.exists():
            downloaded = load_file_from_url(
                url=url,
                model_dir=str(WEIGHT_DIR),
                progress=True,
                file_name=model_path.name,
            )
            model_path = Path(downloaded)

        self.device = pick_device()
        self.lock = threading.Lock()

        self.upsampler = RealESRGANer(
            scale=scale,
            model_path=str(model_path),
            model=build_network(model_name),
            tile=tile,
            tile_pad=10,
            pre_pad=0,
            half=False,
            device=self.device,
        )

    def upscale_file(self, input_path: Path, output_path: Path, outscale=4.0):
        image = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
        if image is None:
            raise ValueError("image decode failed")

        with self.lock:
            output, _ = self.upsampler.enhance(image, outscale=outscale)

        output_path.parent.mkdir(parents=True, exist_ok=True)
        cv2.imwrite(str(output_path), output)
        return output_path

这里有几个部署点:

  • half=False:Mac 上优先稳一点,尤其是 CPU 或 MPS 混合场景
  • device=self.device:让 Apple Silicon 可以尝试 MPS
  • tile=256:降低内存压力,图片很大时比整图推理稳
  • threading.Lock():避免 Flask 多请求同时操作同一个 upsampler 对象

如果你只在 CPU 上跑,tile 可以更小一些;如果机器内存和显存余量比较充足,可以逐步调大。

Flask 上传接口:文件校验要认真写

Flask 官方文件上传文档里有几个很关键的点:从 request.files 取上传文件,保存前使用 secure_filename(),并通过 MAX_CONTENT_LENGTH 限制上传体积。

Flask 上传流程

下面是 app.py

from pathlib import Path
from uuid import uuid4

from flask import Flask, jsonify, request, send_from_directory
from werkzeug.utils import secure_filename

from super_resolution import RealESRGANService


ROOT = Path(__file__).resolve().parent
UPLOAD_DIR = ROOT / "runtime_data" / "uploads"
OUTPUT_DIR = ROOT / "runtime_data" / "outputs"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "webp"}

app = Flask(__name__)
app.config["MAX_CONTENT_LENGTH"] = 24 * 1000 * 1000

sr_service = RealESRGANService(
    model_name="RealESRGAN_x4plus",
    tile=256,
)


def allowed_file(filename: str):
    return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS

继续写 API:

@app.post("/api/upscale")
def upscale():
    if "file" not in request.files:
        return jsonify({"error": "missing file field"}), 400

    file = request.files["file"]
    if file.filename == "":
        return jsonify({"error": "empty filename"}), 400

    if not allowed_file(file.filename):
        return jsonify({"error": "unsupported file type"}), 400

    try:
        outscale = float(request.form.get("scale", 4))
    except ValueError:
        return jsonify({"error": "scale must be a number"}), 400

    outscale = max(1.0, min(outscale, 4.0))

    safe_name = secure_filename(file.filename)
    suffix = Path(safe_name).suffix.lower() or ".png"
    job_id = uuid4().hex

    input_path = UPLOAD_DIR / f"{job_id}{suffix}"
    output_path = OUTPUT_DIR / f"{job_id}_sr.png"

    file.save(input_path)

    try:
        sr_service.upscale_file(input_path, output_path, outscale=outscale)
    except Exception as exc:
        return jsonify({"error": str(exc)}), 500

    return jsonify(
        {
            "job_id": job_id,
            "download_url": f"/outputs/{output_path.name}",
        }
    )


@app.get("/outputs/<name>")
def output_file(name):
    return send_from_directory(OUTPUT_DIR, name, as_attachment=True)


@app.get("/health")
def health():
    return jsonify({"status": "ok", "device": str(sr_service.device)})


if __name__ == "__main__":
    app.run(host="127.0.0.1", port=5000)

启动服务:

python app.py

curl 测一下:

curl -F "[email protected]" -F "scale=4" http://127.0.0.1:5000/api/upscale

返回会包含下载路径:

{
  "job_id": "example",
  "download_url": "/outputs/example_sr.png"
}

给浏览器一个轻量页面

如果你只要 API,上面已经够用。
如果想让别人直接在浏览器里上传,可以加一个非常轻的首页。

@app.get("/")
def index():
    return """
    <!doctype html>
    <html lang="zh-CN">
    <head>
      <meta charset="utf-8">
      <title>Real-ESRGAN Flask</title>
    </head>
    <body>
      <h2>Real-ESRGAN 图片超分</h2>
      <form action="/api/upscale" method="post" enctype="multipart/form-data">
        <p><input type="file" name="file" accept="image/*"></p>
        <p>
          <label>Scale</label>
          <input type="number" name="scale" value="4" min="1" max="4" step="0.5">
        </p>
        <button type="submit">开始处理</button>
      </form>
    </body>
    </html>
    """

这个页面很简陋,但足够做本地测试。
如果你要更好的体验,可以加进度提示、处理前后对比图、历史记录、图片尺寸展示。

Mac 上最容易遇到的几个问题

1. 第一次请求很慢

第一次请求通常会触发模型下载、权重加载、设备初始化。
更推荐在 Flask 启动时就初始化 RealESRGANService,不要等用户上传后再加载模型。

2. 大图容易内存吃紧

Real-ESRGAN 支持 tile,这就是 Mac 部署里非常实用的参数。
可以先从 tile=256 开始,如果机器余量大,再试 384512

3. MPS 不一定对每个算子都顺

PyTorch 的 MPS 后端能让 Apple Silicon 调用 Metal 加速,但并不代表所有模型和算子都同样顺滑。
如果遇到 MPS 相关报错,可以先走 CPU,或者开启 fallback 再观察:

export PYTORCH_ENABLE_MPS_FALLBACK=1
python app.py

服务代码里也可以临时强制 CPU:

def pick_device():
    return torch.device("cpu")

4. 多请求并发会抢模型

Flask 的开发服务器可以处理多个请求,但 Real-ESRGAN 推理不适合无脑并发。
如果同一进程里共享一个模型对象,建议用锁串行化推理;如果要做排队,可以接一个任务队列,把上传和推理拆开。

5. 不要把开发服务器当生产入口

本地工具可以直接 python app.py
如果要长期给团队使用,建议至少加一个 WSGI 服务器,并限制 worker 数量:

python -m pip install gunicorn
gunicorn "app:app" --bind 127.0.0.1:5000 --workers 1 --threads 2

workers 不要开太多,因为每个 worker 都可能加载一份模型,内存会明显上升。

模型选择建议

Real-ESRGAN 官方仓库里提供了多个推理模型,Flask 服务可以把它们做成表单参数或配置项:

  • RealESRGAN_x4plus:通用图片,适合照片、自然场景
  • RealESRGAN_x4plus_anime_6B:动漫插图,模型更轻
  • realesr-animevideov3:偏动漫视频帧,结构更小
  • realesr-general-x4v3:通用小模型,并支持降噪强度相关参数

如果是本地 Mac 小服务,建议先从轻模型开始。
通用照片用 RealESRGAN_x4plus,动漫图用 RealESRGAN_x4plus_anime_6B,大批量任务再考虑更细的队列和缓存。

结果清理和安全边界

图片服务很容易越跑越脏。建议一开始就把边界写进去:

  • 限制上传扩展名
  • 使用 secure_filename()
  • 给上传体积设置上限
  • 输出文件使用随机 ID
  • 定期清理 runtime_data/uploadsruntime_data/outputs
  • 不把上传目录直接暴露成静态目录
  • 对外服务时加鉴权或放在内网

运行治理面板

一个简单的清理函数可以这样写:

from pathlib import Path


def cleanup_folder(folder: Path, keep_latest=200):
    files = [p for p in folder.iterdir() if p.is_file()]
    files.sort(key=lambda p: p.stat().st_mtime, reverse=True)

    for old_file in files[keep_latest:]:
        old_file.unlink(missing_ok=True)

可以在服务启动时、接口处理后、或单独的维护脚本里调用它。
本地工具不一定需要复杂调度,但别让输出目录无限长大。

一条更完整的部署命令流

把上面的步骤压缩成一组常用命令,大概是这样:

mkdir realesrgan-flask-mac
cd realesrgan-flask-mac

python3 -m venv .venv
source .venv/bin/activate
python -m pip install -U pip setuptools wheel
python -m pip install torch torchvision torchaudio
python -m pip install flask pillow

git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN
python -m pip install basicsr facexlib gfpgan
python -m pip install -r requirements.txt
python setup.py develop
cd ..

mkdir -p runtime_data/uploads runtime_data/outputs runtime_data/weights
python app.py