#!/usr/bin/env python3 """Set a Codex app-server thread cwd in Codex-side state. Updates: - /state_*.sqlite threads.cwd for the thread id - the rollout JSONL first session_meta payload.cwd, using rollout_path from state DB This intentionally does not touch downstream client databases. """ from __future__ import annotations import argparse import json import os import shutil import sqlite3 import tempfile import time from pathlib import Path def codex_home(explicit: str | None) -> Path: if explicit: return Path(explicit).expanduser() return Path(os.environ.get("CODEX_HOME") or Path.home() / ".codex").expanduser() def find_state_db(home: Path, explicit: str | None) -> Path: if explicit: path = Path(explicit).expanduser() if not path.exists(): raise SystemExit(f"state DB does not exist: {path}") return path candidates = sorted(home.glob("state_*.sqlite"), key=lambda p: p.stat().st_mtime, reverse=True) if not candidates: raise SystemExit(f"no state_*.sqlite found under {home}; pass --codex-home or --state-db") return candidates[0] def validate_cwd(cwd: str, require_exists: bool) -> str: path = Path(cwd).expanduser() if not path.is_absolute(): raise SystemExit(f"cwd must be absolute: {cwd}") clean = str(path.resolve() if path.exists() else path) if clean == "/": raise SystemExit("refusing to set cwd to filesystem root") if require_exists and not Path(clean).is_dir(): raise SystemExit(f"cwd is not an existing directory: {clean}") return clean def read_thread(conn: sqlite3.Connection, thread_id: str) -> sqlite3.Row: conn.row_factory = sqlite3.Row row = conn.execute( "select id, cwd, rollout_path, updated_at, updated_at_ms from threads where id = ?", (thread_id,), ).fetchone() if row is None: raise SystemExit(f"thread not found in Codex state DB: {thread_id}") return row def update_state_db(db: Path, thread_id: str, cwd: str) -> sqlite3.Row: now_ms = int(time.time() * 1000) now_s = now_ms // 1000 conn = sqlite3.connect(str(db), timeout=10) try: changed = conn.execute( "update threads set cwd = ?, updated_at = ?, updated_at_ms = ? where id = ?", (cwd, now_s, now_ms, thread_id), ).rowcount if changed != 1: raise SystemExit(f"updated {changed} Codex thread rows, expected 1") conn.commit() return read_thread(conn, thread_id) finally: conn.close() def update_rollout(path: Path, thread_id: str, cwd: str, backup: bool) -> None: if not path.exists(): raise SystemExit(f"rollout JSONL does not exist: {path}") lines = path.read_text(encoding="utf-8").splitlines(keepends=True) if not lines: raise SystemExit(f"rollout JSONL is empty: {path}") first = json.loads(lines[0]) payload = first.get("payload") or {} if first.get("type") != "session_meta" or payload.get("id") != thread_id: raise SystemExit(f"first rollout line is not session_meta for {thread_id}: {path}") payload["cwd"] = cwd first["payload"] = payload lines[0] = json.dumps(first, separators=(",", ":")) + "\n" if backup: backup_path = path.with_name(path.name + f".bak.{int(time.time())}") shutil.copy2(path, backup_path) print(f"backup: {backup_path}") fd, tmp_name = tempfile.mkstemp(dir=str(path.parent), prefix=path.name + ".", text=True) tmp = Path(tmp_name) try: with os.fdopen(fd, "w", encoding="utf-8") as handle: handle.writelines(lines) os.chmod(tmp, path.stat().st_mode & 0o777) os.replace(tmp, path) finally: if tmp.exists(): tmp.unlink() def main() -> int: parser = argparse.ArgumentParser(description="Set Codex thread cwd in Codex-side state") parser.add_argument("thread_id") parser.add_argument("cwd") parser.add_argument("--codex-home", help="Codex home directory; defaults to CODEX_HOME or ~/.codex") parser.add_argument("--state-db", help="Path to Codex state_*.sqlite; defaults to newest under Codex home") parser.add_argument("--allow-missing-cwd", action="store_true", help="Allow cwd path that does not exist yet") parser.add_argument("--no-backup", action="store_true", help="Do not create rollout JSONL backup") parser.add_argument("--verify-only", action="store_true", help="Only print current Codex cwd for thread") args = parser.parse_args() home = codex_home(args.codex_home) db = find_state_db(home, args.state_db) cwd = validate_cwd(args.cwd, require_exists=not args.allow_missing_cwd) conn = sqlite3.connect(str(db), timeout=10) try: row = read_thread(conn, args.thread_id) finally: conn.close() rollout = Path(row["rollout_path"]).expanduser() print(f"state_db: {db}") print(f"thread_id: {args.thread_id}") print(f"before_cwd: {row['cwd']}") print(f"rollout: {rollout}") if args.verify_only: return 0 if row["cwd"] == cwd else 1 after = update_state_db(db, args.thread_id, cwd) update_rollout(rollout, args.thread_id, cwd, backup=not args.no_backup) print(f"after_cwd: {after['cwd']}") return 0 if __name__ == "__main__": raise SystemExit(main())