Spaces:
Running on Zero
Running on Zero
| """ | |
| Waypoint v1.5 Real-Time World Model Demo — gradio.Server + WebSocket Edition | |
| Uses gradio.Server for ZeroGPU-compatible game start/stop, and a raw WebSocket | |
| for real-time binary JPEG frame streaming + control input. No Gradio UI polling. | |
| Multi-user safe: every endpoint is keyed by a per-client `session_id` so | |
| concurrent players never share seed images, frame queues, or status messages. | |
| """ | |
| # Check for ZeroGPU environment - must be before other imports | |
| try: | |
| import spaces | |
| IS_ZERO_GPU = True | |
| print("ZeroGPU environment detected") | |
| except ImportError: | |
| IS_ZERO_GPU = False | |
| print("Running in standard GPU mode") | |
| import asyncio | |
| import io | |
| import os | |
| import queue | |
| import random | |
| import struct | |
| import contextvars | |
| import threading | |
| import time | |
| import uuid | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| from multiprocessing import Queue | |
| from typing import Dict, Optional, Set, Tuple | |
| import torch | |
| from PIL import Image | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from gradio import Server | |
| from diffusers.modular_pipelines import ModularPipeline | |
| from diffusers.utils import load_image | |
| # --- Configuration --- | |
| MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/waypoint_v1_5_modular_aot") | |
| IMAGE_WIDTH = 1024 | |
| IMAGE_HEIGHT = 512 | |
| MAX_FRAMES_BEFORE_RESET = 4096 | |
| JPEG_QUALITY = 80 | |
| SESSION_IDLE_TIMEOUT = 600 # seconds; janitor reaps abandoned sessions | |
| torch.set_float32_matmul_precision("medium") | |
| # --- Load Pipeline on CPU at Startup --- | |
| print(f"Loading pipeline from {MODEL_ID}...") | |
| pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| pipe.load_components("transformer", trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| pipe.transformer.apply_inference_patches() | |
| pipe.load_components( | |
| ["vae", "text_encoder", "tokenizer"], | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| import torch._inductor.codecache | |
| pt2_path = hf_hub_download(repo_id=MODEL_ID, filename="transformer/transformer.pt2") | |
| constants_path = hf_hub_download(repo_id=MODEL_ID, filename="transformer/transformer-constants.pt") | |
| print("Pipeline and AOT artifacts downloaded.") | |
| _aot_initialized = False | |
| _aot_init_lock = threading.Lock() | |
| def _ensure_aot_on_gpu(): | |
| global _aot_initialized | |
| if _aot_initialized: | |
| return | |
| with _aot_init_lock: | |
| if _aot_initialized: | |
| return | |
| broadcast_status("Moving pipeline to GPU...") | |
| print("Initializing AOT model on GPU...") | |
| pipe.to("cuda") | |
| broadcast_status("Setting up KV cache...") | |
| pipe.blocks.sub_blocks["before_denoise"].sub_blocks["setup_kv_cache"]._setup_kv_cache( | |
| pipe.transformer, torch.device("cuda"), torch.bfloat16 | |
| ) | |
| broadcast_status("Loading AOT-compiled model...") | |
| aot_model = torch._inductor.aoti_load_package(pt2_path) | |
| constants_map = torch.load(constants_path, map_location="cuda", weights_only=True) | |
| aot_model.load_constants(constants_map, check_full_update=True, user_managed=True) | |
| pipe.transformer.forward = aot_model | |
| _aot_initialized = True | |
| broadcast_status("AOT model ready!") | |
| print("AOT model loaded successfully!") | |
| _SEED_BASE = "https://github.com/Overworldai/Biome/blob/main/seeds" | |
| _SEED_NAMES = [ | |
| "abandoned_city_runner.jpg", | |
| "alien_command_center.jpg", | |
| "ancient_standing_stones.jpg", | |
| "default.jpg", | |
| "desert_outpost_hangar.jpg", | |
| "enchanted_swamp_torch.jpg", | |
| "frozen_crystal_cavern.jpg", | |
| "frozen_valley_sniper.jpg", | |
| "highland_castle_loch.jpg", | |
| "mountain_ruins_gun.jpg", | |
| "shattered_cockpit_nebula.jpg", | |
| "shipwreck_shore_revolver.jpg", | |
| "snowy_forest_tracks.jpg", | |
| "stormy_countryside_rifle.jpg", | |
| "sunken_city_depths.jpg", | |
| ] | |
| SEED_FRAME_URLS = [f"{_SEED_BASE}/{name}?raw=true" for name in _SEED_NAMES] | |
| def load_seed_frame(url: str, target_size: Tuple[int, int] = (IMAGE_HEIGHT, IMAGE_WIDTH)) -> Image.Image: | |
| img = load_image(url) | |
| return img.resize((target_size[1], target_size[0]), Image.BILINEAR) | |
| # --- Command Types --- | |
| class GenerateCommand: | |
| buttons: Set[int] | |
| mouse: Tuple[float, float] | |
| prompt: str | |
| class ResetCommand: | |
| seed_url: Optional[str] = None | |
| prompt: str = "An explorable world" | |
| class StopCommand: | |
| pass | |
| # --- Session --- | |
| class GameSession: | |
| session_id: str | |
| command_queue: Queue | |
| frame_queue: queue.Queue | |
| stop_event: threading.Event | |
| status_queue: queue.Queue | |
| worker_thread: Optional[threading.Thread] = None | |
| frame_times: deque = field(default_factory=lambda: deque(maxlen=30)) | |
| last_active: float = field(default_factory=time.time) | |
| def touch(self): | |
| self.last_active = time.time() | |
| def stop(self): | |
| self.stop_event.set() | |
| try: | |
| self.command_queue.put_nowait(StopCommand()) | |
| except Exception: | |
| pass | |
| if self.worker_thread and self.worker_thread.is_alive(): | |
| self.worker_thread.join(timeout=3.0) | |
| # Per-client state, keyed by session_id (UUID generated by the browser). | |
| _sessions: Dict[str, GameSession] = {} | |
| _sessions_lock = threading.Lock() | |
| # Contextvar carrying the active session's status queue. The worker thread | |
| # inherits the value via contextvars.copy_context(), so broadcast_status() | |
| # always lands in the right session's queue without explicit threading. | |
| _current_status_queue: contextvars.ContextVar[Optional[queue.Queue]] = contextvars.ContextVar( | |
| "waypoint_status_queue", default=None | |
| ) | |
| def broadcast_status(msg: str): | |
| """Push a status message to the current session's WebSocket consumer.""" | |
| q = _current_status_queue.get() | |
| if q is None: | |
| return | |
| try: | |
| q.put_nowait(msg) | |
| except queue.Full: | |
| pass | |
| def _get_session(session_id: str) -> Optional[GameSession]: | |
| with _sessions_lock: | |
| return _sessions.get(session_id) | |
| def _drop_session(session_id: str) -> Optional[GameSession]: | |
| with _sessions_lock: | |
| return _sessions.pop(session_id, None) | |
| def _reap_idle_sessions(): | |
| """Background janitor: stops sessions whose worker died or that went idle.""" | |
| while True: | |
| time.sleep(60) | |
| now = time.time() | |
| to_drop = [] | |
| with _sessions_lock: | |
| for sid, sess in list(_sessions.items()): | |
| worker_dead = sess.worker_thread is None or not sess.worker_thread.is_alive() | |
| idle = (now - sess.last_active) > SESSION_IDLE_TIMEOUT | |
| if worker_dead and idle: | |
| to_drop.append(sid) | |
| for sid in to_drop: | |
| _sessions.pop(sid, None) | |
| if to_drop: | |
| print(f"Janitor reaped {len(to_drop)} idle session(s)") | |
| threading.Thread(target=_reap_idle_sessions, daemon=True).start() | |
| def gpu_worker_thread(seed_url, prompt, command_queue, frame_queue, stop_event, frame_times): | |
| try: | |
| gen = create_gpu_game_loop( | |
| command_queue, | |
| initial_seed_url=seed_url, | |
| initial_prompt=prompt, | |
| ) | |
| while not stop_event.is_set(): | |
| try: | |
| frame, frame_count = next(gen) | |
| now = time.time() | |
| frame_times.append(now) | |
| fps = 0.0 | |
| if len(frame_times) >= 2: | |
| elapsed = frame_times[-1] - frame_times[0] | |
| fps = (len(frame_times) - 1) / elapsed if elapsed > 0 else 0.0 | |
| # Drop stale frames, keep only latest | |
| while not frame_queue.empty(): | |
| try: | |
| frame_queue.get_nowait() | |
| except queue.Empty: | |
| break | |
| frame_queue.put_nowait((frame, frame_count, round(fps, 1))) | |
| except StopIteration: | |
| print("Generator exhausted") | |
| break | |
| except Exception as e: | |
| if "aborted" in str(e).lower() or "duration" in str(e).lower(): | |
| print(f"GPU time expired: {e}") | |
| else: | |
| print(f"Worker error: {e}") | |
| broadcast_status(f"error:{e}") | |
| break | |
| finally: | |
| stop_event.set() | |
| print("Worker thread finished") | |
| def create_gpu_game_loop(command_queue, initial_seed_url=None, initial_prompt="An explorable world"): | |
| def gpu_game_loop(): | |
| broadcast_status("GPU allocated! Loading model...") | |
| _ensure_aot_on_gpu() | |
| n_frames = pipe.transformer.config.n_frames | |
| print(f"Model loaded! (n_frames={n_frames})") | |
| broadcast_status("Loading seed image...") | |
| if initial_seed_url: | |
| seed_image = load_seed_frame(initial_seed_url) | |
| else: | |
| seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS)) | |
| broadcast_status("Generating first frame...") | |
| state = pipe(prompt=initial_prompt, image=seed_image, button=set(), mouse=(0.0, 0.0), output_type="pil") | |
| frame_count = 0 | |
| for frame in state.values.get("images"): | |
| frame_count += 1 | |
| yield (frame, frame_count) | |
| current_buttons = set() | |
| current_mouse = (0.0, 0.0) | |
| current_prompt = initial_prompt | |
| while True: | |
| stop_requested = False | |
| reset_command = None | |
| while True: | |
| try: | |
| command = command_queue.get_nowait() | |
| if isinstance(command, StopCommand): | |
| stop_requested = True | |
| break | |
| elif isinstance(command, ResetCommand): | |
| reset_command = command | |
| elif isinstance(command, GenerateCommand): | |
| current_buttons = command.buttons | |
| current_mouse = command.mouse | |
| current_prompt = command.prompt | |
| except: | |
| break | |
| if stop_requested: | |
| break | |
| if reset_command is not None: | |
| if reset_command.seed_url: | |
| seed_img = load_seed_frame(reset_command.seed_url) | |
| else: | |
| seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS)) | |
| state = pipe(prompt=reset_command.prompt, image=seed_img, button=set(), mouse=(0.0, 0.0), output_type="pil") | |
| frame_count = 0 | |
| current_prompt = reset_command.prompt | |
| for frame in state.values.get("images"): | |
| frame_count += 1 | |
| yield (frame, frame_count) | |
| continue | |
| state = pipe(state, prompt=current_prompt, button=current_buttons, mouse=current_mouse, image=None, output_type="pil") | |
| for frame in state.values.get("images"): | |
| frame_count += 1 | |
| yield (frame, frame_count) | |
| if frame_count >= MAX_FRAMES_BEFORE_RESET - 2: | |
| print(f"Auto-reset at frame {frame_count}") | |
| seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS)) | |
| state = pipe(prompt=current_prompt, image=seed_img, button=set(), mouse=(0.0, 0.0), output_type="pil") | |
| frame_count = 0 | |
| for frame in state.values.get("images"): | |
| frame_count += 1 | |
| yield (frame, frame_count) | |
| print("GPU session ended") | |
| return gpu_game_loop() | |
| # --- App --- | |
| app = Server() | |
| def start_game(session_id: str = "", seed_url: str = "", prompt: str = "An explorable world") -> str: | |
| """Start a new game session for `session_id`. Returns the session_id used.""" | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| # Tear down any prior session for this client (e.g. a reload mid-game). | |
| prior = _drop_session(session_id) | |
| if prior is not None: | |
| prior.stop() | |
| # If no seed selected, pick a random one from the curated set so the | |
| # player always lands somewhere instead of falling through to the model's | |
| # default. | |
| if not seed_url: | |
| seed_url = random.choice(SEED_FRAME_URLS) | |
| command_queue = Queue() | |
| frame_queue = queue.Queue(maxsize=2) | |
| stop_event = threading.Event() | |
| status_queue = queue.Queue(maxsize=20) | |
| frame_times = deque(maxlen=30) | |
| session = GameSession( | |
| session_id=session_id, | |
| command_queue=command_queue, | |
| frame_queue=frame_queue, | |
| stop_event=stop_event, | |
| status_queue=status_queue, | |
| frame_times=frame_times, | |
| ) | |
| with _sessions_lock: | |
| _sessions[session_id] = session | |
| # Bind the status queue into the contextvar so the worker thread (and the | |
| # @spaces.GPU function it spawns) inherits it via copy_context(). | |
| token = _current_status_queue.set(status_queue) | |
| try: | |
| broadcast_status("Requesting GPU from ZeroGPU...") | |
| ctx = contextvars.copy_context() | |
| worker = threading.Thread( | |
| target=ctx.run, | |
| args=(gpu_worker_thread, | |
| seed_url, | |
| prompt or "An explorable world", | |
| command_queue, frame_queue, stop_event, frame_times), | |
| daemon=True, | |
| ) | |
| session.worker_thread = worker | |
| worker.start() | |
| finally: | |
| _current_status_queue.reset(token) | |
| return session_id | |
| def stop_game(session_id: str = "") -> str: | |
| """Stop the active game session for the given client.""" | |
| if not session_id: | |
| return "no_session" | |
| session = _drop_session(session_id) | |
| if session is not None: | |
| session.stop() | |
| return "stopped" | |
| def get_worlds() -> list: | |
| """Return the list of seed world URLs.""" | |
| return SEED_FRAME_URLS | |
| async def game_ws(websocket: WebSocket, session_id: str = ""): | |
| """ | |
| Real-time game WebSocket. Requires `?session_id=...` query param matching | |
| the value passed to /start_game. | |
| """ | |
| await websocket.accept() | |
| if not session_id: | |
| await websocket.send_json({"type": "error", "message": "missing session_id"}) | |
| await websocket.close(code=1008) | |
| return | |
| loop = asyncio.get_event_loop() | |
| async def send_frames(): | |
| session_ended_sent = False | |
| while True: | |
| session = _get_session(session_id) | |
| # Drain any status messages for this session (init progress / errors). | |
| if session is not None: | |
| try: | |
| status_msg = session.status_queue.get_nowait() | |
| if status_msg.startswith("error:"): | |
| await websocket.send_json({"type": "error", "message": status_msg[6:]}) | |
| break | |
| await websocket.send_json({"type": "status", "message": status_msg}) | |
| except queue.Empty: | |
| pass | |
| except (WebSocketDisconnect, RuntimeError): | |
| break | |
| if session is None: | |
| await asyncio.sleep(0.05) | |
| continue | |
| if session.stop_event.is_set(): | |
| if not session_ended_sent: | |
| try: | |
| await websocket.send_json({"type": "session_ended"}) | |
| except (WebSocketDisconnect, RuntimeError): | |
| break | |
| session_ended_sent = True | |
| await asyncio.sleep(0.5) | |
| continue | |
| try: | |
| result = await loop.run_in_executor( | |
| None, lambda s=session: s.frame_queue.get(timeout=0.05) | |
| ) | |
| frame, count, fps = result | |
| buf = io.BytesIO() | |
| frame.save(buf, format="JPEG", quality=JPEG_QUALITY) | |
| jpeg_bytes = buf.getvalue() | |
| header = struct.pack(">II", count, int(fps * 10)) | |
| await websocket.send_bytes(header + jpeg_bytes) | |
| session.touch() | |
| except queue.Empty: | |
| pass | |
| except (WebSocketDisconnect, RuntimeError): | |
| break | |
| async def receive_controls(): | |
| while True: | |
| try: | |
| data = await websocket.receive_json() | |
| session = _get_session(session_id) | |
| if session is None: | |
| continue | |
| session.touch() | |
| msg_type = data.get("type", "control") | |
| if msg_type == "control": | |
| buttons = set(data.get("buttons", [])) | |
| mouse = (data.get("mouse_x", 0.0), data.get("mouse_y", 0.0)) | |
| prompt = data.get("prompt", "An explorable world") | |
| session.command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt)) | |
| elif msg_type == "reset": | |
| seed_url = data.get("seed_url") or random.choice(SEED_FRAME_URLS) | |
| prompt = data.get("prompt", "An explorable world") | |
| session.command_queue.put(ResetCommand( | |
| seed_url=seed_url, | |
| prompt=prompt, | |
| )) | |
| except WebSocketDisconnect: | |
| break | |
| except Exception: | |
| break | |
| try: | |
| await asyncio.gather(send_frames(), receive_controls()) | |
| except WebSocketDisconnect: | |
| pass | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| # Avoid ZeroGPU "no GPU function" error | |
| if IS_ZERO_GPU: | |
| spaces.GPU(lambda: None) | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |