Bake in thread directives
This commit is contained in:
221
internal/codexstate/state.go
Normal file
221
internal/codexstate/state.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package codexstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type ThreadCWDUpdate struct {
|
||||
StateDB string
|
||||
Rollout string
|
||||
Before string
|
||||
After string
|
||||
}
|
||||
|
||||
func SetThreadCWD(ctx context.Context, codexHome, stateDB, threadID, cwd string) (ThreadCWDUpdate, error) {
|
||||
threadID = strings.TrimSpace(threadID)
|
||||
if threadID == "" {
|
||||
return ThreadCWDUpdate{}, errors.New("thread id is required")
|
||||
}
|
||||
clean, err := validateCWD(cwd)
|
||||
if err != nil {
|
||||
return ThreadCWDUpdate{}, err
|
||||
}
|
||||
dbPath, err := findStateDB(codexHome, stateDB)
|
||||
if err != nil {
|
||||
return ThreadCWDUpdate{}, err
|
||||
}
|
||||
|
||||
row, err := readThreadState(ctx, dbPath, threadID)
|
||||
if err != nil {
|
||||
return ThreadCWDUpdate{}, err
|
||||
}
|
||||
if row.CWD == clean {
|
||||
return ThreadCWDUpdate{StateDB: dbPath, Rollout: row.RolloutPath, Before: row.CWD, After: clean}, nil
|
||||
}
|
||||
if err := updateRolloutCWD(row.RolloutPath, threadID, clean); err != nil {
|
||||
return ThreadCWDUpdate{}, err
|
||||
}
|
||||
if err := updateStateDBCWD(ctx, dbPath, threadID, clean); err != nil {
|
||||
return ThreadCWDUpdate{}, err
|
||||
}
|
||||
return ThreadCWDUpdate{StateDB: dbPath, Rollout: row.RolloutPath, Before: row.CWD, After: clean}, nil
|
||||
}
|
||||
|
||||
type threadStateRow struct {
|
||||
CWD string
|
||||
RolloutPath string
|
||||
}
|
||||
|
||||
func validateCWD(cwd string) (string, error) {
|
||||
cwd = strings.TrimSpace(cwd)
|
||||
if cwd == "" {
|
||||
return "", errors.New("cwd is required")
|
||||
}
|
||||
if strings.ContainsRune(cwd, 0) {
|
||||
return "", errors.New("cwd contains a NUL byte")
|
||||
}
|
||||
if !filepath.IsAbs(cwd) {
|
||||
return "", fmt.Errorf("cwd must be absolute: %s", cwd)
|
||||
}
|
||||
clean := filepath.Clean(cwd)
|
||||
if clean == string(filepath.Separator) {
|
||||
return "", errors.New("refusing to set cwd to filesystem root")
|
||||
}
|
||||
return clean, nil
|
||||
}
|
||||
|
||||
func findStateDB(codexHome, explicit string) (string, error) {
|
||||
if strings.TrimSpace(explicit) != "" {
|
||||
path := filepath.Clean(strings.TrimSpace(explicit))
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return "", fmt.Errorf("state DB %s: %w", path, err)
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
home := strings.TrimSpace(codexHome)
|
||||
if home == "" {
|
||||
if env := strings.TrimSpace(os.Getenv("CODEX_HOME")); env != "" {
|
||||
home = env
|
||||
} else if userHome, err := os.UserHomeDir(); err == nil {
|
||||
home = filepath.Join(userHome, ".codex")
|
||||
}
|
||||
}
|
||||
if home == "" {
|
||||
return "", errors.New("CODEX_HOME is not configured")
|
||||
}
|
||||
matches, err := filepath.Glob(filepath.Join(filepath.Clean(home), "state_*.sqlite"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return "", fmt.Errorf("no state_*.sqlite found under %s", filepath.Clean(home))
|
||||
}
|
||||
best := matches[0]
|
||||
bestInfo, _ := os.Stat(best)
|
||||
for _, candidate := range matches[1:] {
|
||||
info, err := os.Stat(candidate)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if bestInfo == nil || info.ModTime().After(bestInfo.ModTime()) {
|
||||
best = candidate
|
||||
bestInfo = info
|
||||
}
|
||||
}
|
||||
return best, nil
|
||||
}
|
||||
|
||||
func readThreadState(ctx context.Context, dbPath, threadID string) (threadStateRow, error) {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return threadStateRow{}, err
|
||||
}
|
||||
defer db.Close()
|
||||
var row threadStateRow
|
||||
err = db.QueryRowContext(ctx, "SELECT cwd, rollout_path FROM threads WHERE id = ?", threadID).Scan(&row.CWD, &row.RolloutPath)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return threadStateRow{}, fmt.Errorf("thread not found in Codex state DB: %s", threadID)
|
||||
}
|
||||
if err != nil {
|
||||
return threadStateRow{}, err
|
||||
}
|
||||
row.RolloutPath = strings.TrimSpace(row.RolloutPath)
|
||||
if row.RolloutPath == "" {
|
||||
return threadStateRow{}, fmt.Errorf("thread %s has no rollout path", threadID)
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
|
||||
func updateStateDBCWD(ctx context.Context, dbPath, threadID, cwd string) error {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
nowMS := time.Now().UnixMilli()
|
||||
result, err := db.ExecContext(ctx, "UPDATE threads SET cwd = ?, updated_at = ?, updated_at_ms = ? WHERE id = ?", cwd, nowMS/1000, nowMS, threadID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
changed, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if changed != 1 {
|
||||
return fmt.Errorf("updated %d Codex thread rows, expected 1", changed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateRolloutCWD(path, threadID, cwd string) error {
|
||||
path = filepath.Clean(strings.TrimSpace(path))
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open rollout %s: %w", path, err)
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.SplitAfter(string(data), "\n")
|
||||
if len(lines) == 0 || strings.TrimSpace(lines[0]) == "" {
|
||||
return fmt.Errorf("rollout JSONL is empty: %s", path)
|
||||
}
|
||||
firstLine := strings.TrimRight(lines[0], "\r\n")
|
||||
|
||||
var first map[string]any
|
||||
if err := json.Unmarshal([]byte(firstLine), &first); err != nil {
|
||||
return fmt.Errorf("parse first rollout line: %w", err)
|
||||
}
|
||||
payload, _ := first["payload"].(map[string]any)
|
||||
if first["type"] != "session_meta" || payload == nil || payload["id"] != threadID {
|
||||
return fmt.Errorf("first rollout line is not session_meta for %s", threadID)
|
||||
}
|
||||
payload["cwd"] = cwd
|
||||
first["payload"] = payload
|
||||
updatedFirst, err := json.Marshal(first)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lines[0] = string(updatedFirst) + "\n"
|
||||
|
||||
tmp, err := os.CreateTemp(filepath.Dir(path), filepath.Base(path)+".")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
ok := false
|
||||
defer func() {
|
||||
if !ok {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
for _, line := range lines {
|
||||
if _, err := tmp.WriteString(line); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tmp.Chmod(info.Mode().Perm()); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
return err
|
||||
}
|
||||
ok = true
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user