Sync thread workspaces from Codex cwd

This commit is contained in:
Codex
2026-05-21 09:56:47 +00:00
parent 1d433038ab
commit a73f88fe5e
5 changed files with 120 additions and 29 deletions

View File

@@ -53,6 +53,7 @@ type Thread struct {
SessionID string `json:"sessionId,omitempty"`
Name string `json:"name,omitempty"`
Preview string `json:"preview,omitempty"`
CWD string `json:"cwd,omitempty"`
}
type Turn struct {
@@ -168,6 +169,13 @@ func (c *Client) initialize(ctx context.Context) error {
return c.notify("initialized", nil)
}
func threadWithCWD(thread Thread, cwd string) Thread {
if thread.CWD == "" {
thread.CWD = cwd
}
return thread
}
func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
@@ -183,11 +191,12 @@ func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (T
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/start", params, &result); err != nil {
return Thread{}, err
}
return result.Thread, nil
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) {
@@ -196,11 +205,12 @@ func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, err
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil {
return Thread{}, err
}
return result.Thread, nil
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error) {
@@ -209,6 +219,7 @@ func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/read", map[string]any{
"threadId": threadID,
@@ -216,7 +227,7 @@ func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error
}, &result); err != nil {
return Thread{}, err
}
return result.Thread, nil
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) {
@@ -225,11 +236,12 @@ func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil {
return Thread{}, err
}
return result.Thread, nil
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ArchiveThread(ctx context.Context, threadID string) error {

View File

@@ -79,6 +79,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
if err := conn.WriteJSON(map[string]any{
"id": start["id"],
"result": map[string]any{
"cwd": "/tmp/project",
"thread": map[string]any{"id": "thr_1", "preview": "test"},
},
}); err != nil {
@@ -104,7 +105,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
if err := conn.WriteJSON(map[string]any{
"id": readThread["id"],
"result": map[string]any{
"thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test"},
"thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test", "cwd": "/tmp/project"},
},
}); err != nil {
serverDone <- err
@@ -167,14 +168,14 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if thread.ID != "thr_1" {
if thread.ID != "thr_1" || thread.CWD != "/tmp/project" {
t.Fatalf("unexpected thread: %+v", thread)
}
readThread, err := client.ReadThread(ctx, "thr_1")
if err != nil {
t.Fatal(err)
}
if readThread.ID != "thr_1" || readThread.Name != "Read title" {
if readThread.ID != "thr_1" || readThread.Name != "Read title" || readThread.CWD != "/tmp/project" {
t.Fatalf("unexpected read thread: %+v", readThread)
}
if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil {

View File

@@ -445,6 +445,13 @@ func (s *Store) SyncThreadTitleByCodexID(ctx context.Context, codexThreadID, tit
return err
}
func (s *Store) SyncThreadWorkspace(ctx context.Context, telegramUserID, id, workspaceID int64) error {
_, err := s.db.ExecContext(ctx, `
UPDATE threads SET workspace_id = ?
WHERE telegram_user_id = ? AND id = ?`, workspaceID, telegramUserID, id)
return err
}
func (s *Store) UpsertPendingApproval(ctx context.Context, approval PendingApproval) (PendingApproval, error) {
_, err := s.db.ExecContext(ctx, `
INSERT INTO pending_approvals (

View File

@@ -176,6 +176,20 @@ func TestRenameThread(t *testing.T) {
if thread.Title != "synced codex title" {
t.Fatalf("title = %q", thread.Title)
}
ws2, err := st.AddWorkspace(ctx, t.TempDir(), "repo2", false)
if err != nil {
t.Fatal(err)
}
if err := st.SyncThreadWorkspace(ctx, 42, thread.ID, ws2.ID); err != nil {
t.Fatal(err)
}
thread, err = st.GetThreadByID(ctx, 42, thread.ID)
if err != nil {
t.Fatal(err)
}
if thread.WorkspaceID != ws2.ID {
t.Fatalf("workspace id = %d, want %d", thread.WorkspaceID, ws2.ID)
}
}
func TestValidateWorkspacePath(t *testing.T) {

View File

@@ -270,7 +270,7 @@ func (b *Bot) sendResumeChoices(ctx context.Context, userID, chatID int64, page
_, err := b.tg.SendMessage(ctx, chatID, text, SendMessageOptions{})
return err
}
threads = b.syncThreadTitles(ctx, threads)
threads = b.syncThreadStates(ctx, threads)
hasNext := len(threads) > resumeThreadPageSize
if hasNext {
threads = threads[:resumeThreadPageSize]
@@ -303,13 +303,16 @@ func (b *Bot) resumeThreadByID(ctx context.Context, userID, chatID int64, id int
if err != nil {
return b.sendError(ctx, chatID, "Could not resume Codex thread", err)
}
thread, err = b.applyCodexThreadTitle(ctx, thread, resumed)
thread, err = b.applyCodexThreadState(ctx, thread, resumed)
if err != nil {
return err
}
if err := b.store.SetActiveThread(ctx, userID, thread.ID); err != nil {
return err
}
if err := b.store.SetSessionWorkspace(ctx, userID, thread.WorkspaceID); err != nil {
return err
}
text := fmt.Sprintf("Active thread ID %d: %s", thread.ID, threadDisplayTitle(thread))
if messageID != 0 {
_, err = b.tg.EditMessageText(ctx, chatID, messageID, EscapeHTML(text), EditMessageTextOptions{ParseMode: "HTML"})
@@ -357,7 +360,7 @@ func (b *Bot) renameThread(ctx context.Context, userID, chatID int64, session st
return b.sendError(ctx, chatID, "Could not rename Codex thread", err)
}
if codexThread, readErr := b.codex.ReadThread(ctx, thread.CodexThreadID); readErr == nil {
thread, err = b.applyCodexThreadTitle(ctx, thread, codexThread)
thread, err = b.applyCodexThreadState(ctx, thread, codexThread)
if err != nil {
return err
}
@@ -381,14 +384,23 @@ func (b *Bot) forkThread(ctx context.Context, userID, chatID int64, session stor
if err != nil {
return b.sendError(ctx, chatID, "Could not fork Codex thread", err)
}
workspaceID := thread.WorkspaceID
if workspace, ok, workspaceErr := b.workspaceForCodexCWD(ctx, forked.CWD); workspaceErr == nil && ok {
workspaceID = workspace.ID
} else if workspaceErr != nil {
b.logger.Printf("sync fork cwd %s: %v", forked.CWD, workspaceErr)
}
title := codexThreadTitle(forked, "fork of ID "+strconv.FormatInt(thread.ID, 10))
local, err := b.store.CreateThread(ctx, userID, forked.ID, thread.WorkspaceID, title)
local, err := b.store.CreateThread(ctx, userID, forked.ID, workspaceID, title)
if err != nil {
return err
}
if err := b.store.SetActiveThread(ctx, userID, local.ID); err != nil {
return err
}
if err := b.store.SetSessionWorkspace(ctx, userID, local.WorkspaceID); err != nil {
return err
}
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("Forked active thread to #%d.", local.ID), SendMessageOptions{})
return err
}
@@ -430,10 +442,13 @@ func (b *Bot) sendStatus(ctx context.Context, userID, chatID int64, session stor
if session.ActiveThreadID != 0 {
thread = fmt.Sprintf("ID %d", session.ActiveThreadID)
if active, err := b.store.GetThreadByID(ctx, userID, session.ActiveThreadID); err == nil {
if synced, syncErr := b.syncThreadTitle(ctx, active); syncErr == nil {
if synced, syncErr := b.syncThreadState(ctx, active); syncErr == nil {
active = synced
} else {
b.logger.Printf("sync status thread title %s: %v", active.CodexThreadID, syncErr)
b.logger.Printf("sync status thread state %s: %v", active.CodexThreadID, syncErr)
}
if ws, wsErr := b.store.GetWorkspaceByID(ctx, active.WorkspaceID); wsErr == nil {
workspace = fmt.Sprintf("%s (%s)", ws.Label, ws.Path)
}
thread = fmt.Sprintf("ID %d: %s", active.ID, threadDisplayTitle(active))
}
@@ -726,19 +741,52 @@ func codexThreadTitle(thread codexapp.Thread, fallback string) string {
return normalizeThreadTitle(fallback)
}
func (b *Bot) applyCodexThreadTitle(ctx context.Context, thread store.Thread, codexThread codexapp.Thread) (store.Thread, error) {
func (b *Bot) applyCodexThreadState(ctx context.Context, thread store.Thread, codexThread codexapp.Thread) (store.Thread, error) {
title := codexThreadTitle(codexThread, "")
if title == thread.Title {
return thread, nil
}
if title != thread.Title {
if err := b.store.SyncThreadTitle(ctx, thread.TelegramUserID, thread.ID, title); err != nil {
return thread, err
}
thread.Title = title
}
workspace, ok, err := b.workspaceForCodexCWD(ctx, codexThread.CWD)
if err != nil {
return thread, err
}
if ok && workspace.ID != thread.WorkspaceID {
if err := b.store.SyncThreadWorkspace(ctx, thread.TelegramUserID, thread.ID, workspace.ID); err != nil {
return thread, err
}
thread.WorkspaceID = workspace.ID
}
return thread, nil
}
func (b *Bot) syncThreadTitle(ctx context.Context, thread store.Thread) (store.Thread, error) {
func (b *Bot) workspaceForCodexCWD(ctx context.Context, cwd string) (store.Workspace, bool, error) {
cwd = strings.TrimSpace(cwd)
if cwd == "" {
return store.Workspace{}, false, nil
}
clean, err := store.ValidateWorkspacePath(cwd)
if err != nil {
return store.Workspace{}, false, err
}
workspace, err := b.store.GetWorkspaceByPath(ctx, clean)
if err == nil {
return workspace, true, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return store.Workspace{}, false, err
}
label := filepath.Base(clean)
workspace, err = b.store.AddWorkspace(ctx, clean, label, false)
if err != nil {
return store.Workspace{}, false, err
}
return workspace, true, nil
}
func (b *Bot) syncThreadState(ctx context.Context, thread store.Thread) (store.Thread, error) {
if thread.CodexThreadID == "" {
return thread, nil
}
@@ -746,14 +794,14 @@ func (b *Bot) syncThreadTitle(ctx context.Context, thread store.Thread) (store.T
if err != nil {
return thread, err
}
return b.applyCodexThreadTitle(ctx, thread, codexThread)
return b.applyCodexThreadState(ctx, thread, codexThread)
}
func (b *Bot) syncThreadTitles(ctx context.Context, threads []store.Thread) []store.Thread {
func (b *Bot) syncThreadStates(ctx context.Context, threads []store.Thread) []store.Thread {
for i := range threads {
synced, err := b.syncThreadTitle(ctx, threads[i])
synced, err := b.syncThreadState(ctx, threads[i])
if err != nil {
b.logger.Printf("sync thread title %s: %v", threads[i].CodexThreadID, err)
b.logger.Printf("sync thread state %s: %v", threads[i].CodexThreadID, err)
continue
}
threads[i] = synced
@@ -770,26 +818,35 @@ func (b *Bot) createNewThread(ctx context.Context, userID, chatID int64, session
if err != nil {
return store.Thread{}, store.Workspace{}, b.sendError(ctx, chatID, "Could not start Codex thread", err)
}
title := codexThreadTitle(codexThread, workspace.Label)
thread, err := b.store.CreateThread(ctx, userID, codexThread.ID, workspace.ID, title)
threadWorkspace := workspace
if codexWorkspace, ok, workspaceErr := b.workspaceForCodexCWD(ctx, codexThread.CWD); workspaceErr == nil && ok {
threadWorkspace = codexWorkspace
} else if workspaceErr != nil {
b.logger.Printf("sync new thread cwd %s: %v", codexThread.CWD, workspaceErr)
}
title := codexThreadTitle(codexThread, threadWorkspace.Label)
thread, err := b.store.CreateThread(ctx, userID, codexThread.ID, threadWorkspace.ID, title)
if err != nil {
return store.Thread{}, store.Workspace{}, err
}
if err := b.store.SetActiveThread(ctx, userID, thread.ID); err != nil {
return store.Thread{}, store.Workspace{}, err
}
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, workspace.Label), SendMessageOptions{})
return thread, workspace, err
if err := b.store.SetSessionWorkspace(ctx, userID, thread.WorkspaceID); err != nil {
return store.Thread{}, store.Workspace{}, err
}
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, threadWorkspace.Label), SendMessageOptions{})
return thread, threadWorkspace, err
}
func (b *Bot) ensureThread(ctx context.Context, userID, chatID int64, session store.Session) (store.Thread, store.Workspace, error) {
if session.ActiveThreadID != 0 {
thread, err := b.store.GetThreadByID(ctx, userID, session.ActiveThreadID)
if err == nil && !thread.Archived {
if synced, syncErr := b.syncThreadTitle(ctx, thread); syncErr == nil {
if synced, syncErr := b.syncThreadState(ctx, thread); syncErr == nil {
thread = synced
} else {
b.logger.Printf("sync active thread title %s: %v", thread.CodexThreadID, syncErr)
b.logger.Printf("sync active thread state %s: %v", thread.CodexThreadID, syncErr)
}
workspace, err := b.store.GetWorkspaceByID(ctx, thread.WorkspaceID)
return thread, workspace, err