150 lines
5.2 KiB
Python
Executable File
150 lines
5.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Set a Codex app-server thread cwd in Codex-side state.
|
|
|
|
Updates:
|
|
- <CODEX_HOME>/state_*.sqlite threads.cwd for the thread id
|
|
- the rollout JSONL first session_meta payload.cwd, using rollout_path from state DB
|
|
|
|
This intentionally does not touch downstream client databases.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import shutil
|
|
import sqlite3
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
|
|
|
|
def codex_home(explicit: str | None) -> Path:
|
|
if explicit:
|
|
return Path(explicit).expanduser()
|
|
return Path(os.environ.get("CODEX_HOME") or Path.home() / ".codex").expanduser()
|
|
|
|
|
|
def find_state_db(home: Path, explicit: str | None) -> Path:
|
|
if explicit:
|
|
path = Path(explicit).expanduser()
|
|
if not path.exists():
|
|
raise SystemExit(f"state DB does not exist: {path}")
|
|
return path
|
|
candidates = sorted(home.glob("state_*.sqlite"), key=lambda p: p.stat().st_mtime, reverse=True)
|
|
if not candidates:
|
|
raise SystemExit(f"no state_*.sqlite found under {home}; pass --codex-home or --state-db")
|
|
return candidates[0]
|
|
|
|
|
|
def validate_cwd(cwd: str, require_exists: bool) -> str:
|
|
path = Path(cwd).expanduser()
|
|
if not path.is_absolute():
|
|
raise SystemExit(f"cwd must be absolute: {cwd}")
|
|
clean = str(path.resolve() if path.exists() else path)
|
|
if clean == "/":
|
|
raise SystemExit("refusing to set cwd to filesystem root")
|
|
if require_exists and not Path(clean).is_dir():
|
|
raise SystemExit(f"cwd is not an existing directory: {clean}")
|
|
return clean
|
|
|
|
|
|
def read_thread(conn: sqlite3.Connection, thread_id: str) -> sqlite3.Row:
|
|
conn.row_factory = sqlite3.Row
|
|
row = conn.execute(
|
|
"select id, cwd, rollout_path, updated_at, updated_at_ms from threads where id = ?",
|
|
(thread_id,),
|
|
).fetchone()
|
|
if row is None:
|
|
raise SystemExit(f"thread not found in Codex state DB: {thread_id}")
|
|
return row
|
|
|
|
|
|
def update_state_db(db: Path, thread_id: str, cwd: str) -> sqlite3.Row:
|
|
now_ms = int(time.time() * 1000)
|
|
now_s = now_ms // 1000
|
|
conn = sqlite3.connect(str(db), timeout=10)
|
|
try:
|
|
changed = conn.execute(
|
|
"update threads set cwd = ?, updated_at = ?, updated_at_ms = ? where id = ?",
|
|
(cwd, now_s, now_ms, thread_id),
|
|
).rowcount
|
|
if changed != 1:
|
|
raise SystemExit(f"updated {changed} Codex thread rows, expected 1")
|
|
conn.commit()
|
|
return read_thread(conn, thread_id)
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def update_rollout(path: Path, thread_id: str, cwd: str, backup: bool) -> None:
|
|
if not path.exists():
|
|
raise SystemExit(f"rollout JSONL does not exist: {path}")
|
|
lines = path.read_text(encoding="utf-8").splitlines(keepends=True)
|
|
if not lines:
|
|
raise SystemExit(f"rollout JSONL is empty: {path}")
|
|
first = json.loads(lines[0])
|
|
payload = first.get("payload") or {}
|
|
if first.get("type") != "session_meta" or payload.get("id") != thread_id:
|
|
raise SystemExit(f"first rollout line is not session_meta for {thread_id}: {path}")
|
|
payload["cwd"] = cwd
|
|
first["payload"] = payload
|
|
lines[0] = json.dumps(first, separators=(",", ":")) + "\n"
|
|
|
|
if backup:
|
|
backup_path = path.with_name(path.name + f".bak.{int(time.time())}")
|
|
shutil.copy2(path, backup_path)
|
|
print(f"backup: {backup_path}")
|
|
|
|
fd, tmp_name = tempfile.mkstemp(dir=str(path.parent), prefix=path.name + ".", text=True)
|
|
tmp = Path(tmp_name)
|
|
try:
|
|
with os.fdopen(fd, "w", encoding="utf-8") as handle:
|
|
handle.writelines(lines)
|
|
os.chmod(tmp, path.stat().st_mode & 0o777)
|
|
os.replace(tmp, path)
|
|
finally:
|
|
if tmp.exists():
|
|
tmp.unlink()
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Set Codex thread cwd in Codex-side state")
|
|
parser.add_argument("thread_id")
|
|
parser.add_argument("cwd")
|
|
parser.add_argument("--codex-home", help="Codex home directory; defaults to CODEX_HOME or ~/.codex")
|
|
parser.add_argument("--state-db", help="Path to Codex state_*.sqlite; defaults to newest under Codex home")
|
|
parser.add_argument("--allow-missing-cwd", action="store_true", help="Allow cwd path that does not exist yet")
|
|
parser.add_argument("--no-backup", action="store_true", help="Do not create rollout JSONL backup")
|
|
parser.add_argument("--verify-only", action="store_true", help="Only print current Codex cwd for thread")
|
|
args = parser.parse_args()
|
|
|
|
home = codex_home(args.codex_home)
|
|
db = find_state_db(home, args.state_db)
|
|
cwd = validate_cwd(args.cwd, require_exists=not args.allow_missing_cwd)
|
|
|
|
conn = sqlite3.connect(str(db), timeout=10)
|
|
try:
|
|
row = read_thread(conn, args.thread_id)
|
|
finally:
|
|
conn.close()
|
|
|
|
rollout = Path(row["rollout_path"]).expanduser()
|
|
print(f"state_db: {db}")
|
|
print(f"thread_id: {args.thread_id}")
|
|
print(f"before_cwd: {row['cwd']}")
|
|
print(f"rollout: {rollout}")
|
|
|
|
if args.verify_only:
|
|
return 0 if row["cwd"] == cwd else 1
|
|
|
|
after = update_state_db(db, args.thread_id, cwd)
|
|
update_rollout(rollout, args.thread_id, cwd, backup=not args.no_backup)
|
|
print(f"after_cwd: {after['cwd']}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|