2023-12-29
Real-ESRGAN 是一个面向真实场景图像修复和超分辨率的开源项目。 它的常见使用方式是:给一张图片或一个目录,输出放大和修复后的结果。
Real-ESRGAN 官方仓库提供了 Python 推理脚本,也提供了便携执行文件路线。
如果只是偶尔处理几张图,命令行已经很好用;如果你想做一个可交互的小服务,Flask 更适合做外壳。
用 Flask 包一层的好处很直接:
tile在 Mac 上要额外注意一点:Real-ESRGAN 的官方推理脚本默认优先找 CUDA,否则走 CPU;而 Apple Silicon 上常见的是 PyTorch MPS。
所以如果你想尝试 MPS,需要在自定义服务里显式把 device=torch.device("mps") 传给 RealESRGANer。
推荐先在独立虚拟环境里做,不要把依赖塞进系统 Python。
python3 -m venv .venv
source .venv/bin/activate
python -m pip install -U pip setuptools wheel
安装 PyTorch:
python -m pip install torch torchvision torchaudio
检查 MPS 是否可用:
import torch
print("mps built:", torch.backends.mps.is_built())
print("mps available:", torch.backends.mps.is_available())
如果你的机器不支持 MPS,也可以用 CPU 跑,只是速度会慢一些。
对 Flask 服务来说,设备选择建议写成函数,不要把平台判断散落在业务代码里。
import torch
def pick_device():
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
接着安装 Real-ESRGAN 相关依赖。官方仓库的安装流程包括 basicsr、facexlib、gfpgan、项目 requirements 和开发模式安装。
git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN
python -m pip install basicsr
python -m pip install facexlib
python -m pip install gfpgan
python -m pip install -r requirements.txt
python setup.py develop
再装 Flask 服务需要的依赖:
python -m pip install flask pillow
可以把 Flask 外壳和 Real-ESRGAN 仓库放在同一个父目录下,也可以直接在 Real-ESRGAN 项目里新建服务文件。
这里用一个更清晰的结构:
realesrgan-flask-mac/
├── app.py
├── super_resolution.py
├── runtime_data/
│ ├── uploads/
│ ├── outputs/
│ └── weights/
└── Real-ESRGAN/
如果你把服务文件放在 Real-ESRGAN 仓库外面,记得让 Python 能找到 realesrgan 包。
最省事的方式是先进入 Real-ESRGAN 仓库执行 python setup.py develop,让包以开发模式挂到环境里。
Real-ESRGAN 的推理核心是 RealESRGANer。
官方脚本里会根据 model_name 选择网络结构、权重地址和放大倍率。我们在 Flask 服务里也可以封成一个类,让模型只初始化一次。
下面是一个可改造的 super_resolution.py:
from pathlib import Path
import threading
import cv2
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
ROOT = Path(__file__).resolve().parent
WEIGHT_DIR = ROOT / "runtime_data" / "weights"
WEIGHT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_URLS = {
"RealESRGAN_x4plus": (
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
4,
),
"RealESRGAN_x4plus_anime_6B": (
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
4,
),
"realesr-animevideov3": (
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
4,
),
}
def pick_device():
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def build_network(model_name: str):
if model_name == "RealESRGAN_x4plus":
return RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if model_name == "RealESRGAN_x4plus_anime_6B":
return RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6,
num_grow_ch=32,
scale=4,
)
if model_name == "realesr-animevideov3":
return SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=16,
upscale=4,
act_type="prelu",
)
raise ValueError(f"unsupported model: {model_name}")
继续写服务类:
class RealESRGANService:
def __init__(self, model_name="RealESRGAN_x4plus", tile=256):
if model_name not in MODEL_URLS:
raise ValueError(f"unsupported model: {model_name}")
url, scale = MODEL_URLS[model_name]
model_path = WEIGHT_DIR / f"{model_name}.pth"
if not model_path.exists():
downloaded = load_file_from_url(
url=url,
model_dir=str(WEIGHT_DIR),
progress=True,
file_name=model_path.name,
)
model_path = Path(downloaded)
self.device = pick_device()
self.lock = threading.Lock()
self.upsampler = RealESRGANer(
scale=scale,
model_path=str(model_path),
model=build_network(model_name),
tile=tile,
tile_pad=10,
pre_pad=0,
half=False,
device=self.device,
)
def upscale_file(self, input_path: Path, output_path: Path, outscale=4.0):
image = cv2.imread(str(input_path), cv2.IMREAD_UNCHANGED)
if image is None:
raise ValueError("image decode failed")
with self.lock:
output, _ = self.upsampler.enhance(image, outscale=outscale)
output_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(output_path), output)
return output_path
这里有几个部署点:
half=False:Mac 上优先稳一点,尤其是 CPU 或 MPS 混合场景device=self.device:让 Apple Silicon 可以尝试 MPStile=256:降低内存压力,图片很大时比整图推理稳threading.Lock():避免 Flask 多请求同时操作同一个 upsampler 对象如果你只在 CPU 上跑,tile 可以更小一些;如果机器内存和显存余量比较充足,可以逐步调大。
Flask 官方文件上传文档里有几个很关键的点:从 request.files 取上传文件,保存前使用 secure_filename(),并通过 MAX_CONTENT_LENGTH 限制上传体积。
下面是 app.py:
from pathlib import Path
from uuid import uuid4
from flask import Flask, jsonify, request, send_from_directory
from werkzeug.utils import secure_filename
from super_resolution import RealESRGANService
ROOT = Path(__file__).resolve().parent
UPLOAD_DIR = ROOT / "runtime_data" / "uploads"
OUTPUT_DIR = ROOT / "runtime_data" / "outputs"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "webp"}
app = Flask(__name__)
app.config["MAX_CONTENT_LENGTH"] = 24 * 1000 * 1000
sr_service = RealESRGANService(
model_name="RealESRGAN_x4plus",
tile=256,
)
def allowed_file(filename: str):
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
继续写 API:
@app.post("/api/upscale")
def upscale():
if "file" not in request.files:
return jsonify({"error": "missing file field"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "empty filename"}), 400
if not allowed_file(file.filename):
return jsonify({"error": "unsupported file type"}), 400
try:
outscale = float(request.form.get("scale", 4))
except ValueError:
return jsonify({"error": "scale must be a number"}), 400
outscale = max(1.0, min(outscale, 4.0))
safe_name = secure_filename(file.filename)
suffix = Path(safe_name).suffix.lower() or ".png"
job_id = uuid4().hex
input_path = UPLOAD_DIR / f"{job_id}{suffix}"
output_path = OUTPUT_DIR / f"{job_id}_sr.png"
file.save(input_path)
try:
sr_service.upscale_file(input_path, output_path, outscale=outscale)
except Exception as exc:
return jsonify({"error": str(exc)}), 500
return jsonify(
{
"job_id": job_id,
"download_url": f"/outputs/{output_path.name}",
}
)
@app.get("/outputs/<name>")
def output_file(name):
return send_from_directory(OUTPUT_DIR, name, as_attachment=True)
@app.get("/health")
def health():
return jsonify({"status": "ok", "device": str(sr_service.device)})
if __name__ == "__main__":
app.run(host="127.0.0.1", port=5000)
启动服务:
python app.py
用 curl 测一下:
curl -F "[email protected]" -F "scale=4" http://127.0.0.1:5000/api/upscale
返回会包含下载路径:
{
"job_id": "example",
"download_url": "/outputs/example_sr.png"
}
如果你只要 API,上面已经够用。
如果想让别人直接在浏览器里上传,可以加一个非常轻的首页。
@app.get("/")
def index():
return """
<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<title>Real-ESRGAN Flask</title>
</head>
<body>
<h2>Real-ESRGAN 图片超分</h2>
<form action="/api/upscale" method="post" enctype="multipart/form-data">
<p><input type="file" name="file" accept="image/*"></p>
<p>
<label>Scale</label>
<input type="number" name="scale" value="4" min="1" max="4" step="0.5">
</p>
<button type="submit">开始处理</button>
</form>
</body>
</html>
"""
这个页面很简陋,但足够做本地测试。
如果你要更好的体验,可以加进度提示、处理前后对比图、历史记录、图片尺寸展示。
第一次请求通常会触发模型下载、权重加载、设备初始化。
更推荐在 Flask 启动时就初始化 RealESRGANService,不要等用户上传后再加载模型。
Real-ESRGAN 支持 tile,这就是 Mac 部署里非常实用的参数。
可以先从 tile=256 开始,如果机器余量大,再试 384 或 512。
PyTorch 的 MPS 后端能让 Apple Silicon 调用 Metal 加速,但并不代表所有模型和算子都同样顺滑。
如果遇到 MPS 相关报错,可以先走 CPU,或者开启 fallback 再观察:
export PYTORCH_ENABLE_MPS_FALLBACK=1
python app.py
服务代码里也可以临时强制 CPU:
def pick_device():
return torch.device("cpu")
Flask 的开发服务器可以处理多个请求,但 Real-ESRGAN 推理不适合无脑并发。
如果同一进程里共享一个模型对象,建议用锁串行化推理;如果要做排队,可以接一个任务队列,把上传和推理拆开。
本地工具可以直接 python app.py。
如果要长期给团队使用,建议至少加一个 WSGI 服务器,并限制 worker 数量:
python -m pip install gunicorn
gunicorn "app:app" --bind 127.0.0.1:5000 --workers 1 --threads 2
workers 不要开太多,因为每个 worker 都可能加载一份模型,内存会明显上升。
Real-ESRGAN 官方仓库里提供了多个推理模型,Flask 服务可以把它们做成表单参数或配置项:
RealESRGAN_x4plus:通用图片,适合照片、自然场景RealESRGAN_x4plus_anime_6B:动漫插图,模型更轻realesr-animevideov3:偏动漫视频帧,结构更小realesr-general-x4v3:通用小模型,并支持降噪强度相关参数如果是本地 Mac 小服务,建议先从轻模型开始。
通用照片用 RealESRGAN_x4plus,动漫图用 RealESRGAN_x4plus_anime_6B,大批量任务再考虑更细的队列和缓存。
图片服务很容易越跑越脏。建议一开始就把边界写进去:
secure_filename()runtime_data/uploads 和 runtime_data/outputs一个简单的清理函数可以这样写:
from pathlib import Path
def cleanup_folder(folder: Path, keep_latest=200):
files = [p for p in folder.iterdir() if p.is_file()]
files.sort(key=lambda p: p.stat().st_mtime, reverse=True)
for old_file in files[keep_latest:]:
old_file.unlink(missing_ok=True)
可以在服务启动时、接口处理后、或单独的维护脚本里调用它。
本地工具不一定需要复杂调度,但别让输出目录无限长大。
把上面的步骤压缩成一组常用命令,大概是这样:
mkdir realesrgan-flask-mac
cd realesrgan-flask-mac
python3 -m venv .venv
source .venv/bin/activate
python -m pip install -U pip setuptools wheel
python -m pip install torch torchvision torchaudio
python -m pip install flask pillow
git clone https://github.com/xinntao/Real-ESRGAN.git
cd Real-ESRGAN
python -m pip install basicsr facexlib gfpgan
python -m pip install -r requirements.txt
python setup.py develop
cd ..
mkdir -p runtime_data/uploads runtime_data/outputs runtime_data/weights
python app.py