""" Waypoint v1.5 Real-Time World Model Demo — gradio.Server + WebSocket Edition Uses gradio.Server for ZeroGPU-compatible game start/stop, and a raw WebSocket for real-time binary JPEG frame streaming + control input. No Gradio UI polling. Multi-user safe: every endpoint is keyed by a per-client `session_id` so concurrent players never share seed images, frame queues, or status messages. """ # Check for ZeroGPU environment - must be before other imports try: import spaces IS_ZERO_GPU = True print("ZeroGPU environment detected") except ImportError: IS_ZERO_GPU = False print("Running in standard GPU mode") import asyncio import io import os import queue import random import struct import contextvars import threading import time import uuid from collections import deque from dataclasses import dataclass, field from multiprocessing import Queue from typing import Dict, Optional, Set, Tuple import torch from PIL import Image from fastapi import WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from gradio import Server from diffusers.modular_pipelines import ModularPipeline from diffusers.utils import load_image # --- Configuration --- MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/waypoint_v1_5_modular_aot") IMAGE_WIDTH = 1024 IMAGE_HEIGHT = 512 MAX_FRAMES_BEFORE_RESET = 4096 JPEG_QUALITY = 80 SESSION_IDLE_TIMEOUT = 600 # seconds; janitor reaps abandoned sessions torch.set_float32_matmul_precision("medium") # --- Load Pipeline on CPU at Startup --- print(f"Loading pipeline from {MODEL_ID}...") pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True) pipe.load_components("transformer", trust_remote_code=True, torch_dtype=torch.bfloat16) pipe.transformer.apply_inference_patches() pipe.load_components( ["vae", "text_encoder", "tokenizer"], trust_remote_code=True, torch_dtype=torch.bfloat16, ) from huggingface_hub import hf_hub_download import torch._inductor.codecache pt2_path = hf_hub_download(repo_id=MODEL_ID, filename="transformer/transformer.pt2") constants_path = hf_hub_download(repo_id=MODEL_ID, filename="transformer/transformer-constants.pt") print("Pipeline and AOT artifacts downloaded.") _aot_initialized = False _aot_init_lock = threading.Lock() def _ensure_aot_on_gpu(): global _aot_initialized if _aot_initialized: return with _aot_init_lock: if _aot_initialized: return broadcast_status("Moving pipeline to GPU...") print("Initializing AOT model on GPU...") pipe.to("cuda") broadcast_status("Setting up KV cache...") pipe.blocks.sub_blocks["before_denoise"].sub_blocks["setup_kv_cache"]._setup_kv_cache( pipe.transformer, torch.device("cuda"), torch.bfloat16 ) broadcast_status("Loading AOT-compiled model...") aot_model = torch._inductor.aoti_load_package(pt2_path) constants_map = torch.load(constants_path, map_location="cuda", weights_only=True) aot_model.load_constants(constants_map, check_full_update=True, user_managed=True) pipe.transformer.forward = aot_model _aot_initialized = True broadcast_status("AOT model ready!") print("AOT model loaded successfully!") _SEED_BASE = "https://github.com/Overworldai/Biome/blob/main/seeds" _SEED_NAMES = [ "abandoned_city_runner.jpg", "alien_command_center.jpg", "ancient_standing_stones.jpg", "default.jpg", "desert_outpost_hangar.jpg", "enchanted_swamp_torch.jpg", "frozen_crystal_cavern.jpg", "frozen_valley_sniper.jpg", "highland_castle_loch.jpg", "mountain_ruins_gun.jpg", "shattered_cockpit_nebula.jpg", "shipwreck_shore_revolver.jpg", "snowy_forest_tracks.jpg", "stormy_countryside_rifle.jpg", "sunken_city_depths.jpg", ] SEED_FRAME_URLS = [f"{_SEED_BASE}/{name}?raw=true" for name in _SEED_NAMES] def load_seed_frame(url: str, target_size: Tuple[int, int] = (IMAGE_HEIGHT, IMAGE_WIDTH)) -> Image.Image: img = load_image(url) return img.resize((target_size[1], target_size[0]), Image.BILINEAR) # --- Command Types --- @dataclass class GenerateCommand: buttons: Set[int] mouse: Tuple[float, float] prompt: str @dataclass class ResetCommand: seed_url: Optional[str] = None prompt: str = "An explorable world" @dataclass class StopCommand: pass # --- Session --- @dataclass class GameSession: session_id: str command_queue: Queue frame_queue: queue.Queue stop_event: threading.Event status_queue: queue.Queue worker_thread: Optional[threading.Thread] = None frame_times: deque = field(default_factory=lambda: deque(maxlen=30)) last_active: float = field(default_factory=time.time) def touch(self): self.last_active = time.time() def stop(self): self.stop_event.set() try: self.command_queue.put_nowait(StopCommand()) except Exception: pass if self.worker_thread and self.worker_thread.is_alive(): self.worker_thread.join(timeout=3.0) # Per-client state, keyed by session_id (UUID generated by the browser). _sessions: Dict[str, GameSession] = {} _sessions_lock = threading.Lock() # Contextvar carrying the active session's status queue. The worker thread # inherits the value via contextvars.copy_context(), so broadcast_status() # always lands in the right session's queue without explicit threading. _current_status_queue: contextvars.ContextVar[Optional[queue.Queue]] = contextvars.ContextVar( "waypoint_status_queue", default=None ) def broadcast_status(msg: str): """Push a status message to the current session's WebSocket consumer.""" q = _current_status_queue.get() if q is None: return try: q.put_nowait(msg) except queue.Full: pass def _get_session(session_id: str) -> Optional[GameSession]: with _sessions_lock: return _sessions.get(session_id) def _drop_session(session_id: str) -> Optional[GameSession]: with _sessions_lock: return _sessions.pop(session_id, None) def _reap_idle_sessions(): """Background janitor: stops sessions whose worker died or that went idle.""" while True: time.sleep(60) now = time.time() to_drop = [] with _sessions_lock: for sid, sess in list(_sessions.items()): worker_dead = sess.worker_thread is None or not sess.worker_thread.is_alive() idle = (now - sess.last_active) > SESSION_IDLE_TIMEOUT if worker_dead and idle: to_drop.append(sid) for sid in to_drop: _sessions.pop(sid, None) if to_drop: print(f"Janitor reaped {len(to_drop)} idle session(s)") threading.Thread(target=_reap_idle_sessions, daemon=True).start() def gpu_worker_thread(seed_url, prompt, command_queue, frame_queue, stop_event, frame_times): try: gen = create_gpu_game_loop( command_queue, initial_seed_url=seed_url, initial_prompt=prompt, ) while not stop_event.is_set(): try: frame, frame_count = next(gen) now = time.time() frame_times.append(now) fps = 0.0 if len(frame_times) >= 2: elapsed = frame_times[-1] - frame_times[0] fps = (len(frame_times) - 1) / elapsed if elapsed > 0 else 0.0 # Drop stale frames, keep only latest while not frame_queue.empty(): try: frame_queue.get_nowait() except queue.Empty: break frame_queue.put_nowait((frame, frame_count, round(fps, 1))) except StopIteration: print("Generator exhausted") break except Exception as e: if "aborted" in str(e).lower() or "duration" in str(e).lower(): print(f"GPU time expired: {e}") else: print(f"Worker error: {e}") broadcast_status(f"error:{e}") break finally: stop_event.set() print("Worker thread finished") def create_gpu_game_loop(command_queue, initial_seed_url=None, initial_prompt="An explorable world"): @spaces.GPU(duration=120) def gpu_game_loop(): broadcast_status("GPU allocated! Loading model...") _ensure_aot_on_gpu() n_frames = pipe.transformer.config.n_frames print(f"Model loaded! (n_frames={n_frames})") broadcast_status("Loading seed image...") if initial_seed_url: seed_image = load_seed_frame(initial_seed_url) else: seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS)) broadcast_status("Generating first frame...") state = pipe(prompt=initial_prompt, image=seed_image, button=set(), mouse=(0.0, 0.0), output_type="pil") frame_count = 0 for frame in state.values.get("images"): frame_count += 1 yield (frame, frame_count) current_buttons = set() current_mouse = (0.0, 0.0) current_prompt = initial_prompt while True: stop_requested = False reset_command = None while True: try: command = command_queue.get_nowait() if isinstance(command, StopCommand): stop_requested = True break elif isinstance(command, ResetCommand): reset_command = command elif isinstance(command, GenerateCommand): current_buttons = command.buttons current_mouse = command.mouse current_prompt = command.prompt except: break if stop_requested: break if reset_command is not None: if reset_command.seed_url: seed_img = load_seed_frame(reset_command.seed_url) else: seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS)) state = pipe(prompt=reset_command.prompt, image=seed_img, button=set(), mouse=(0.0, 0.0), output_type="pil") frame_count = 0 current_prompt = reset_command.prompt for frame in state.values.get("images"): frame_count += 1 yield (frame, frame_count) continue state = pipe(state, prompt=current_prompt, button=current_buttons, mouse=current_mouse, image=None, output_type="pil") for frame in state.values.get("images"): frame_count += 1 yield (frame, frame_count) if frame_count >= MAX_FRAMES_BEFORE_RESET - 2: print(f"Auto-reset at frame {frame_count}") seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS)) state = pipe(prompt=current_prompt, image=seed_img, button=set(), mouse=(0.0, 0.0), output_type="pil") frame_count = 0 for frame in state.values.get("images"): frame_count += 1 yield (frame, frame_count) print("GPU session ended") return gpu_game_loop() # --- App --- app = Server() @app.api(name="start_game") def start_game(session_id: str = "", seed_url: str = "", prompt: str = "An explorable world") -> str: """Start a new game session for `session_id`. Returns the session_id used.""" if not session_id: session_id = str(uuid.uuid4()) # Tear down any prior session for this client (e.g. a reload mid-game). prior = _drop_session(session_id) if prior is not None: prior.stop() # If no seed selected, pick a random one from the curated set so the # player always lands somewhere instead of falling through to the model's # default. if not seed_url: seed_url = random.choice(SEED_FRAME_URLS) command_queue = Queue() frame_queue = queue.Queue(maxsize=2) stop_event = threading.Event() status_queue = queue.Queue(maxsize=20) frame_times = deque(maxlen=30) session = GameSession( session_id=session_id, command_queue=command_queue, frame_queue=frame_queue, stop_event=stop_event, status_queue=status_queue, frame_times=frame_times, ) with _sessions_lock: _sessions[session_id] = session # Bind the status queue into the contextvar so the worker thread (and the # @spaces.GPU function it spawns) inherits it via copy_context(). token = _current_status_queue.set(status_queue) try: broadcast_status("Requesting GPU from ZeroGPU...") ctx = contextvars.copy_context() worker = threading.Thread( target=ctx.run, args=(gpu_worker_thread, seed_url, prompt or "An explorable world", command_queue, frame_queue, stop_event, frame_times), daemon=True, ) session.worker_thread = worker worker.start() finally: _current_status_queue.reset(token) return session_id @app.api(name="stop_game") def stop_game(session_id: str = "") -> str: """Stop the active game session for the given client.""" if not session_id: return "no_session" session = _drop_session(session_id) if session is not None: session.stop() return "stopped" @app.api(name="get_worlds") def get_worlds() -> list: """Return the list of seed world URLs.""" return SEED_FRAME_URLS @app.websocket("/ws") async def game_ws(websocket: WebSocket, session_id: str = ""): """ Real-time game WebSocket. Requires `?session_id=...` query param matching the value passed to /start_game. """ await websocket.accept() if not session_id: await websocket.send_json({"type": "error", "message": "missing session_id"}) await websocket.close(code=1008) return loop = asyncio.get_event_loop() async def send_frames(): session_ended_sent = False while True: session = _get_session(session_id) # Drain any status messages for this session (init progress / errors). if session is not None: try: status_msg = session.status_queue.get_nowait() if status_msg.startswith("error:"): await websocket.send_json({"type": "error", "message": status_msg[6:]}) break await websocket.send_json({"type": "status", "message": status_msg}) except queue.Empty: pass except (WebSocketDisconnect, RuntimeError): break if session is None: await asyncio.sleep(0.05) continue if session.stop_event.is_set(): if not session_ended_sent: try: await websocket.send_json({"type": "session_ended"}) except (WebSocketDisconnect, RuntimeError): break session_ended_sent = True await asyncio.sleep(0.5) continue try: result = await loop.run_in_executor( None, lambda s=session: s.frame_queue.get(timeout=0.05) ) frame, count, fps = result buf = io.BytesIO() frame.save(buf, format="JPEG", quality=JPEG_QUALITY) jpeg_bytes = buf.getvalue() header = struct.pack(">II", count, int(fps * 10)) await websocket.send_bytes(header + jpeg_bytes) session.touch() except queue.Empty: pass except (WebSocketDisconnect, RuntimeError): break async def receive_controls(): while True: try: data = await websocket.receive_json() session = _get_session(session_id) if session is None: continue session.touch() msg_type = data.get("type", "control") if msg_type == "control": buttons = set(data.get("buttons", [])) mouse = (data.get("mouse_x", 0.0), data.get("mouse_y", 0.0)) prompt = data.get("prompt", "An explorable world") session.command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt)) elif msg_type == "reset": seed_url = data.get("seed_url") or random.choice(SEED_FRAME_URLS) prompt = data.get("prompt", "An explorable world") session.command_queue.put(ResetCommand( seed_url=seed_url, prompt=prompt, )) except WebSocketDisconnect: break except Exception: break try: await asyncio.gather(send_frames(), receive_controls()) except WebSocketDisconnect: pass @app.get("/", response_class=HTMLResponse) async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(html_path, "r", encoding="utf-8") as f: return f.read() # Avoid ZeroGPU "no GPU function" error if IS_ZERO_GPU: spaces.GPU(lambda: None) app.launch(server_name="0.0.0.0", server_port=7860)