waypoint-1-5 / app.py
lapp0's picture
Fix per user isolation (#4)
1d676b9
"""
Waypoint v1.5 Real-Time World Model Demo — gradio.Server + WebSocket Edition
Uses gradio.Server for ZeroGPU-compatible game start/stop, and a raw WebSocket
for real-time binary JPEG frame streaming + control input. No Gradio UI polling.
Multi-user safe: every endpoint is keyed by a per-client `session_id` so
concurrent players never share seed images, frame queues, or status messages.
"""
# Check for ZeroGPU environment - must be before other imports
try:
import spaces
IS_ZERO_GPU = True
print("ZeroGPU environment detected")
except ImportError:
IS_ZERO_GPU = False
print("Running in standard GPU mode")
import asyncio
import io
import os
import queue
import random
import struct
import contextvars
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from multiprocessing import Queue
from typing import Dict, Optional, Set, Tuple
import torch
from PIL import Image
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from gradio import Server
from diffusers.modular_pipelines import ModularPipeline
from diffusers.utils import load_image
# --- Configuration ---
MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/waypoint_v1_5_modular_aot")
IMAGE_WIDTH = 1024
IMAGE_HEIGHT = 512
MAX_FRAMES_BEFORE_RESET = 4096
JPEG_QUALITY = 80
SESSION_IDLE_TIMEOUT = 600 # seconds; janitor reaps abandoned sessions
torch.set_float32_matmul_precision("medium")
# --- Load Pipeline on CPU at Startup ---
print(f"Loading pipeline from {MODEL_ID}...")
pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True)
pipe.load_components("transformer", trust_remote_code=True, torch_dtype=torch.bfloat16)
pipe.transformer.apply_inference_patches()
pipe.load_components(
["vae", "text_encoder", "tokenizer"],
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
from huggingface_hub import hf_hub_download
import torch._inductor.codecache
pt2_path = hf_hub_download(repo_id=MODEL_ID, filename="transformer/transformer.pt2")
constants_path = hf_hub_download(repo_id=MODEL_ID, filename="transformer/transformer-constants.pt")
print("Pipeline and AOT artifacts downloaded.")
_aot_initialized = False
_aot_init_lock = threading.Lock()
def _ensure_aot_on_gpu():
global _aot_initialized
if _aot_initialized:
return
with _aot_init_lock:
if _aot_initialized:
return
broadcast_status("Moving pipeline to GPU...")
print("Initializing AOT model on GPU...")
pipe.to("cuda")
broadcast_status("Setting up KV cache...")
pipe.blocks.sub_blocks["before_denoise"].sub_blocks["setup_kv_cache"]._setup_kv_cache(
pipe.transformer, torch.device("cuda"), torch.bfloat16
)
broadcast_status("Loading AOT-compiled model...")
aot_model = torch._inductor.aoti_load_package(pt2_path)
constants_map = torch.load(constants_path, map_location="cuda", weights_only=True)
aot_model.load_constants(constants_map, check_full_update=True, user_managed=True)
pipe.transformer.forward = aot_model
_aot_initialized = True
broadcast_status("AOT model ready!")
print("AOT model loaded successfully!")
_SEED_BASE = "https://github.com/Overworldai/Biome/blob/main/seeds"
_SEED_NAMES = [
"abandoned_city_runner.jpg",
"alien_command_center.jpg",
"ancient_standing_stones.jpg",
"default.jpg",
"desert_outpost_hangar.jpg",
"enchanted_swamp_torch.jpg",
"frozen_crystal_cavern.jpg",
"frozen_valley_sniper.jpg",
"highland_castle_loch.jpg",
"mountain_ruins_gun.jpg",
"shattered_cockpit_nebula.jpg",
"shipwreck_shore_revolver.jpg",
"snowy_forest_tracks.jpg",
"stormy_countryside_rifle.jpg",
"sunken_city_depths.jpg",
]
SEED_FRAME_URLS = [f"{_SEED_BASE}/{name}?raw=true" for name in _SEED_NAMES]
def load_seed_frame(url: str, target_size: Tuple[int, int] = (IMAGE_HEIGHT, IMAGE_WIDTH)) -> Image.Image:
img = load_image(url)
return img.resize((target_size[1], target_size[0]), Image.BILINEAR)
# --- Command Types ---
@dataclass
class GenerateCommand:
buttons: Set[int]
mouse: Tuple[float, float]
prompt: str
@dataclass
class ResetCommand:
seed_url: Optional[str] = None
prompt: str = "An explorable world"
@dataclass
class StopCommand:
pass
# --- Session ---
@dataclass
class GameSession:
session_id: str
command_queue: Queue
frame_queue: queue.Queue
stop_event: threading.Event
status_queue: queue.Queue
worker_thread: Optional[threading.Thread] = None
frame_times: deque = field(default_factory=lambda: deque(maxlen=30))
last_active: float = field(default_factory=time.time)
def touch(self):
self.last_active = time.time()
def stop(self):
self.stop_event.set()
try:
self.command_queue.put_nowait(StopCommand())
except Exception:
pass
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=3.0)
# Per-client state, keyed by session_id (UUID generated by the browser).
_sessions: Dict[str, GameSession] = {}
_sessions_lock = threading.Lock()
# Contextvar carrying the active session's status queue. The worker thread
# inherits the value via contextvars.copy_context(), so broadcast_status()
# always lands in the right session's queue without explicit threading.
_current_status_queue: contextvars.ContextVar[Optional[queue.Queue]] = contextvars.ContextVar(
"waypoint_status_queue", default=None
)
def broadcast_status(msg: str):
"""Push a status message to the current session's WebSocket consumer."""
q = _current_status_queue.get()
if q is None:
return
try:
q.put_nowait(msg)
except queue.Full:
pass
def _get_session(session_id: str) -> Optional[GameSession]:
with _sessions_lock:
return _sessions.get(session_id)
def _drop_session(session_id: str) -> Optional[GameSession]:
with _sessions_lock:
return _sessions.pop(session_id, None)
def _reap_idle_sessions():
"""Background janitor: stops sessions whose worker died or that went idle."""
while True:
time.sleep(60)
now = time.time()
to_drop = []
with _sessions_lock:
for sid, sess in list(_sessions.items()):
worker_dead = sess.worker_thread is None or not sess.worker_thread.is_alive()
idle = (now - sess.last_active) > SESSION_IDLE_TIMEOUT
if worker_dead and idle:
to_drop.append(sid)
for sid in to_drop:
_sessions.pop(sid, None)
if to_drop:
print(f"Janitor reaped {len(to_drop)} idle session(s)")
threading.Thread(target=_reap_idle_sessions, daemon=True).start()
def gpu_worker_thread(seed_url, prompt, command_queue, frame_queue, stop_event, frame_times):
try:
gen = create_gpu_game_loop(
command_queue,
initial_seed_url=seed_url,
initial_prompt=prompt,
)
while not stop_event.is_set():
try:
frame, frame_count = next(gen)
now = time.time()
frame_times.append(now)
fps = 0.0
if len(frame_times) >= 2:
elapsed = frame_times[-1] - frame_times[0]
fps = (len(frame_times) - 1) / elapsed if elapsed > 0 else 0.0
# Drop stale frames, keep only latest
while not frame_queue.empty():
try:
frame_queue.get_nowait()
except queue.Empty:
break
frame_queue.put_nowait((frame, frame_count, round(fps, 1)))
except StopIteration:
print("Generator exhausted")
break
except Exception as e:
if "aborted" in str(e).lower() or "duration" in str(e).lower():
print(f"GPU time expired: {e}")
else:
print(f"Worker error: {e}")
broadcast_status(f"error:{e}")
break
finally:
stop_event.set()
print("Worker thread finished")
def create_gpu_game_loop(command_queue, initial_seed_url=None, initial_prompt="An explorable world"):
@spaces.GPU(duration=120)
def gpu_game_loop():
broadcast_status("GPU allocated! Loading model...")
_ensure_aot_on_gpu()
n_frames = pipe.transformer.config.n_frames
print(f"Model loaded! (n_frames={n_frames})")
broadcast_status("Loading seed image...")
if initial_seed_url:
seed_image = load_seed_frame(initial_seed_url)
else:
seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
broadcast_status("Generating first frame...")
state = pipe(prompt=initial_prompt, image=seed_image, button=set(), mouse=(0.0, 0.0), output_type="pil")
frame_count = 0
for frame in state.values.get("images"):
frame_count += 1
yield (frame, frame_count)
current_buttons = set()
current_mouse = (0.0, 0.0)
current_prompt = initial_prompt
while True:
stop_requested = False
reset_command = None
while True:
try:
command = command_queue.get_nowait()
if isinstance(command, StopCommand):
stop_requested = True
break
elif isinstance(command, ResetCommand):
reset_command = command
elif isinstance(command, GenerateCommand):
current_buttons = command.buttons
current_mouse = command.mouse
current_prompt = command.prompt
except:
break
if stop_requested:
break
if reset_command is not None:
if reset_command.seed_url:
seed_img = load_seed_frame(reset_command.seed_url)
else:
seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
state = pipe(prompt=reset_command.prompt, image=seed_img, button=set(), mouse=(0.0, 0.0), output_type="pil")
frame_count = 0
current_prompt = reset_command.prompt
for frame in state.values.get("images"):
frame_count += 1
yield (frame, frame_count)
continue
state = pipe(state, prompt=current_prompt, button=current_buttons, mouse=current_mouse, image=None, output_type="pil")
for frame in state.values.get("images"):
frame_count += 1
yield (frame, frame_count)
if frame_count >= MAX_FRAMES_BEFORE_RESET - 2:
print(f"Auto-reset at frame {frame_count}")
seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
state = pipe(prompt=current_prompt, image=seed_img, button=set(), mouse=(0.0, 0.0), output_type="pil")
frame_count = 0
for frame in state.values.get("images"):
frame_count += 1
yield (frame, frame_count)
print("GPU session ended")
return gpu_game_loop()
# --- App ---
app = Server()
@app.api(name="start_game")
def start_game(session_id: str = "", seed_url: str = "", prompt: str = "An explorable world") -> str:
"""Start a new game session for `session_id`. Returns the session_id used."""
if not session_id:
session_id = str(uuid.uuid4())
# Tear down any prior session for this client (e.g. a reload mid-game).
prior = _drop_session(session_id)
if prior is not None:
prior.stop()
# If no seed selected, pick a random one from the curated set so the
# player always lands somewhere instead of falling through to the model's
# default.
if not seed_url:
seed_url = random.choice(SEED_FRAME_URLS)
command_queue = Queue()
frame_queue = queue.Queue(maxsize=2)
stop_event = threading.Event()
status_queue = queue.Queue(maxsize=20)
frame_times = deque(maxlen=30)
session = GameSession(
session_id=session_id,
command_queue=command_queue,
frame_queue=frame_queue,
stop_event=stop_event,
status_queue=status_queue,
frame_times=frame_times,
)
with _sessions_lock:
_sessions[session_id] = session
# Bind the status queue into the contextvar so the worker thread (and the
# @spaces.GPU function it spawns) inherits it via copy_context().
token = _current_status_queue.set(status_queue)
try:
broadcast_status("Requesting GPU from ZeroGPU...")
ctx = contextvars.copy_context()
worker = threading.Thread(
target=ctx.run,
args=(gpu_worker_thread,
seed_url,
prompt or "An explorable world",
command_queue, frame_queue, stop_event, frame_times),
daemon=True,
)
session.worker_thread = worker
worker.start()
finally:
_current_status_queue.reset(token)
return session_id
@app.api(name="stop_game")
def stop_game(session_id: str = "") -> str:
"""Stop the active game session for the given client."""
if not session_id:
return "no_session"
session = _drop_session(session_id)
if session is not None:
session.stop()
return "stopped"
@app.api(name="get_worlds")
def get_worlds() -> list:
"""Return the list of seed world URLs."""
return SEED_FRAME_URLS
@app.websocket("/ws")
async def game_ws(websocket: WebSocket, session_id: str = ""):
"""
Real-time game WebSocket. Requires `?session_id=...` query param matching
the value passed to /start_game.
"""
await websocket.accept()
if not session_id:
await websocket.send_json({"type": "error", "message": "missing session_id"})
await websocket.close(code=1008)
return
loop = asyncio.get_event_loop()
async def send_frames():
session_ended_sent = False
while True:
session = _get_session(session_id)
# Drain any status messages for this session (init progress / errors).
if session is not None:
try:
status_msg = session.status_queue.get_nowait()
if status_msg.startswith("error:"):
await websocket.send_json({"type": "error", "message": status_msg[6:]})
break
await websocket.send_json({"type": "status", "message": status_msg})
except queue.Empty:
pass
except (WebSocketDisconnect, RuntimeError):
break
if session is None:
await asyncio.sleep(0.05)
continue
if session.stop_event.is_set():
if not session_ended_sent:
try:
await websocket.send_json({"type": "session_ended"})
except (WebSocketDisconnect, RuntimeError):
break
session_ended_sent = True
await asyncio.sleep(0.5)
continue
try:
result = await loop.run_in_executor(
None, lambda s=session: s.frame_queue.get(timeout=0.05)
)
frame, count, fps = result
buf = io.BytesIO()
frame.save(buf, format="JPEG", quality=JPEG_QUALITY)
jpeg_bytes = buf.getvalue()
header = struct.pack(">II", count, int(fps * 10))
await websocket.send_bytes(header + jpeg_bytes)
session.touch()
except queue.Empty:
pass
except (WebSocketDisconnect, RuntimeError):
break
async def receive_controls():
while True:
try:
data = await websocket.receive_json()
session = _get_session(session_id)
if session is None:
continue
session.touch()
msg_type = data.get("type", "control")
if msg_type == "control":
buttons = set(data.get("buttons", []))
mouse = (data.get("mouse_x", 0.0), data.get("mouse_y", 0.0))
prompt = data.get("prompt", "An explorable world")
session.command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt))
elif msg_type == "reset":
seed_url = data.get("seed_url") or random.choice(SEED_FRAME_URLS)
prompt = data.get("prompt", "An explorable world")
session.command_queue.put(ResetCommand(
seed_url=seed_url,
prompt=prompt,
))
except WebSocketDisconnect:
break
except Exception:
break
try:
await asyncio.gather(send_frames(), receive_controls())
except WebSocketDisconnect:
pass
@app.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
# Avoid ZeroGPU "no GPU function" error
if IS_ZERO_GPU:
spaces.GPU(lambda: None)
app.launch(server_name="0.0.0.0", server_port=7860)