Files
codex-telegram-bot/internal/codexstate/state.go
2026-05-21 13:05:53 +00:00

222 lines
5.6 KiB
Go

package codexstate
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
_ "modernc.org/sqlite"
)
type ThreadCWDUpdate struct {
StateDB string
Rollout string
Before string
After string
}
func SetThreadCWD(ctx context.Context, codexHome, stateDB, threadID, cwd string) (ThreadCWDUpdate, error) {
threadID = strings.TrimSpace(threadID)
if threadID == "" {
return ThreadCWDUpdate{}, errors.New("thread id is required")
}
clean, err := validateCWD(cwd)
if err != nil {
return ThreadCWDUpdate{}, err
}
dbPath, err := findStateDB(codexHome, stateDB)
if err != nil {
return ThreadCWDUpdate{}, err
}
row, err := readThreadState(ctx, dbPath, threadID)
if err != nil {
return ThreadCWDUpdate{}, err
}
if row.CWD == clean {
return ThreadCWDUpdate{StateDB: dbPath, Rollout: row.RolloutPath, Before: row.CWD, After: clean}, nil
}
if err := updateRolloutCWD(row.RolloutPath, threadID, clean); err != nil {
return ThreadCWDUpdate{}, err
}
if err := updateStateDBCWD(ctx, dbPath, threadID, clean); err != nil {
return ThreadCWDUpdate{}, err
}
return ThreadCWDUpdate{StateDB: dbPath, Rollout: row.RolloutPath, Before: row.CWD, After: clean}, nil
}
type threadStateRow struct {
CWD string
RolloutPath string
}
func validateCWD(cwd string) (string, error) {
cwd = strings.TrimSpace(cwd)
if cwd == "" {
return "", errors.New("cwd is required")
}
if strings.ContainsRune(cwd, 0) {
return "", errors.New("cwd contains a NUL byte")
}
if !filepath.IsAbs(cwd) {
return "", fmt.Errorf("cwd must be absolute: %s", cwd)
}
clean := filepath.Clean(cwd)
if clean == string(filepath.Separator) {
return "", errors.New("refusing to set cwd to filesystem root")
}
return clean, nil
}
func findStateDB(codexHome, explicit string) (string, error) {
if strings.TrimSpace(explicit) != "" {
path := filepath.Clean(strings.TrimSpace(explicit))
if _, err := os.Stat(path); err != nil {
return "", fmt.Errorf("state DB %s: %w", path, err)
}
return path, nil
}
home := strings.TrimSpace(codexHome)
if home == "" {
if env := strings.TrimSpace(os.Getenv("CODEX_HOME")); env != "" {
home = env
} else if userHome, err := os.UserHomeDir(); err == nil {
home = filepath.Join(userHome, ".codex")
}
}
if home == "" {
return "", errors.New("CODEX_HOME is not configured")
}
matches, err := filepath.Glob(filepath.Join(filepath.Clean(home), "state_*.sqlite"))
if err != nil {
return "", err
}
if len(matches) == 0 {
return "", fmt.Errorf("no state_*.sqlite found under %s", filepath.Clean(home))
}
best := matches[0]
bestInfo, _ := os.Stat(best)
for _, candidate := range matches[1:] {
info, err := os.Stat(candidate)
if err != nil {
continue
}
if bestInfo == nil || info.ModTime().After(bestInfo.ModTime()) {
best = candidate
bestInfo = info
}
}
return best, nil
}
func readThreadState(ctx context.Context, dbPath, threadID string) (threadStateRow, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return threadStateRow{}, err
}
defer db.Close()
var row threadStateRow
err = db.QueryRowContext(ctx, "SELECT cwd, rollout_path FROM threads WHERE id = ?", threadID).Scan(&row.CWD, &row.RolloutPath)
if errors.Is(err, sql.ErrNoRows) {
return threadStateRow{}, fmt.Errorf("thread not found in Codex state DB: %s", threadID)
}
if err != nil {
return threadStateRow{}, err
}
row.RolloutPath = strings.TrimSpace(row.RolloutPath)
if row.RolloutPath == "" {
return threadStateRow{}, fmt.Errorf("thread %s has no rollout path", threadID)
}
return row, nil
}
func updateStateDBCWD(ctx context.Context, dbPath, threadID, cwd string) error {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return err
}
defer db.Close()
nowMS := time.Now().UnixMilli()
result, err := db.ExecContext(ctx, "UPDATE threads SET cwd = ?, updated_at = ?, updated_at_ms = ? WHERE id = ?", cwd, nowMS/1000, nowMS, threadID)
if err != nil {
return err
}
changed, err := result.RowsAffected()
if err != nil {
return err
}
if changed != 1 {
return fmt.Errorf("updated %d Codex thread rows, expected 1", changed)
}
return nil
}
func updateRolloutCWD(path, threadID, cwd string) error {
path = filepath.Clean(strings.TrimSpace(path))
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("open rollout %s: %w", path, err)
}
info, err := os.Stat(path)
if err != nil {
return err
}
lines := strings.SplitAfter(string(data), "\n")
if len(lines) == 0 || strings.TrimSpace(lines[0]) == "" {
return fmt.Errorf("rollout JSONL is empty: %s", path)
}
firstLine := strings.TrimRight(lines[0], "\r\n")
var first map[string]any
if err := json.Unmarshal([]byte(firstLine), &first); err != nil {
return fmt.Errorf("parse first rollout line: %w", err)
}
payload, _ := first["payload"].(map[string]any)
if first["type"] != "session_meta" || payload == nil || payload["id"] != threadID {
return fmt.Errorf("first rollout line is not session_meta for %s", threadID)
}
payload["cwd"] = cwd
first["payload"] = payload
updatedFirst, err := json.Marshal(first)
if err != nil {
return err
}
lines[0] = string(updatedFirst) + "\n"
tmp, err := os.CreateTemp(filepath.Dir(path), filepath.Base(path)+".")
if err != nil {
return err
}
tmpPath := tmp.Name()
ok := false
defer func() {
if !ok {
_ = os.Remove(tmpPath)
}
}()
for _, line := range lines {
if _, err := tmp.WriteString(line); err != nil {
_ = tmp.Close()
return err
}
}
if err := tmp.Chmod(info.Mode().Perm()); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Close(); err != nil {
return err
}
if err := os.Rename(tmpPath, path); err != nil {
return err
}
ok = true
return nil
}