package codexstate import ( "context" "database/sql" "encoding/json" "errors" "fmt" "os" "path/filepath" "strings" "time" _ "modernc.org/sqlite" ) type ThreadCWDUpdate struct { StateDB string Rollout string Before string After string } func SetThreadCWD(ctx context.Context, codexHome, stateDB, threadID, cwd string) (ThreadCWDUpdate, error) { threadID = strings.TrimSpace(threadID) if threadID == "" { return ThreadCWDUpdate{}, errors.New("thread id is required") } clean, err := validateCWD(cwd) if err != nil { return ThreadCWDUpdate{}, err } dbPath, err := findStateDB(codexHome, stateDB) if err != nil { return ThreadCWDUpdate{}, err } row, err := readThreadState(ctx, dbPath, threadID) if err != nil { return ThreadCWDUpdate{}, err } if row.CWD == clean { return ThreadCWDUpdate{StateDB: dbPath, Rollout: row.RolloutPath, Before: row.CWD, After: clean}, nil } if err := updateRolloutCWD(row.RolloutPath, threadID, clean); err != nil { return ThreadCWDUpdate{}, err } if err := updateStateDBCWD(ctx, dbPath, threadID, clean); err != nil { return ThreadCWDUpdate{}, err } return ThreadCWDUpdate{StateDB: dbPath, Rollout: row.RolloutPath, Before: row.CWD, After: clean}, nil } type threadStateRow struct { CWD string RolloutPath string } func validateCWD(cwd string) (string, error) { cwd = strings.TrimSpace(cwd) if cwd == "" { return "", errors.New("cwd is required") } if strings.ContainsRune(cwd, 0) { return "", errors.New("cwd contains a NUL byte") } if !filepath.IsAbs(cwd) { return "", fmt.Errorf("cwd must be absolute: %s", cwd) } clean := filepath.Clean(cwd) if clean == string(filepath.Separator) { return "", errors.New("refusing to set cwd to filesystem root") } return clean, nil } func findStateDB(codexHome, explicit string) (string, error) { if strings.TrimSpace(explicit) != "" { path := filepath.Clean(strings.TrimSpace(explicit)) if _, err := os.Stat(path); err != nil { return "", fmt.Errorf("state DB %s: %w", path, err) } return path, nil } home := strings.TrimSpace(codexHome) if home == "" { if env := strings.TrimSpace(os.Getenv("CODEX_HOME")); env != "" { home = env } else if userHome, err := os.UserHomeDir(); err == nil { home = filepath.Join(userHome, ".codex") } } if home == "" { return "", errors.New("CODEX_HOME is not configured") } matches, err := filepath.Glob(filepath.Join(filepath.Clean(home), "state_*.sqlite")) if err != nil { return "", err } if len(matches) == 0 { return "", fmt.Errorf("no state_*.sqlite found under %s", filepath.Clean(home)) } best := matches[0] bestInfo, _ := os.Stat(best) for _, candidate := range matches[1:] { info, err := os.Stat(candidate) if err != nil { continue } if bestInfo == nil || info.ModTime().After(bestInfo.ModTime()) { best = candidate bestInfo = info } } return best, nil } func readThreadState(ctx context.Context, dbPath, threadID string) (threadStateRow, error) { db, err := sql.Open("sqlite", dbPath) if err != nil { return threadStateRow{}, err } defer db.Close() var row threadStateRow err = db.QueryRowContext(ctx, "SELECT cwd, rollout_path FROM threads WHERE id = ?", threadID).Scan(&row.CWD, &row.RolloutPath) if errors.Is(err, sql.ErrNoRows) { return threadStateRow{}, fmt.Errorf("thread not found in Codex state DB: %s", threadID) } if err != nil { return threadStateRow{}, err } row.RolloutPath = strings.TrimSpace(row.RolloutPath) if row.RolloutPath == "" { return threadStateRow{}, fmt.Errorf("thread %s has no rollout path", threadID) } return row, nil } func updateStateDBCWD(ctx context.Context, dbPath, threadID, cwd string) error { db, err := sql.Open("sqlite", dbPath) if err != nil { return err } defer db.Close() nowMS := time.Now().UnixMilli() result, err := db.ExecContext(ctx, "UPDATE threads SET cwd = ?, updated_at = ?, updated_at_ms = ? WHERE id = ?", cwd, nowMS/1000, nowMS, threadID) if err != nil { return err } changed, err := result.RowsAffected() if err != nil { return err } if changed != 1 { return fmt.Errorf("updated %d Codex thread rows, expected 1", changed) } return nil } func updateRolloutCWD(path, threadID, cwd string) error { path = filepath.Clean(strings.TrimSpace(path)) data, err := os.ReadFile(path) if err != nil { return fmt.Errorf("open rollout %s: %w", path, err) } info, err := os.Stat(path) if err != nil { return err } lines := strings.SplitAfter(string(data), "\n") if len(lines) == 0 || strings.TrimSpace(lines[0]) == "" { return fmt.Errorf("rollout JSONL is empty: %s", path) } firstLine := strings.TrimRight(lines[0], "\r\n") var first map[string]any if err := json.Unmarshal([]byte(firstLine), &first); err != nil { return fmt.Errorf("parse first rollout line: %w", err) } payload, _ := first["payload"].(map[string]any) if first["type"] != "session_meta" || payload == nil || payload["id"] != threadID { return fmt.Errorf("first rollout line is not session_meta for %s", threadID) } payload["cwd"] = cwd first["payload"] = payload updatedFirst, err := json.Marshal(first) if err != nil { return err } lines[0] = string(updatedFirst) + "\n" tmp, err := os.CreateTemp(filepath.Dir(path), filepath.Base(path)+".") if err != nil { return err } tmpPath := tmp.Name() ok := false defer func() { if !ok { _ = os.Remove(tmpPath) } }() for _, line := range lines { if _, err := tmp.WriteString(line); err != nil { _ = tmp.Close() return err } } if err := tmp.Chmod(info.Mode().Perm()); err != nil { _ = tmp.Close() return err } if err := tmp.Close(); err != nil { return err } if err := os.Rename(tmpPath, path); err != nil { return err } ok = true return nil }