Nekochu's picture
unified view: single input for image or video
f2a1251
"""
Face Re-Aging with ONNX (CPU)
Based on Disney's FRAN (Face Re-Aging Network) architecture.
Model: face_reaging.onnx from VisoMaster-Fusion.
Supports image and video re-aging in a single unified view.
"""
import os
import shutil
import subprocess
import tempfile
import time
import glob as glob_mod
import cv2
import numpy as np
import onnxruntime as ort
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MAX_VIDEO_SECONDS = 30
MAX_FRAMES = 900
MODEL_PATH = "face_reaging.onnx"
REPO_ID = "Luminia/Face-ReAging-CPU"
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def get_model_path():
if os.path.exists(MODEL_PATH):
return MODEL_PATH
return hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
print("Loading ONNX model...")
_so = ort.SessionOptions()
_so.intra_op_num_threads = os.cpu_count()
_so.inter_op_num_threads = os.cpu_count()
sess = ort.InferenceSession(
get_model_path(),
providers=["CPUExecutionProvider"],
sess_options=_so,
)
print("Model loaded.")
# ---------------------------------------------------------------------------
# Face detection
# ---------------------------------------------------------------------------
_face_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
)
_dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx")
YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"
def _ensure_yunet():
global _dnn_model_path
if not os.path.exists(_dnn_model_path):
print("Downloading YuNet face detector...")
try:
path = hf_hub_download(
repo_id="opencv/opencv_zoo",
filename="models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
)
_dnn_model_path = path
except Exception:
import urllib.request
urllib.request.urlretrieve(YUNET_URL, _dnn_model_path)
print("YuNet downloaded.")
return _dnn_model_path
def detect_face_box(image_rgb: np.ndarray):
h, w = image_rgb.shape[:2]
try:
yunet_path = _ensure_yunet()
detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000)
_, faces = detector.detect(image_rgb)
if faces is not None and len(faces) > 0:
best_idx = int(np.argmax([f[2] * f[3] for f in faces]))
f = faces[best_idx]
x1, y1 = int(f[0]), int(f[1])
x2, y2 = int(f[0] + f[2]), int(f[1] + f[3])
return (max(x1, 0), max(y1, 0), min(x2, w), min(y2, h))
except Exception as e:
print(f"YuNet failed, falling back to Haar: {e}")
gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60))
if len(faces) == 0:
return None
best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces])
x, y, fw, fh = faces[best_idx]
return (x, y, x + fw, y + fh)
# ---------------------------------------------------------------------------
# Core inference
# ---------------------------------------------------------------------------
def crop_face_region(image_rgb, box):
h, w = image_rgb.shape[:2]
x1, y1, x2, y2 = box
face_w, face_h = x2 - x1, y2 - y1
margin_top = int(face_h * 0.63 * 0.85)
margin_bot = int(face_h * 0.37 * 0.85)
margin_x = int(face_w * 0.85 / 2)
margin_top += 2 * margin_x - margin_top - margin_bot
l_y, r_y = max(y1 - margin_top, 0), min(y2 + margin_bot, h)
l_x, r_x = max(x1 - margin_x, 0), min(x2 + margin_x, w)
return image_rgb[l_y:r_y, l_x:r_x, :], (l_x, l_y, r_x, r_y)
def create_blend_mask(crop_h, crop_w, feather=0.15):
mask = np.ones((crop_h, crop_w), dtype=np.float32)
by, bx = max(int(crop_h * feather), 1), max(int(crop_w * feather), 1)
for i in range(by):
a = i / by
mask[i, :] *= a
mask[crop_h - 1 - i, :] *= a
for j in range(bx):
a = j / bx
mask[:, j] *= a
mask[:, crop_w - 1 - j] *= a
return mask[:, :, np.newaxis]
def reage_frame(image_rgb, source_age, target_age):
box = detect_face_box(image_rgb)
if box is None:
return image_rgb
cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box)
crop_h, crop_w = cropped.shape[:2]
cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR)
img_t = cropped_resized.astype(np.float32) / 255.0
img_t = np.transpose(img_t, (2, 0, 1))
src_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32)
tgt_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32)
inp = np.concatenate([img_t, src_ch, tgt_ch], axis=0)[np.newaxis, ...]
delta = sess.run(None, {"input": inp})[0]
aged = np.clip(img_t + delta[0], 0.0, 1.0)
aged_hwc = (np.transpose(aged, (1, 2, 0)) * 255).astype(np.uint8)
aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR)
result = image_rgb.copy()
mask = create_blend_mask(crop_h, crop_w, feather=0.12)
region = result[l_y:r_y, l_x:r_x].astype(np.float32)
blended = region * (1 - mask) + aged_resized.astype(np.float32) * mask
result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8)
return result
# ---------------------------------------------------------------------------
# ffmpeg helpers
# ---------------------------------------------------------------------------
def _find_ffmpeg():
path = shutil.which("ffmpeg")
if path:
return path
for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]:
if os.path.isfile(p):
return p
raise gr.Error("ffmpeg not found.")
def _get_video_info(video_path):
ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin")
if not ffprobe:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return fps, count
try:
import json
r = subprocess.run(
[ffprobe, "-v", "quiet", "-print_format", "json",
"-show_streams", "-select_streams", "v:0", video_path],
capture_output=True, text=True, timeout=30,
)
stream = json.loads(r.stdout)["streams"][0]
num, den = stream.get("r_frame_rate", "25/1").split("/")
fps = float(num) / float(den)
nb = stream.get("nb_frames")
count = int(nb) if nb and nb != "N/A" else int(float(stream.get("duration", 0)) * fps)
return fps, count
except Exception:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return fps, count
def _extract_frames(video_path, out_dir):
ffmpeg = _find_ffmpeg()
cmd = [ffmpeg, "-i", video_path, "-vsync", "0", os.path.join(out_dir, "frame_%06d.png"), "-y"]
r = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if r.returncode != 0:
raise gr.Error(f"Frame extraction failed: {r.stderr[-500:]}")
def _assemble_video(frames_dir, output_path, fps, audio_source=None):
ffmpeg = _find_ffmpeg()
cmd = [ffmpeg, "-y", "-framerate", str(fps), "-i", os.path.join(frames_dir, "frame_%06d.png")]
if audio_source:
cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"]
cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "fast", "-crf", "20",
"-movflags", "+faststart", output_path]
r = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if r.returncode != 0:
raise gr.Error(f"Video assembly failed: {r.stderr[-500:]}")
# ---------------------------------------------------------------------------
# Unified process function
# ---------------------------------------------------------------------------
VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"}
def process(input_file, source_age, target_age, progress=gr.Progress()):
if input_file is None:
raise gr.Error("Please upload an image or video.")
t0 = time.time()
source_age, target_age = int(source_age), int(target_age)
# Determine if image or video
if isinstance(input_file, Image.Image):
# Direct PIL image from gr.Image
image_rgb = np.array(input_file.convert("RGB"))
box = detect_face_box(image_rgb)
if box is None:
raise gr.Error("No face detected. Please upload a clear photo with a visible face.")
result = reage_frame(image_rgb, source_age, target_age)
elapsed = time.time() - t0
info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years"
return Image.fromarray(result), None, info
# File path (could be image or video)
file_path = input_file if isinstance(input_file, str) else str(input_file)
ext = os.path.splitext(file_path)[1].lower()
if ext in VIDEO_EXTS:
# --- Video processing ---
fps, total_frames = _get_video_info(file_path)
duration = total_frames / max(fps, 1)
if duration > MAX_VIDEO_SECONDS:
raise gr.Error(f"Video is {duration:.1f}s (max {MAX_VIDEO_SECONDS}s). Please trim it.")
if total_frames > MAX_FRAMES:
raise gr.Error(f"Video has {total_frames} frames (max {MAX_FRAMES}).")
tmp_root = tempfile.mkdtemp(prefix="reage_")
frames_in = os.path.join(tmp_root, "in")
frames_out = os.path.join(tmp_root, "out")
os.makedirs(frames_in, exist_ok=True)
os.makedirs(frames_out, exist_ok=True)
try:
progress(0, desc="Extracting frames...")
_extract_frames(file_path, frames_in)
frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png")))
n_frames = len(frame_files)
if n_frames == 0:
raise gr.Error("No frames extracted. Is this a valid video?")
if n_frames > MAX_FRAMES:
raise gr.Error(f"{n_frames} frames (max {MAX_FRAMES}).")
faces_found, faces_missed = 0, 0
for idx, fpath in enumerate(frame_files):
progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...")
frame_bgr = cv2.imread(fpath)
if frame_bgr is None:
continue
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
box = detect_face_box(frame_rgb)
if box is not None:
result_rgb = reage_frame(frame_rgb, source_age, target_age)
faces_found += 1
else:
result_rgb = frame_rgb
faces_missed += 1
out_path = os.path.join(frames_out, os.path.basename(fpath))
cv2.imwrite(out_path, cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR))
progress(1.0, desc="Assembling video...")
output_path = os.path.join(tmp_root, "output.mp4")
_assemble_video(frames_out, output_path, fps, audio_source=file_path)
elapsed = time.time() - t0
speed = n_frames / max(elapsed, 0.01)
info = (f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | "
f"Faces: {faces_found} found, {faces_missed} skipped | "
f"{source_age} -> {target_age} years")
return None, output_path, info
except gr.Error:
raise
except Exception as e:
raise gr.Error(f"Video processing failed: {e}")
else:
# --- Image processing ---
image_rgb = cv2.imread(file_path)
if image_rgb is None:
raise gr.Error("Could not read the file. Please upload a valid image or video.")
image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
box = detect_face_box(image_rgb)
if box is None:
raise gr.Error("No face detected.")
result = reage_frame(image_rgb, source_age, target_age)
elapsed = time.time() - t0
info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years"
return Image.fromarray(result), None, info
# ---------------------------------------------------------------------------
# Gradio UI - Single unified view
# ---------------------------------------------------------------------------
with gr.Blocks(title="Face Re-Aging (CPU)") as demo:
gr.Markdown(
"# Face Re-Aging (CPU)\n"
"Upload an **image or video** to age or de-age faces. "
f"Videos: max {MAX_VIDEO_SECONDS}s, ~0.5-2 fps on CPU."
)
with gr.Row():
with gr.Column():
file_input = gr.File(
label="Drop Image or Video Here",
file_types=["image", "video"],
)
# Also accept pasted/webcam images
img_input = gr.Image(
type="pil", label="Or paste/capture an image",
visible=True,
)
src_age = gr.Slider(minimum=5, maximum=95, value=25, step=1,
label="Source Age (current)")
tgt_age = gr.Slider(minimum=5, maximum=95, value=65, step=1,
label="Target Age (desired)")
btn = gr.Button("Re-Age", variant="primary", size="lg")
with gr.Column():
img_output = gr.Image(type="pil", label="Result (Image)")
vid_output = gr.Video(label="Result (Video)")
info_box = gr.Textbox(label="Info", interactive=False)
def on_submit_file(file_obj, source_age, target_age, progress=gr.Progress()):
if file_obj is None:
raise gr.Error("Please upload a file.")
return process(file_obj, source_age, target_age, progress)
def on_submit_image(image, source_age, target_age, progress=gr.Progress()):
if image is None:
raise gr.Error("Please provide an image.")
return process(image, source_age, target_age, progress)
btn.click(
fn=on_submit_file,
inputs=[file_input, src_age, tgt_age],
outputs=[img_output, vid_output, info_box],
)
# Also trigger on image input (for paste/webcam)
img_input.change(
fn=on_submit_image,
inputs=[img_input, src_age, tgt_age],
outputs=[img_output, vid_output, info_box],
)
gr.Markdown(
"**Model:** `face_reaging.onnx` (118 MB) from "
"[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | "
"Based on [Disney FRAN](https://studios.disneyresearch.com/2022/11/30/production-ready-face-re-aging-for-visual-effects/)"
)
if __name__ == "__main__":
demo.launch(show_error=True, ssr_mode=False, theme="NoCrypt/miku")