Files
local-mcp/tests/test_wakeup.py
2026-03-27 03:58:57 +08:00

116 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
tests/test_wakeup.py
Timing and correctness tests for the get_user_request wait loop.
Test 1 Immediate wakeup:
Verifies the asyncio.Event fires within ~10 ms of a new instruction being
enqueued, even when min_wait_seconds has not elapsed yet.
Test 2 Generation safety (concurrent calls):
Simulates two overlapping calls for the same agent_id. The OLDER call must
NOT consume the instruction; only the NEWER (active) call should receive it.
"""
import asyncio
import sys
import threading
import time
sys.path.insert(0, ".")
from app.database import init_db
from app.services import instruction_service
from app.services.config_service import update_config, get_config
WAKEUP_DELAY = 1.5
MIN_WAIT = 8
PASS_THRESH = 4.0
def run():
init_db("data/local_mcp.sqlite3")
update_config(default_wait_seconds=MIN_WAIT)
cfg = get_config()
print(f"min_wait_seconds = {cfg.default_wait_seconds} (wakeup in {WAKEUP_DELAY}s)")
print()
# ── Test 1: Immediate wakeup ───────────────────────────────────────────
async def _test1():
await instruction_service.init_wakeup()
t0 = time.monotonic()
def _add():
time.sleep(WAKEUP_DELAY)
item = instruction_service.create_instruction("Wakeup-timing-test")
print(f"[T1 thread] instruction added t={time.monotonic()-t0:.2f}s")
threading.Thread(target=_add, daemon=True).start()
from app.mcp_server import get_user_request
result = await get_user_request(agent_id="timing-test", wait_seconds=0)
elapsed = time.monotonic() - t0
print(f"[T1] Tool returned t={elapsed:.2f}s result_type={result['result_type']}")
if elapsed < PASS_THRESH:
print(f"[T1] PASS woke up at {elapsed:.2f}s (min_wait={MIN_WAIT}s)")
else:
print(f"[T1] FAIL took {elapsed:.2f}s — wakeup did not fire in time")
sys.exit(1)
asyncio.run(_test1())
print()
# ── Test 2: Generation safety ──────────────────────────────────────────
# Call 1 (old) starts waiting. Before any instruction arrives, Call 2
# (new) also starts. Then an instruction is added. Only Call 2 should
# receive it; Call 1 should step aside and return empty.
async def _test2():
await instruction_service.init_wakeup()
t0 = time.monotonic()
from app.mcp_server import get_user_request, _agent_generations
results = {}
async def _call1():
r = await get_user_request(agent_id="gen-test", wait_seconds=0)
results["call1"] = r
async def _call2():
# Slight delay so Call 1 starts first and registers gen=1
await asyncio.sleep(0.2)
r = await get_user_request(agent_id="gen-test", wait_seconds=0)
results["call2"] = r
def _add():
time.sleep(1.5)
instruction_service.create_instruction("Generation-safety-test")
print(f"[T2 thread] instruction added t={time.monotonic()-t0:.2f}s")
threading.Thread(target=_add, daemon=True).start()
await asyncio.gather(_call1(), _call2())
r1 = results.get("call1", {})
r2 = results.get("call2", {})
print(f"[T2] call1 result_type={r1.get('result_type')} waited={r1.get('waited_seconds')}s")
print(f"[T2] call2 result_type={r2.get('result_type')} waited={r2.get('waited_seconds')}s")
if r2.get("result_type") == "instruction" and r1.get("result_type") != "instruction":
print("[T2] PASS only the newest call received the instruction")
else:
print("[T2] FAIL unexpected result distribution")
sys.exit(1)
asyncio.run(_test2())
# Reset config
update_config(default_wait_seconds=10)
print("\nAll tests passed.")
if __name__ == "__main__":
run()