Sync thread workspaces from Codex cwd
This commit is contained in:
@@ -53,6 +53,7 @@ type Thread struct {
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Preview string `json:"preview,omitempty"`
|
||||
CWD string `json:"cwd,omitempty"`
|
||||
}
|
||||
|
||||
type Turn struct {
|
||||
@@ -168,6 +169,13 @@ func (c *Client) initialize(ctx context.Context) error {
|
||||
return c.notify("initialized", nil)
|
||||
}
|
||||
|
||||
func threadWithCWD(thread Thread, cwd string) Thread {
|
||||
if thread.CWD == "" {
|
||||
thread.CWD = cwd
|
||||
}
|
||||
return thread
|
||||
}
|
||||
|
||||
func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Thread{}, err
|
||||
@@ -183,11 +191,12 @@ func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (T
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
CWD string `json:"cwd"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/start", params, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
return threadWithCWD(result.Thread, result.CWD), nil
|
||||
}
|
||||
|
||||
func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) {
|
||||
@@ -196,11 +205,12 @@ func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, err
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
CWD string `json:"cwd"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
return threadWithCWD(result.Thread, result.CWD), nil
|
||||
}
|
||||
|
||||
func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error) {
|
||||
@@ -209,6 +219,7 @@ func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
CWD string `json:"cwd"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/read", map[string]any{
|
||||
"threadId": threadID,
|
||||
@@ -216,7 +227,7 @@ func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error
|
||||
}, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
return threadWithCWD(result.Thread, result.CWD), nil
|
||||
}
|
||||
|
||||
func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) {
|
||||
@@ -225,11 +236,12 @@ func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
CWD string `json:"cwd"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
return threadWithCWD(result.Thread, result.CWD), nil
|
||||
}
|
||||
|
||||
func (c *Client) ArchiveThread(ctx context.Context, threadID string) error {
|
||||
|
||||
@@ -79,6 +79,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"id": start["id"],
|
||||
"result": map[string]any{
|
||||
"cwd": "/tmp/project",
|
||||
"thread": map[string]any{"id": "thr_1", "preview": "test"},
|
||||
},
|
||||
}); err != nil {
|
||||
@@ -104,7 +105,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"id": readThread["id"],
|
||||
"result": map[string]any{
|
||||
"thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test"},
|
||||
"thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test", "cwd": "/tmp/project"},
|
||||
},
|
||||
}); err != nil {
|
||||
serverDone <- err
|
||||
@@ -167,14 +168,14 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if thread.ID != "thr_1" {
|
||||
if thread.ID != "thr_1" || thread.CWD != "/tmp/project" {
|
||||
t.Fatalf("unexpected thread: %+v", thread)
|
||||
}
|
||||
readThread, err := client.ReadThread(ctx, "thr_1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if readThread.ID != "thr_1" || readThread.Name != "Read title" {
|
||||
if readThread.ID != "thr_1" || readThread.Name != "Read title" || readThread.CWD != "/tmp/project" {
|
||||
t.Fatalf("unexpected read thread: %+v", readThread)
|
||||
}
|
||||
if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil {
|
||||
|
||||
@@ -445,6 +445,13 @@ func (s *Store) SyncThreadTitleByCodexID(ctx context.Context, codexThreadID, tit
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SyncThreadWorkspace(ctx context.Context, telegramUserID, id, workspaceID int64) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE threads SET workspace_id = ?
|
||||
WHERE telegram_user_id = ? AND id = ?`, workspaceID, telegramUserID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) UpsertPendingApproval(ctx context.Context, approval PendingApproval) (PendingApproval, error) {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO pending_approvals (
|
||||
|
||||
@@ -176,6 +176,20 @@ func TestRenameThread(t *testing.T) {
|
||||
if thread.Title != "synced codex title" {
|
||||
t.Fatalf("title = %q", thread.Title)
|
||||
}
|
||||
ws2, err := st.AddWorkspace(ctx, t.TempDir(), "repo2", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := st.SyncThreadWorkspace(ctx, 42, thread.ID, ws2.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
thread, err = st.GetThreadByID(ctx, 42, thread.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if thread.WorkspaceID != ws2.ID {
|
||||
t.Fatalf("workspace id = %d, want %d", thread.WorkspaceID, ws2.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspacePath(t *testing.T) {
|
||||
|
||||
@@ -270,7 +270,7 @@ func (b *Bot) sendResumeChoices(ctx context.Context, userID, chatID int64, page
|
||||
_, err := b.tg.SendMessage(ctx, chatID, text, SendMessageOptions{})
|
||||
return err
|
||||
}
|
||||
threads = b.syncThreadTitles(ctx, threads)
|
||||
threads = b.syncThreadStates(ctx, threads)
|
||||
hasNext := len(threads) > resumeThreadPageSize
|
||||
if hasNext {
|
||||
threads = threads[:resumeThreadPageSize]
|
||||
@@ -303,13 +303,16 @@ func (b *Bot) resumeThreadByID(ctx context.Context, userID, chatID int64, id int
|
||||
if err != nil {
|
||||
return b.sendError(ctx, chatID, "Could not resume Codex thread", err)
|
||||
}
|
||||
thread, err = b.applyCodexThreadTitle(ctx, thread, resumed)
|
||||
thread, err = b.applyCodexThreadState(ctx, thread, resumed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.store.SetActiveThread(ctx, userID, thread.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.store.SetSessionWorkspace(ctx, userID, thread.WorkspaceID); err != nil {
|
||||
return err
|
||||
}
|
||||
text := fmt.Sprintf("Active thread ID %d: %s", thread.ID, threadDisplayTitle(thread))
|
||||
if messageID != 0 {
|
||||
_, err = b.tg.EditMessageText(ctx, chatID, messageID, EscapeHTML(text), EditMessageTextOptions{ParseMode: "HTML"})
|
||||
@@ -357,7 +360,7 @@ func (b *Bot) renameThread(ctx context.Context, userID, chatID int64, session st
|
||||
return b.sendError(ctx, chatID, "Could not rename Codex thread", err)
|
||||
}
|
||||
if codexThread, readErr := b.codex.ReadThread(ctx, thread.CodexThreadID); readErr == nil {
|
||||
thread, err = b.applyCodexThreadTitle(ctx, thread, codexThread)
|
||||
thread, err = b.applyCodexThreadState(ctx, thread, codexThread)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -381,14 +384,23 @@ func (b *Bot) forkThread(ctx context.Context, userID, chatID int64, session stor
|
||||
if err != nil {
|
||||
return b.sendError(ctx, chatID, "Could not fork Codex thread", err)
|
||||
}
|
||||
workspaceID := thread.WorkspaceID
|
||||
if workspace, ok, workspaceErr := b.workspaceForCodexCWD(ctx, forked.CWD); workspaceErr == nil && ok {
|
||||
workspaceID = workspace.ID
|
||||
} else if workspaceErr != nil {
|
||||
b.logger.Printf("sync fork cwd %s: %v", forked.CWD, workspaceErr)
|
||||
}
|
||||
title := codexThreadTitle(forked, "fork of ID "+strconv.FormatInt(thread.ID, 10))
|
||||
local, err := b.store.CreateThread(ctx, userID, forked.ID, thread.WorkspaceID, title)
|
||||
local, err := b.store.CreateThread(ctx, userID, forked.ID, workspaceID, title)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.store.SetActiveThread(ctx, userID, local.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.store.SetSessionWorkspace(ctx, userID, local.WorkspaceID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("Forked active thread to #%d.", local.ID), SendMessageOptions{})
|
||||
return err
|
||||
}
|
||||
@@ -430,10 +442,13 @@ func (b *Bot) sendStatus(ctx context.Context, userID, chatID int64, session stor
|
||||
if session.ActiveThreadID != 0 {
|
||||
thread = fmt.Sprintf("ID %d", session.ActiveThreadID)
|
||||
if active, err := b.store.GetThreadByID(ctx, userID, session.ActiveThreadID); err == nil {
|
||||
if synced, syncErr := b.syncThreadTitle(ctx, active); syncErr == nil {
|
||||
if synced, syncErr := b.syncThreadState(ctx, active); syncErr == nil {
|
||||
active = synced
|
||||
} else {
|
||||
b.logger.Printf("sync status thread title %s: %v", active.CodexThreadID, syncErr)
|
||||
b.logger.Printf("sync status thread state %s: %v", active.CodexThreadID, syncErr)
|
||||
}
|
||||
if ws, wsErr := b.store.GetWorkspaceByID(ctx, active.WorkspaceID); wsErr == nil {
|
||||
workspace = fmt.Sprintf("%s (%s)", ws.Label, ws.Path)
|
||||
}
|
||||
thread = fmt.Sprintf("ID %d: %s", active.ID, threadDisplayTitle(active))
|
||||
}
|
||||
@@ -726,19 +741,52 @@ func codexThreadTitle(thread codexapp.Thread, fallback string) string {
|
||||
return normalizeThreadTitle(fallback)
|
||||
}
|
||||
|
||||
func (b *Bot) applyCodexThreadTitle(ctx context.Context, thread store.Thread, codexThread codexapp.Thread) (store.Thread, error) {
|
||||
func (b *Bot) applyCodexThreadState(ctx context.Context, thread store.Thread, codexThread codexapp.Thread) (store.Thread, error) {
|
||||
title := codexThreadTitle(codexThread, "")
|
||||
if title == thread.Title {
|
||||
return thread, nil
|
||||
if title != thread.Title {
|
||||
if err := b.store.SyncThreadTitle(ctx, thread.TelegramUserID, thread.ID, title); err != nil {
|
||||
return thread, err
|
||||
}
|
||||
thread.Title = title
|
||||
}
|
||||
if err := b.store.SyncThreadTitle(ctx, thread.TelegramUserID, thread.ID, title); err != nil {
|
||||
workspace, ok, err := b.workspaceForCodexCWD(ctx, codexThread.CWD)
|
||||
if err != nil {
|
||||
return thread, err
|
||||
}
|
||||
thread.Title = title
|
||||
if ok && workspace.ID != thread.WorkspaceID {
|
||||
if err := b.store.SyncThreadWorkspace(ctx, thread.TelegramUserID, thread.ID, workspace.ID); err != nil {
|
||||
return thread, err
|
||||
}
|
||||
thread.WorkspaceID = workspace.ID
|
||||
}
|
||||
return thread, nil
|
||||
}
|
||||
|
||||
func (b *Bot) syncThreadTitle(ctx context.Context, thread store.Thread) (store.Thread, error) {
|
||||
func (b *Bot) workspaceForCodexCWD(ctx context.Context, cwd string) (store.Workspace, bool, error) {
|
||||
cwd = strings.TrimSpace(cwd)
|
||||
if cwd == "" {
|
||||
return store.Workspace{}, false, nil
|
||||
}
|
||||
clean, err := store.ValidateWorkspacePath(cwd)
|
||||
if err != nil {
|
||||
return store.Workspace{}, false, err
|
||||
}
|
||||
workspace, err := b.store.GetWorkspaceByPath(ctx, clean)
|
||||
if err == nil {
|
||||
return workspace, true, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return store.Workspace{}, false, err
|
||||
}
|
||||
label := filepath.Base(clean)
|
||||
workspace, err = b.store.AddWorkspace(ctx, clean, label, false)
|
||||
if err != nil {
|
||||
return store.Workspace{}, false, err
|
||||
}
|
||||
return workspace, true, nil
|
||||
}
|
||||
|
||||
func (b *Bot) syncThreadState(ctx context.Context, thread store.Thread) (store.Thread, error) {
|
||||
if thread.CodexThreadID == "" {
|
||||
return thread, nil
|
||||
}
|
||||
@@ -746,14 +794,14 @@ func (b *Bot) syncThreadTitle(ctx context.Context, thread store.Thread) (store.T
|
||||
if err != nil {
|
||||
return thread, err
|
||||
}
|
||||
return b.applyCodexThreadTitle(ctx, thread, codexThread)
|
||||
return b.applyCodexThreadState(ctx, thread, codexThread)
|
||||
}
|
||||
|
||||
func (b *Bot) syncThreadTitles(ctx context.Context, threads []store.Thread) []store.Thread {
|
||||
func (b *Bot) syncThreadStates(ctx context.Context, threads []store.Thread) []store.Thread {
|
||||
for i := range threads {
|
||||
synced, err := b.syncThreadTitle(ctx, threads[i])
|
||||
synced, err := b.syncThreadState(ctx, threads[i])
|
||||
if err != nil {
|
||||
b.logger.Printf("sync thread title %s: %v", threads[i].CodexThreadID, err)
|
||||
b.logger.Printf("sync thread state %s: %v", threads[i].CodexThreadID, err)
|
||||
continue
|
||||
}
|
||||
threads[i] = synced
|
||||
@@ -770,26 +818,35 @@ func (b *Bot) createNewThread(ctx context.Context, userID, chatID int64, session
|
||||
if err != nil {
|
||||
return store.Thread{}, store.Workspace{}, b.sendError(ctx, chatID, "Could not start Codex thread", err)
|
||||
}
|
||||
title := codexThreadTitle(codexThread, workspace.Label)
|
||||
thread, err := b.store.CreateThread(ctx, userID, codexThread.ID, workspace.ID, title)
|
||||
threadWorkspace := workspace
|
||||
if codexWorkspace, ok, workspaceErr := b.workspaceForCodexCWD(ctx, codexThread.CWD); workspaceErr == nil && ok {
|
||||
threadWorkspace = codexWorkspace
|
||||
} else if workspaceErr != nil {
|
||||
b.logger.Printf("sync new thread cwd %s: %v", codexThread.CWD, workspaceErr)
|
||||
}
|
||||
title := codexThreadTitle(codexThread, threadWorkspace.Label)
|
||||
thread, err := b.store.CreateThread(ctx, userID, codexThread.ID, threadWorkspace.ID, title)
|
||||
if err != nil {
|
||||
return store.Thread{}, store.Workspace{}, err
|
||||
}
|
||||
if err := b.store.SetActiveThread(ctx, userID, thread.ID); err != nil {
|
||||
return store.Thread{}, store.Workspace{}, err
|
||||
}
|
||||
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, workspace.Label), SendMessageOptions{})
|
||||
return thread, workspace, err
|
||||
if err := b.store.SetSessionWorkspace(ctx, userID, thread.WorkspaceID); err != nil {
|
||||
return store.Thread{}, store.Workspace{}, err
|
||||
}
|
||||
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, threadWorkspace.Label), SendMessageOptions{})
|
||||
return thread, threadWorkspace, err
|
||||
}
|
||||
|
||||
func (b *Bot) ensureThread(ctx context.Context, userID, chatID int64, session store.Session) (store.Thread, store.Workspace, error) {
|
||||
if session.ActiveThreadID != 0 {
|
||||
thread, err := b.store.GetThreadByID(ctx, userID, session.ActiveThreadID)
|
||||
if err == nil && !thread.Archived {
|
||||
if synced, syncErr := b.syncThreadTitle(ctx, thread); syncErr == nil {
|
||||
if synced, syncErr := b.syncThreadState(ctx, thread); syncErr == nil {
|
||||
thread = synced
|
||||
} else {
|
||||
b.logger.Printf("sync active thread title %s: %v", thread.CodexThreadID, syncErr)
|
||||
b.logger.Printf("sync active thread state %s: %v", thread.CodexThreadID, syncErr)
|
||||
}
|
||||
workspace, err := b.store.GetWorkspaceByID(ctx, thread.WorkspaceID)
|
||||
return thread, workspace, err
|
||||
|
||||
Reference in New Issue
Block a user