Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Concurrency test for benchmark environment using WebSockets. | |
| Each WebSocket connection gets its own dedicated environment session, | |
| enabling true concurrent execution across multiple sessions. | |
| Run the server first: | |
| cd benchmark && uvicorn server.app:app --port 8000 | |
| Then run this script: | |
| python test_concurrency.py --requests 100 --wait 1.0 | |
| python test_concurrency.py -n 100 -w 1 --url wss://your-server.hf.space | |
| """ | |
| import argparse | |
| import asyncio | |
| import json | |
| import time | |
| from dataclasses import dataclass | |
| try: | |
| import websockets | |
| except ImportError: | |
| print("Install websockets: pip install websockets") | |
| raise | |
| class RequestResult: | |
| """Result from a single WebSocket request.""" | |
| request_id: int | |
| wait_requested: float | |
| waited_seconds: float | |
| elapsed: float | |
| pid: int | |
| session_hash: str | |
| host_url: str | |
| def convert_to_ws_url(url: str) -> str: | |
| """Convert HTTP URL to WebSocket URL.""" | |
| url = url.rstrip("/") | |
| if url.startswith("http://"): | |
| return "ws://" + url[7:] + "/ws" | |
| elif url.startswith("https://"): | |
| return "wss://" + url[8:] + "/ws" | |
| elif url.startswith("ws://") or url.startswith("wss://"): | |
| return url if url.endswith("/ws") else url + "/ws" | |
| return "ws://" + url + "/ws" | |
| async def ws_session( | |
| ws_url: str, | |
| request_id: int, | |
| wait_seconds: float, | |
| timeout: float = 60.0, | |
| ) -> RequestResult: | |
| """ | |
| Run a complete WebSocket session: connect, reset, step, close. | |
| Each session gets its own environment instance on the server. | |
| """ | |
| start = time.perf_counter() | |
| async with websockets.connect(ws_url, open_timeout=timeout) as ws: | |
| # Reset to initialize environment | |
| await ws.send(json.dumps({"type": "reset", "data": {}})) | |
| reset_response = json.loads(await asyncio.wait_for(ws.recv(), timeout)) | |
| if reset_response.get("type") == "error": | |
| raise RuntimeError(f"Reset error: {reset_response}") | |
| # Step with wait time | |
| await ws.send( | |
| json.dumps({ | |
| "type": "step", | |
| "data": {"wait_seconds": wait_seconds}, | |
| }) | |
| ) | |
| step_response = json.loads(await asyncio.wait_for(ws.recv(), timeout)) | |
| if step_response.get("type") == "error": | |
| raise RuntimeError(f"Step error: {step_response}") | |
| # Close cleanly | |
| await ws.send(json.dumps({"type": "close"})) | |
| elapsed = time.perf_counter() - start | |
| obs = step_response.get("data", {}).get("observation", {}) | |
| return RequestResult( | |
| request_id=request_id, | |
| wait_requested=wait_seconds, | |
| waited_seconds=obs.get("waited_seconds", 0.0), | |
| elapsed=elapsed, | |
| pid=obs.get("pid", 0), | |
| session_hash=obs.get("session_hash", ""), | |
| host_url=obs.get("host_url", ""), | |
| ) | |
| async def run_concurrent_test( | |
| url: str, | |
| num_requests: int, | |
| wait_seconds: float, | |
| timeout: float = 120.0, | |
| ) -> dict: | |
| """Run concurrent WebSocket sessions and collect results.""" | |
| ws_url = convert_to_ws_url(url) | |
| print(f"WebSocket URL: {ws_url}") | |
| print(f"Running {num_requests} concurrent sessions, each waiting {wait_seconds}s...") | |
| print() | |
| start = time.perf_counter() | |
| # Launch all sessions concurrently | |
| tasks = [ | |
| ws_session(ws_url, i, wait_seconds, timeout) for i in range(num_requests) | |
| ] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| total_time = time.perf_counter() - start | |
| # Process results | |
| successful = [r for r in results if isinstance(r, RequestResult)] | |
| failed = [r for r in results if isinstance(r, Exception)] | |
| if not successful: | |
| print("All requests failed!") | |
| for i, err in enumerate(failed[:5]): | |
| print(f" Error {i}: {err}") | |
| return {"error": "All requests failed"} | |
| avg_time = sum(r.elapsed for r in successful) / len(successful) | |
| unique_pids = set(r.pid for r in successful) | |
| unique_sessions = set(r.session_hash for r in successful) | |
| unique_hosts = set(r.host_url for r in successful) | |
| return { | |
| "num_requests": num_requests, | |
| "successful": len(successful), | |
| "failed": len(failed), | |
| "wait_seconds": wait_seconds, | |
| "total_time": total_time, | |
| "avg_time": avg_time, | |
| "unique_pids": len(unique_pids), | |
| "unique_sessions": len(unique_sessions), | |
| "unique_hosts": len(unique_hosts), | |
| "pids": list(unique_pids)[:10], # First 10 for display | |
| } | |
| async def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Test benchmark environment concurrency via WebSocket" | |
| ) | |
| parser.add_argument( | |
| "--requests", "-n", type=int, default=10, | |
| help="Number of concurrent WebSocket sessions" | |
| ) | |
| parser.add_argument( | |
| "--wait", "-w", type=float, default=1.0, | |
| help="Wait time per request (seconds)" | |
| ) | |
| parser.add_argument( | |
| "--url", "-u", type=str, default="http://localhost:8000", | |
| help="Server URL (http/https/ws/wss)" | |
| ) | |
| parser.add_argument( | |
| "--timeout", "-t", type=float, default=120.0, | |
| help="Timeout per request (seconds)" | |
| ) | |
| args = parser.parse_args() | |
| result = await run_concurrent_test( | |
| args.url, args.requests, args.wait, args.timeout | |
| ) | |
| if "error" in result: | |
| return | |
| print(f"Successful: {result['successful']}/{result['num_requests']}") | |
| if result["failed"]: | |
| print(f"Failed: {result['failed']}") | |
| print(f"Total time: {result['total_time']:.3f}s") | |
| print(f"Avg time: {result['avg_time']:.3f}s") | |
| print(f"Unique PIDs: {result['unique_pids']}") | |
| print(f"Unique sessions: {result['unique_sessions']}") | |
| print(f"Unique hosts: {result['unique_hosts']}") | |
| # Calculate concurrency metrics | |
| ideal_time = args.wait | |
| actual_concurrency = (args.requests * args.wait) / result["total_time"] | |
| print() | |
| print(f"Ideal time (full concurrency): {ideal_time:.3f}s") | |
| print(f"Effective concurrency: {actual_concurrency:.1f}x") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |