| """ |
| Face Re-Aging with ONNX (CPU) |
| Based on Disney's FRAN (Face Re-Aging Network) architecture. |
| Model: face_reaging.onnx from VisoMaster-Fusion. |
| Supports image and video re-aging in a single unified view. |
| """ |
|
|
| import os |
| import shutil |
| import subprocess |
| import tempfile |
| import time |
| import glob as glob_mod |
|
|
| import cv2 |
| import numpy as np |
| import onnxruntime as ort |
| import gradio as gr |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
| MAX_VIDEO_SECONDS = 30 |
| MAX_FRAMES = 900 |
| MODEL_PATH = "face_reaging.onnx" |
| REPO_ID = "Luminia/Face-ReAging-CPU" |
|
|
| |
| |
| |
| def get_model_path(): |
| if os.path.exists(MODEL_PATH): |
| return MODEL_PATH |
| return hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH) |
|
|
| print("Loading ONNX model...") |
| _so = ort.SessionOptions() |
| _so.intra_op_num_threads = os.cpu_count() |
| _so.inter_op_num_threads = os.cpu_count() |
| sess = ort.InferenceSession( |
| get_model_path(), |
| providers=["CPUExecutionProvider"], |
| sess_options=_so, |
| ) |
| print("Model loaded.") |
|
|
| |
| |
| |
| _face_cascade = cv2.CascadeClassifier( |
| cv2.data.haarcascades + "haarcascade_frontalface_default.xml" |
| ) |
| _dnn_model_path = os.path.join(os.path.dirname(__file__), "face_detection_yunet_2023mar.onnx") |
| YUNET_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx" |
|
|
|
|
| def _ensure_yunet(): |
| global _dnn_model_path |
| if not os.path.exists(_dnn_model_path): |
| print("Downloading YuNet face detector...") |
| try: |
| path = hf_hub_download( |
| repo_id="opencv/opencv_zoo", |
| filename="models/face_detection_yunet/face_detection_yunet_2023mar.onnx", |
| ) |
| _dnn_model_path = path |
| except Exception: |
| import urllib.request |
| urllib.request.urlretrieve(YUNET_URL, _dnn_model_path) |
| print("YuNet downloaded.") |
| return _dnn_model_path |
|
|
|
|
| def detect_face_box(image_rgb: np.ndarray): |
| h, w = image_rgb.shape[:2] |
| try: |
| yunet_path = _ensure_yunet() |
| detector = cv2.FaceDetectorYN.create(yunet_path, "", (w, h), 0.5, 0.3, 5000) |
| _, faces = detector.detect(image_rgb) |
| if faces is not None and len(faces) > 0: |
| best_idx = int(np.argmax([f[2] * f[3] for f in faces])) |
| f = faces[best_idx] |
| x1, y1 = int(f[0]), int(f[1]) |
| x2, y2 = int(f[0] + f[2]), int(f[1] + f[3]) |
| return (max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)) |
| except Exception as e: |
| print(f"YuNet failed, falling back to Haar: {e}") |
|
|
| gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) |
| faces = _face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60)) |
| if len(faces) == 0: |
| return None |
| best_idx = np.argmax([fw * fh for (_, _, fw, fh) in faces]) |
| x, y, fw, fh = faces[best_idx] |
| return (x, y, x + fw, y + fh) |
|
|
| |
| |
| |
| def crop_face_region(image_rgb, box): |
| h, w = image_rgb.shape[:2] |
| x1, y1, x2, y2 = box |
| face_w, face_h = x2 - x1, y2 - y1 |
| margin_top = int(face_h * 0.63 * 0.85) |
| margin_bot = int(face_h * 0.37 * 0.85) |
| margin_x = int(face_w * 0.85 / 2) |
| margin_top += 2 * margin_x - margin_top - margin_bot |
| l_y, r_y = max(y1 - margin_top, 0), min(y2 + margin_bot, h) |
| l_x, r_x = max(x1 - margin_x, 0), min(x2 + margin_x, w) |
| return image_rgb[l_y:r_y, l_x:r_x, :], (l_x, l_y, r_x, r_y) |
|
|
|
|
| def create_blend_mask(crop_h, crop_w, feather=0.15): |
| mask = np.ones((crop_h, crop_w), dtype=np.float32) |
| by, bx = max(int(crop_h * feather), 1), max(int(crop_w * feather), 1) |
| for i in range(by): |
| a = i / by |
| mask[i, :] *= a |
| mask[crop_h - 1 - i, :] *= a |
| for j in range(bx): |
| a = j / bx |
| mask[:, j] *= a |
| mask[:, crop_w - 1 - j] *= a |
| return mask[:, :, np.newaxis] |
|
|
|
|
| def reage_frame(image_rgb, source_age, target_age): |
| box = detect_face_box(image_rgb) |
| if box is None: |
| return image_rgb |
|
|
| cropped, (l_x, l_y, r_x, r_y) = crop_face_region(image_rgb, box) |
| crop_h, crop_w = cropped.shape[:2] |
| cropped_resized = cv2.resize(cropped, (512, 512), interpolation=cv2.INTER_LINEAR) |
|
|
| img_t = cropped_resized.astype(np.float32) / 255.0 |
| img_t = np.transpose(img_t, (2, 0, 1)) |
| src_ch = np.full((1, 512, 512), source_age / 100.0, dtype=np.float32) |
| tgt_ch = np.full((1, 512, 512), target_age / 100.0, dtype=np.float32) |
| inp = np.concatenate([img_t, src_ch, tgt_ch], axis=0)[np.newaxis, ...] |
|
|
| delta = sess.run(None, {"input": inp})[0] |
| aged = np.clip(img_t + delta[0], 0.0, 1.0) |
| aged_hwc = (np.transpose(aged, (1, 2, 0)) * 255).astype(np.uint8) |
| aged_resized = cv2.resize(aged_hwc, (crop_w, crop_h), interpolation=cv2.INTER_LINEAR) |
|
|
| result = image_rgb.copy() |
| mask = create_blend_mask(crop_h, crop_w, feather=0.12) |
| region = result[l_y:r_y, l_x:r_x].astype(np.float32) |
| blended = region * (1 - mask) + aged_resized.astype(np.float32) * mask |
| result[l_y:r_y, l_x:r_x] = blended.astype(np.uint8) |
| return result |
|
|
| |
| |
| |
| def _find_ffmpeg(): |
| path = shutil.which("ffmpeg") |
| if path: |
| return path |
| for p in ["/usr/bin/ffmpeg", "/usr/local/bin/ffmpeg"]: |
| if os.path.isfile(p): |
| return p |
| raise gr.Error("ffmpeg not found.") |
|
|
|
|
| def _get_video_info(video_path): |
| ffprobe = shutil.which("ffprobe") or shutil.which("ffprobe", path="/usr/bin:/usr/local/bin") |
| if not ffprobe: |
| cap = cv2.VideoCapture(video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 |
| count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| cap.release() |
| return fps, count |
| try: |
| import json |
| r = subprocess.run( |
| [ffprobe, "-v", "quiet", "-print_format", "json", |
| "-show_streams", "-select_streams", "v:0", video_path], |
| capture_output=True, text=True, timeout=30, |
| ) |
| stream = json.loads(r.stdout)["streams"][0] |
| num, den = stream.get("r_frame_rate", "25/1").split("/") |
| fps = float(num) / float(den) |
| nb = stream.get("nb_frames") |
| count = int(nb) if nb and nb != "N/A" else int(float(stream.get("duration", 0)) * fps) |
| return fps, count |
| except Exception: |
| cap = cv2.VideoCapture(video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 |
| count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| cap.release() |
| return fps, count |
|
|
|
|
| def _extract_frames(video_path, out_dir): |
| ffmpeg = _find_ffmpeg() |
| cmd = [ffmpeg, "-i", video_path, "-vsync", "0", os.path.join(out_dir, "frame_%06d.png"), "-y"] |
| r = subprocess.run(cmd, capture_output=True, text=True, timeout=300) |
| if r.returncode != 0: |
| raise gr.Error(f"Frame extraction failed: {r.stderr[-500:]}") |
|
|
|
|
| def _assemble_video(frames_dir, output_path, fps, audio_source=None): |
| ffmpeg = _find_ffmpeg() |
| cmd = [ffmpeg, "-y", "-framerate", str(fps), "-i", os.path.join(frames_dir, "frame_%06d.png")] |
| if audio_source: |
| cmd += ["-i", audio_source, "-map", "0:v", "-map", "1:a?", "-shortest"] |
| cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "fast", "-crf", "20", |
| "-movflags", "+faststart", output_path] |
| r = subprocess.run(cmd, capture_output=True, text=True, timeout=600) |
| if r.returncode != 0: |
| raise gr.Error(f"Video assembly failed: {r.stderr[-500:]}") |
|
|
| |
| |
| |
| VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"} |
|
|
|
|
| def process(input_file, source_age, target_age, progress=gr.Progress()): |
| if input_file is None: |
| raise gr.Error("Please upload an image or video.") |
|
|
| t0 = time.time() |
| source_age, target_age = int(source_age), int(target_age) |
|
|
| |
| if isinstance(input_file, Image.Image): |
| |
| image_rgb = np.array(input_file.convert("RGB")) |
| box = detect_face_box(image_rgb) |
| if box is None: |
| raise gr.Error("No face detected. Please upload a clear photo with a visible face.") |
| result = reage_frame(image_rgb, source_age, target_age) |
| elapsed = time.time() - t0 |
| info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years" |
| return Image.fromarray(result), None, info |
|
|
| |
| file_path = input_file if isinstance(input_file, str) else str(input_file) |
| ext = os.path.splitext(file_path)[1].lower() |
|
|
| if ext in VIDEO_EXTS: |
| |
| fps, total_frames = _get_video_info(file_path) |
| duration = total_frames / max(fps, 1) |
|
|
| if duration > MAX_VIDEO_SECONDS: |
| raise gr.Error(f"Video is {duration:.1f}s (max {MAX_VIDEO_SECONDS}s). Please trim it.") |
| if total_frames > MAX_FRAMES: |
| raise gr.Error(f"Video has {total_frames} frames (max {MAX_FRAMES}).") |
|
|
| tmp_root = tempfile.mkdtemp(prefix="reage_") |
| frames_in = os.path.join(tmp_root, "in") |
| frames_out = os.path.join(tmp_root, "out") |
| os.makedirs(frames_in, exist_ok=True) |
| os.makedirs(frames_out, exist_ok=True) |
|
|
| try: |
| progress(0, desc="Extracting frames...") |
| _extract_frames(file_path, frames_in) |
|
|
| frame_files = sorted(glob_mod.glob(os.path.join(frames_in, "frame_*.png"))) |
| n_frames = len(frame_files) |
| if n_frames == 0: |
| raise gr.Error("No frames extracted. Is this a valid video?") |
| if n_frames > MAX_FRAMES: |
| raise gr.Error(f"{n_frames} frames (max {MAX_FRAMES}).") |
|
|
| faces_found, faces_missed = 0, 0 |
| for idx, fpath in enumerate(frame_files): |
| progress((idx + 1) / n_frames, desc=f"Re-aging frame {idx + 1}/{n_frames}...") |
| frame_bgr = cv2.imread(fpath) |
| if frame_bgr is None: |
| continue |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) |
| box = detect_face_box(frame_rgb) |
| if box is not None: |
| result_rgb = reage_frame(frame_rgb, source_age, target_age) |
| faces_found += 1 |
| else: |
| result_rgb = frame_rgb |
| faces_missed += 1 |
| out_path = os.path.join(frames_out, os.path.basename(fpath)) |
| cv2.imwrite(out_path, cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)) |
|
|
| progress(1.0, desc="Assembling video...") |
| output_path = os.path.join(tmp_root, "output.mp4") |
| _assemble_video(frames_out, output_path, fps, audio_source=file_path) |
|
|
| elapsed = time.time() - t0 |
| speed = n_frames / max(elapsed, 0.01) |
| info = (f"Done in {elapsed:.1f}s | {n_frames} frames at {speed:.1f} fps | " |
| f"Faces: {faces_found} found, {faces_missed} skipped | " |
| f"{source_age} -> {target_age} years") |
| return None, output_path, info |
|
|
| except gr.Error: |
| raise |
| except Exception as e: |
| raise gr.Error(f"Video processing failed: {e}") |
| else: |
| |
| image_rgb = cv2.imread(file_path) |
| if image_rgb is None: |
| raise gr.Error("Could not read the file. Please upload a valid image or video.") |
| image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB) |
| box = detect_face_box(image_rgb) |
| if box is None: |
| raise gr.Error("No face detected.") |
| result = reage_frame(image_rgb, source_age, target_age) |
| elapsed = time.time() - t0 |
| info = f"Done in {elapsed:.2f}s | {source_age} -> {target_age} years" |
| return Image.fromarray(result), None, info |
|
|
|
|
| |
| |
| |
| with gr.Blocks(title="Face Re-Aging (CPU)") as demo: |
| gr.Markdown( |
| "# Face Re-Aging (CPU)\n" |
| "Upload an **image or video** to age or de-age faces. " |
| f"Videos: max {MAX_VIDEO_SECONDS}s, ~0.5-2 fps on CPU." |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| file_input = gr.File( |
| label="Drop Image or Video Here", |
| file_types=["image", "video"], |
| ) |
| |
| img_input = gr.Image( |
| type="pil", label="Or paste/capture an image", |
| visible=True, |
| ) |
| src_age = gr.Slider(minimum=5, maximum=95, value=25, step=1, |
| label="Source Age (current)") |
| tgt_age = gr.Slider(minimum=5, maximum=95, value=65, step=1, |
| label="Target Age (desired)") |
| btn = gr.Button("Re-Age", variant="primary", size="lg") |
|
|
| with gr.Column(): |
| img_output = gr.Image(type="pil", label="Result (Image)") |
| vid_output = gr.Video(label="Result (Video)") |
| info_box = gr.Textbox(label="Info", interactive=False) |
|
|
| def on_submit_file(file_obj, source_age, target_age, progress=gr.Progress()): |
| if file_obj is None: |
| raise gr.Error("Please upload a file.") |
| return process(file_obj, source_age, target_age, progress) |
|
|
| def on_submit_image(image, source_age, target_age, progress=gr.Progress()): |
| if image is None: |
| raise gr.Error("Please provide an image.") |
| return process(image, source_age, target_age, progress) |
|
|
| btn.click( |
| fn=on_submit_file, |
| inputs=[file_input, src_age, tgt_age], |
| outputs=[img_output, vid_output, info_box], |
| ) |
|
|
| |
| img_input.change( |
| fn=on_submit_image, |
| inputs=[img_input, src_age, tgt_age], |
| outputs=[img_output, vid_output, info_box], |
| ) |
|
|
| gr.Markdown( |
| "**Model:** `face_reaging.onnx` (118 MB) from " |
| "[VisoMaster-Fusion](https://github.com/VisoMasterFusion/VisoMaster-Fusion) | " |
| "Based on [Disney FRAN](https://studios.disneyresearch.com/2022/11/30/production-ready-face-re-aging-for-visual-effects/)" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(show_error=True, ssr_mode=False, theme="NoCrypt/miku") |
|
|