openenv-benchmark-ws / test_concurrency.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
73edc95 verified
#!/usr/bin/env python3
"""
Concurrency test for benchmark environment using WebSockets.
Each WebSocket connection gets its own dedicated environment session,
enabling true concurrent execution across multiple sessions.
Run the server first:
cd benchmark && uvicorn server.app:app --port 8000
Then run this script:
python test_concurrency.py --requests 100 --wait 1.0
python test_concurrency.py -n 100 -w 1 --url wss://your-server.hf.space
"""
import argparse
import asyncio
import json
import time
from dataclasses import dataclass
try:
import websockets
except ImportError:
print("Install websockets: pip install websockets")
raise
@dataclass
class RequestResult:
"""Result from a single WebSocket request."""
request_id: int
wait_requested: float
waited_seconds: float
elapsed: float
pid: int
session_hash: str
host_url: str
def convert_to_ws_url(url: str) -> str:
"""Convert HTTP URL to WebSocket URL."""
url = url.rstrip("/")
if url.startswith("http://"):
return "ws://" + url[7:] + "/ws"
elif url.startswith("https://"):
return "wss://" + url[8:] + "/ws"
elif url.startswith("ws://") or url.startswith("wss://"):
return url if url.endswith("/ws") else url + "/ws"
return "ws://" + url + "/ws"
async def ws_session(
ws_url: str,
request_id: int,
wait_seconds: float,
timeout: float = 60.0,
) -> RequestResult:
"""
Run a complete WebSocket session: connect, reset, step, close.
Each session gets its own environment instance on the server.
"""
start = time.perf_counter()
async with websockets.connect(ws_url, open_timeout=timeout) as ws:
# Reset to initialize environment
await ws.send(json.dumps({"type": "reset", "data": {}}))
reset_response = json.loads(await asyncio.wait_for(ws.recv(), timeout))
if reset_response.get("type") == "error":
raise RuntimeError(f"Reset error: {reset_response}")
# Step with wait time
await ws.send(
json.dumps({
"type": "step",
"data": {"wait_seconds": wait_seconds},
})
)
step_response = json.loads(await asyncio.wait_for(ws.recv(), timeout))
if step_response.get("type") == "error":
raise RuntimeError(f"Step error: {step_response}")
# Close cleanly
await ws.send(json.dumps({"type": "close"}))
elapsed = time.perf_counter() - start
obs = step_response.get("data", {}).get("observation", {})
return RequestResult(
request_id=request_id,
wait_requested=wait_seconds,
waited_seconds=obs.get("waited_seconds", 0.0),
elapsed=elapsed,
pid=obs.get("pid", 0),
session_hash=obs.get("session_hash", ""),
host_url=obs.get("host_url", ""),
)
async def run_concurrent_test(
url: str,
num_requests: int,
wait_seconds: float,
timeout: float = 120.0,
) -> dict:
"""Run concurrent WebSocket sessions and collect results."""
ws_url = convert_to_ws_url(url)
print(f"WebSocket URL: {ws_url}")
print(f"Running {num_requests} concurrent sessions, each waiting {wait_seconds}s...")
print()
start = time.perf_counter()
# Launch all sessions concurrently
tasks = [
ws_session(ws_url, i, wait_seconds, timeout) for i in range(num_requests)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
total_time = time.perf_counter() - start
# Process results
successful = [r for r in results if isinstance(r, RequestResult)]
failed = [r for r in results if isinstance(r, Exception)]
if not successful:
print("All requests failed!")
for i, err in enumerate(failed[:5]):
print(f" Error {i}: {err}")
return {"error": "All requests failed"}
avg_time = sum(r.elapsed for r in successful) / len(successful)
unique_pids = set(r.pid for r in successful)
unique_sessions = set(r.session_hash for r in successful)
unique_hosts = set(r.host_url for r in successful)
return {
"num_requests": num_requests,
"successful": len(successful),
"failed": len(failed),
"wait_seconds": wait_seconds,
"total_time": total_time,
"avg_time": avg_time,
"unique_pids": len(unique_pids),
"unique_sessions": len(unique_sessions),
"unique_hosts": len(unique_hosts),
"pids": list(unique_pids)[:10], # First 10 for display
}
async def main():
parser = argparse.ArgumentParser(
description="Test benchmark environment concurrency via WebSocket"
)
parser.add_argument(
"--requests", "-n", type=int, default=10,
help="Number of concurrent WebSocket sessions"
)
parser.add_argument(
"--wait", "-w", type=float, default=1.0,
help="Wait time per request (seconds)"
)
parser.add_argument(
"--url", "-u", type=str, default="http://localhost:8000",
help="Server URL (http/https/ws/wss)"
)
parser.add_argument(
"--timeout", "-t", type=float, default=120.0,
help="Timeout per request (seconds)"
)
args = parser.parse_args()
result = await run_concurrent_test(
args.url, args.requests, args.wait, args.timeout
)
if "error" in result:
return
print(f"Successful: {result['successful']}/{result['num_requests']}")
if result["failed"]:
print(f"Failed: {result['failed']}")
print(f"Total time: {result['total_time']:.3f}s")
print(f"Avg time: {result['avg_time']:.3f}s")
print(f"Unique PIDs: {result['unique_pids']}")
print(f"Unique sessions: {result['unique_sessions']}")
print(f"Unique hosts: {result['unique_hosts']}")
# Calculate concurrency metrics
ideal_time = args.wait
actual_concurrency = (args.requests * args.wait) / result["total_time"]
print()
print(f"Ideal time (full concurrency): {ideal_time:.3f}s")
print(f"Effective concurrency: {actual_concurrency:.1f}x")
if __name__ == "__main__":
asyncio.run(main())