From a73f88fe5e22b395f743b7f2117edb0a732513d1 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 21 May 2026 09:56:47 +0000 Subject: [PATCH] Sync thread workspaces from Codex cwd --- internal/codexapp/client.go | 20 ++++-- internal/codexapp/client_test.go | 7 ++- internal/store/store.go | 7 +++ internal/store/store_test.go | 14 +++++ internal/telegram/bot.go | 101 ++++++++++++++++++++++++------- 5 files changed, 120 insertions(+), 29 deletions(-) diff --git a/internal/codexapp/client.go b/internal/codexapp/client.go index b65c601..91780d7 100644 --- a/internal/codexapp/client.go +++ b/internal/codexapp/client.go @@ -53,6 +53,7 @@ type Thread struct { SessionID string `json:"sessionId,omitempty"` Name string `json:"name,omitempty"` Preview string `json:"preview,omitempty"` + CWD string `json:"cwd,omitempty"` } type Turn struct { @@ -168,6 +169,13 @@ func (c *Client) initialize(ctx context.Context) error { return c.notify("initialized", nil) } +func threadWithCWD(thread Thread, cwd string) Thread { + if thread.CWD == "" { + thread.CWD = cwd + } + return thread +} + func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) { if err := c.EnsureConnected(ctx); err != nil { return Thread{}, err @@ -183,11 +191,12 @@ func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (T } var result struct { Thread Thread `json:"thread"` + CWD string `json:"cwd"` } if err := c.call(ctx, "thread/start", params, &result); err != nil { return Thread{}, err } - return result.Thread, nil + return threadWithCWD(result.Thread, result.CWD), nil } func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) { @@ -196,11 +205,12 @@ func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, err } var result struct { Thread Thread `json:"thread"` + CWD string `json:"cwd"` } if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil { return Thread{}, err } - return result.Thread, nil + return threadWithCWD(result.Thread, result.CWD), nil } func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error) { @@ -209,6 +219,7 @@ func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error } var result struct { Thread Thread `json:"thread"` + CWD string `json:"cwd"` } if err := c.call(ctx, "thread/read", map[string]any{ "threadId": threadID, @@ -216,7 +227,7 @@ func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error }, &result); err != nil { return Thread{}, err } - return result.Thread, nil + return threadWithCWD(result.Thread, result.CWD), nil } func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) { @@ -225,11 +236,12 @@ func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error } var result struct { Thread Thread `json:"thread"` + CWD string `json:"cwd"` } if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil { return Thread{}, err } - return result.Thread, nil + return threadWithCWD(result.Thread, result.CWD), nil } func (c *Client) ArchiveThread(ctx context.Context, threadID string) error { diff --git a/internal/codexapp/client_test.go b/internal/codexapp/client_test.go index f4bc85b..a57e3a2 100644 --- a/internal/codexapp/client_test.go +++ b/internal/codexapp/client_test.go @@ -79,6 +79,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) { if err := conn.WriteJSON(map[string]any{ "id": start["id"], "result": map[string]any{ + "cwd": "/tmp/project", "thread": map[string]any{"id": "thr_1", "preview": "test"}, }, }); err != nil { @@ -104,7 +105,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) { if err := conn.WriteJSON(map[string]any{ "id": readThread["id"], "result": map[string]any{ - "thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test"}, + "thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test", "cwd": "/tmp/project"}, }, }); err != nil { serverDone <- err @@ -167,14 +168,14 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) { if err != nil { t.Fatal(err) } - if thread.ID != "thr_1" { + if thread.ID != "thr_1" || thread.CWD != "/tmp/project" { t.Fatalf("unexpected thread: %+v", thread) } readThread, err := client.ReadThread(ctx, "thr_1") if err != nil { t.Fatal(err) } - if readThread.ID != "thr_1" || readThread.Name != "Read title" { + if readThread.ID != "thr_1" || readThread.Name != "Read title" || readThread.CWD != "/tmp/project" { t.Fatalf("unexpected read thread: %+v", readThread) } if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil { diff --git a/internal/store/store.go b/internal/store/store.go index 24ada99..d7cc7e1 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -445,6 +445,13 @@ func (s *Store) SyncThreadTitleByCodexID(ctx context.Context, codexThreadID, tit return err } +func (s *Store) SyncThreadWorkspace(ctx context.Context, telegramUserID, id, workspaceID int64) error { + _, err := s.db.ExecContext(ctx, ` +UPDATE threads SET workspace_id = ? +WHERE telegram_user_id = ? AND id = ?`, workspaceID, telegramUserID, id) + return err +} + func (s *Store) UpsertPendingApproval(ctx context.Context, approval PendingApproval) (PendingApproval, error) { _, err := s.db.ExecContext(ctx, ` INSERT INTO pending_approvals ( diff --git a/internal/store/store_test.go b/internal/store/store_test.go index aba02a3..ca99599 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -176,6 +176,20 @@ func TestRenameThread(t *testing.T) { if thread.Title != "synced codex title" { t.Fatalf("title = %q", thread.Title) } + ws2, err := st.AddWorkspace(ctx, t.TempDir(), "repo2", false) + if err != nil { + t.Fatal(err) + } + if err := st.SyncThreadWorkspace(ctx, 42, thread.ID, ws2.ID); err != nil { + t.Fatal(err) + } + thread, err = st.GetThreadByID(ctx, 42, thread.ID) + if err != nil { + t.Fatal(err) + } + if thread.WorkspaceID != ws2.ID { + t.Fatalf("workspace id = %d, want %d", thread.WorkspaceID, ws2.ID) + } } func TestValidateWorkspacePath(t *testing.T) { diff --git a/internal/telegram/bot.go b/internal/telegram/bot.go index a4adc70..9374071 100644 --- a/internal/telegram/bot.go +++ b/internal/telegram/bot.go @@ -270,7 +270,7 @@ func (b *Bot) sendResumeChoices(ctx context.Context, userID, chatID int64, page _, err := b.tg.SendMessage(ctx, chatID, text, SendMessageOptions{}) return err } - threads = b.syncThreadTitles(ctx, threads) + threads = b.syncThreadStates(ctx, threads) hasNext := len(threads) > resumeThreadPageSize if hasNext { threads = threads[:resumeThreadPageSize] @@ -303,13 +303,16 @@ func (b *Bot) resumeThreadByID(ctx context.Context, userID, chatID int64, id int if err != nil { return b.sendError(ctx, chatID, "Could not resume Codex thread", err) } - thread, err = b.applyCodexThreadTitle(ctx, thread, resumed) + thread, err = b.applyCodexThreadState(ctx, thread, resumed) if err != nil { return err } if err := b.store.SetActiveThread(ctx, userID, thread.ID); err != nil { return err } + if err := b.store.SetSessionWorkspace(ctx, userID, thread.WorkspaceID); err != nil { + return err + } text := fmt.Sprintf("Active thread ID %d: %s", thread.ID, threadDisplayTitle(thread)) if messageID != 0 { _, err = b.tg.EditMessageText(ctx, chatID, messageID, EscapeHTML(text), EditMessageTextOptions{ParseMode: "HTML"}) @@ -357,7 +360,7 @@ func (b *Bot) renameThread(ctx context.Context, userID, chatID int64, session st return b.sendError(ctx, chatID, "Could not rename Codex thread", err) } if codexThread, readErr := b.codex.ReadThread(ctx, thread.CodexThreadID); readErr == nil { - thread, err = b.applyCodexThreadTitle(ctx, thread, codexThread) + thread, err = b.applyCodexThreadState(ctx, thread, codexThread) if err != nil { return err } @@ -381,14 +384,23 @@ func (b *Bot) forkThread(ctx context.Context, userID, chatID int64, session stor if err != nil { return b.sendError(ctx, chatID, "Could not fork Codex thread", err) } + workspaceID := thread.WorkspaceID + if workspace, ok, workspaceErr := b.workspaceForCodexCWD(ctx, forked.CWD); workspaceErr == nil && ok { + workspaceID = workspace.ID + } else if workspaceErr != nil { + b.logger.Printf("sync fork cwd %s: %v", forked.CWD, workspaceErr) + } title := codexThreadTitle(forked, "fork of ID "+strconv.FormatInt(thread.ID, 10)) - local, err := b.store.CreateThread(ctx, userID, forked.ID, thread.WorkspaceID, title) + local, err := b.store.CreateThread(ctx, userID, forked.ID, workspaceID, title) if err != nil { return err } if err := b.store.SetActiveThread(ctx, userID, local.ID); err != nil { return err } + if err := b.store.SetSessionWorkspace(ctx, userID, local.WorkspaceID); err != nil { + return err + } _, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("Forked active thread to #%d.", local.ID), SendMessageOptions{}) return err } @@ -430,10 +442,13 @@ func (b *Bot) sendStatus(ctx context.Context, userID, chatID int64, session stor if session.ActiveThreadID != 0 { thread = fmt.Sprintf("ID %d", session.ActiveThreadID) if active, err := b.store.GetThreadByID(ctx, userID, session.ActiveThreadID); err == nil { - if synced, syncErr := b.syncThreadTitle(ctx, active); syncErr == nil { + if synced, syncErr := b.syncThreadState(ctx, active); syncErr == nil { active = synced } else { - b.logger.Printf("sync status thread title %s: %v", active.CodexThreadID, syncErr) + b.logger.Printf("sync status thread state %s: %v", active.CodexThreadID, syncErr) + } + if ws, wsErr := b.store.GetWorkspaceByID(ctx, active.WorkspaceID); wsErr == nil { + workspace = fmt.Sprintf("%s (%s)", ws.Label, ws.Path) } thread = fmt.Sprintf("ID %d: %s", active.ID, threadDisplayTitle(active)) } @@ -726,19 +741,52 @@ func codexThreadTitle(thread codexapp.Thread, fallback string) string { return normalizeThreadTitle(fallback) } -func (b *Bot) applyCodexThreadTitle(ctx context.Context, thread store.Thread, codexThread codexapp.Thread) (store.Thread, error) { +func (b *Bot) applyCodexThreadState(ctx context.Context, thread store.Thread, codexThread codexapp.Thread) (store.Thread, error) { title := codexThreadTitle(codexThread, "") - if title == thread.Title { - return thread, nil + if title != thread.Title { + if err := b.store.SyncThreadTitle(ctx, thread.TelegramUserID, thread.ID, title); err != nil { + return thread, err + } + thread.Title = title } - if err := b.store.SyncThreadTitle(ctx, thread.TelegramUserID, thread.ID, title); err != nil { + workspace, ok, err := b.workspaceForCodexCWD(ctx, codexThread.CWD) + if err != nil { return thread, err } - thread.Title = title + if ok && workspace.ID != thread.WorkspaceID { + if err := b.store.SyncThreadWorkspace(ctx, thread.TelegramUserID, thread.ID, workspace.ID); err != nil { + return thread, err + } + thread.WorkspaceID = workspace.ID + } return thread, nil } -func (b *Bot) syncThreadTitle(ctx context.Context, thread store.Thread) (store.Thread, error) { +func (b *Bot) workspaceForCodexCWD(ctx context.Context, cwd string) (store.Workspace, bool, error) { + cwd = strings.TrimSpace(cwd) + if cwd == "" { + return store.Workspace{}, false, nil + } + clean, err := store.ValidateWorkspacePath(cwd) + if err != nil { + return store.Workspace{}, false, err + } + workspace, err := b.store.GetWorkspaceByPath(ctx, clean) + if err == nil { + return workspace, true, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return store.Workspace{}, false, err + } + label := filepath.Base(clean) + workspace, err = b.store.AddWorkspace(ctx, clean, label, false) + if err != nil { + return store.Workspace{}, false, err + } + return workspace, true, nil +} + +func (b *Bot) syncThreadState(ctx context.Context, thread store.Thread) (store.Thread, error) { if thread.CodexThreadID == "" { return thread, nil } @@ -746,14 +794,14 @@ func (b *Bot) syncThreadTitle(ctx context.Context, thread store.Thread) (store.T if err != nil { return thread, err } - return b.applyCodexThreadTitle(ctx, thread, codexThread) + return b.applyCodexThreadState(ctx, thread, codexThread) } -func (b *Bot) syncThreadTitles(ctx context.Context, threads []store.Thread) []store.Thread { +func (b *Bot) syncThreadStates(ctx context.Context, threads []store.Thread) []store.Thread { for i := range threads { - synced, err := b.syncThreadTitle(ctx, threads[i]) + synced, err := b.syncThreadState(ctx, threads[i]) if err != nil { - b.logger.Printf("sync thread title %s: %v", threads[i].CodexThreadID, err) + b.logger.Printf("sync thread state %s: %v", threads[i].CodexThreadID, err) continue } threads[i] = synced @@ -770,26 +818,35 @@ func (b *Bot) createNewThread(ctx context.Context, userID, chatID int64, session if err != nil { return store.Thread{}, store.Workspace{}, b.sendError(ctx, chatID, "Could not start Codex thread", err) } - title := codexThreadTitle(codexThread, workspace.Label) - thread, err := b.store.CreateThread(ctx, userID, codexThread.ID, workspace.ID, title) + threadWorkspace := workspace + if codexWorkspace, ok, workspaceErr := b.workspaceForCodexCWD(ctx, codexThread.CWD); workspaceErr == nil && ok { + threadWorkspace = codexWorkspace + } else if workspaceErr != nil { + b.logger.Printf("sync new thread cwd %s: %v", codexThread.CWD, workspaceErr) + } + title := codexThreadTitle(codexThread, threadWorkspace.Label) + thread, err := b.store.CreateThread(ctx, userID, codexThread.ID, threadWorkspace.ID, title) if err != nil { return store.Thread{}, store.Workspace{}, err } if err := b.store.SetActiveThread(ctx, userID, thread.ID); err != nil { return store.Thread{}, store.Workspace{}, err } - _, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, workspace.Label), SendMessageOptions{}) - return thread, workspace, err + if err := b.store.SetSessionWorkspace(ctx, userID, thread.WorkspaceID); err != nil { + return store.Thread{}, store.Workspace{}, err + } + _, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, threadWorkspace.Label), SendMessageOptions{}) + return thread, threadWorkspace, err } func (b *Bot) ensureThread(ctx context.Context, userID, chatID int64, session store.Session) (store.Thread, store.Workspace, error) { if session.ActiveThreadID != 0 { thread, err := b.store.GetThreadByID(ctx, userID, session.ActiveThreadID) if err == nil && !thread.Archived { - if synced, syncErr := b.syncThreadTitle(ctx, thread); syncErr == nil { + if synced, syncErr := b.syncThreadState(ctx, thread); syncErr == nil { thread = synced } else { - b.logger.Printf("sync active thread title %s: %v", thread.CodexThreadID, syncErr) + b.logger.Printf("sync active thread state %s: %v", thread.CodexThreadID, syncErr) } workspace, err := b.store.GetWorkspaceByID(ctx, thread.WorkspaceID) return thread, workspace, err