diff --git a/internal/codexapp/client.go b/internal/codexapp/client.go index db3daa9..3d741e1 100644 --- a/internal/codexapp/client.go +++ b/internal/codexapp/client.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "strconv" "strings" "sync" "sync/atomic" @@ -22,19 +23,77 @@ type Client struct { mu sync.Mutex conn *websocket.Conn - pending map[int64]chan rpcResult + pending map[string]chan rpcResult events chan Event connected bool } type Event struct { - ID *int64 + ID *RequestID Method string Params json.RawMessage ServerRequest bool Err error } +type RequestID struct { + value any + key string +} + +func IntRequestID(id int64) RequestID { + return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)} +} + +func ParseRequestIDKey(key string) (RequestID, error) { + key = strings.TrimSpace(key) + if strings.HasPrefix(key, "i:") { + id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64) + if err != nil { + return RequestID{}, err + } + return IntRequestID(id), nil + } + if strings.HasPrefix(key, "s:") { + value := strings.TrimPrefix(key, "s:") + return RequestID{value: value, key: "s:" + value}, nil + } + if id, err := strconv.ParseInt(key, 10, 64); err == nil { + return IntRequestID(id), nil + } + if key == "" { + return RequestID{}, errors.New("request id is empty") + } + return RequestID{value: key, key: "s:" + key}, nil +} + +func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) { + trimmed := strings.TrimSpace(string(raw)) + if trimmed == "" || trimmed == "null" { + return RequestID{}, false, nil + } + if strings.HasPrefix(trimmed, "\"") { + var value string + if err := json.Unmarshal(raw, &value); err != nil { + return RequestID{}, false, err + } + return RequestID{value: value, key: "s:" + value}, true, nil + } + var value int64 + if err := json.Unmarshal(raw, &value); err != nil { + return RequestID{}, false, err + } + return IntRequestID(value), true, nil +} + +func (id RequestID) Key() string { + return id.key +} + +func (id RequestID) Value() any { + return id.value +} + type RPCError struct { Code int `json:"code"` Message string `json:"message"` @@ -87,7 +146,7 @@ func New(socketPath, version string) *Client { return &Client{ socketPath: socketPath, version: version, - pending: make(map[int64]chan rpcResult), + pending: make(map[string]chan rpcResult), events: make(chan Event, 128), } } @@ -315,13 +374,24 @@ func (c *Client) ListModels(ctx context.Context) ([]Model, error) { return result.Data, nil } -func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error { +func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error { + return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result}) +} + +func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error { + return c.respondServerRequest(ctx, responseEnvelope{ + ID: requestID.Value(), + Error: RPCError{Code: code, Message: message}, + }) +} + +func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error { if err := c.EnsureConnected(ctx); err != nil { return err } done := make(chan error, 1) go func() { - done <- c.write(responseEnvelope{ID: requestID, Result: result}) + done <- c.write(response) }() select { case <-ctx.Done(): @@ -340,12 +410,12 @@ func (c *Client) call(ctx context.Context, method string, params any, result any c.mu.Unlock() return errors.New("codex app-server is not connected") } - c.pending[id] = ch + c.pending[IntRequestID(id).Key()] = ch c.mu.Unlock() if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil { c.mu.Lock() - delete(c.pending, id) + delete(c.pending, IntRequestID(id).Key()) c.mu.Unlock() return err } @@ -353,7 +423,7 @@ func (c *Client) call(ctx context.Context, method string, params any, result any select { case <-ctx.Done(): c.mu.Lock() - delete(c.pending, id) + delete(c.pending, IntRequestID(id).Key()) c.mu.Unlock() return ctx.Err() case rpc := <-ch: @@ -387,17 +457,22 @@ func (c *Client) readLoop(conn *websocket.Conn) { c.failConnection(conn, err) return } - if env.Method != "" && env.ID != nil { + id, hasID, err := ParseRequestID(env.ID) + if err != nil { + c.failConnection(conn, fmt.Errorf("decode request id: %w", err)) + return + } + if env.Method != "" && hasID { c.events <- Event{ - ID: env.ID, + ID: &id, Method: env.Method, Params: env.Params, ServerRequest: true, } continue } - if env.ID != nil { - c.completeCall(*env.ID, env.Result, env.Error) + if hasID { + c.completeCall(id, env.Result, env.Error) continue } if env.Method != "" { @@ -409,10 +484,10 @@ func (c *Client) readLoop(conn *websocket.Conn) { } } -func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) { +func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) { c.mu.Lock() - ch := c.pending[id] - delete(c.pending, id) + ch := c.pending[id.Key()] + delete(c.pending, id.Key()) c.mu.Unlock() if ch == nil { return @@ -431,7 +506,7 @@ func (c *Client) failConnection(conn *websocket.Conn, err error) { c.connected = false } pending := c.pending - c.pending = make(map[int64]chan rpcResult) + c.pending = make(map[string]chan rpcResult) c.mu.Unlock() for _, ch := range pending { @@ -501,13 +576,13 @@ type notificationEnvelope struct { } type responseEnvelope struct { - ID int64 `json:"id"` - Result any `json:"result,omitempty"` - Error any `json:"error,omitempty"` + ID any `json:"id"` + Result any `json:"result,omitempty"` + Error any `json:"error,omitempty"` } type incomingEnvelope struct { - ID *int64 `json:"id,omitempty"` + ID json.RawMessage `json:"id,omitempty"` Method string `json:"method,omitempty"` Params json.RawMessage `json:"params,omitempty"` Result json.RawMessage `json:"result,omitempty"` diff --git a/internal/codexapp/client_test.go b/internal/codexapp/client_test.go index 8dfe4ca..3cb18b0 100644 --- a/internal/codexapp/client_test.go +++ b/internal/codexapp/client_test.go @@ -60,7 +60,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) { } if err := conn.WriteJSON(map[string]any{ - "id": 99, + "id": "approval-99", "method": "item/commandExecution/requestApproval", "params": map[string]any{"threadId": "thr_1"}, }); err != nil { @@ -138,7 +138,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) { serverDone <- err return } - if response["id"].(float64) != 99 || response["result"] != "accept" { + if response["id"] != "approval-99" || response["result"] != "accept" { payload, _ := json.Marshal(response) serverDone <- unexpectedMessage("approval response", string(payload)) return @@ -156,11 +156,13 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) { } defer client.Close() + var approvalRequestID RequestID select { case event := <-client.Events(): - if !event.ServerRequest || event.ID == nil || *event.ID != 99 { + if !event.ServerRequest || event.ID == nil || event.ID.Key() != "s:approval-99" { t.Fatalf("unexpected event: %+v", event) } + approvalRequestID = *event.ID case <-ctx.Done(): t.Fatal(ctx.Err()) } @@ -182,7 +184,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) { if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil { t.Fatal(err) } - if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil { + if err := client.RespondServerRequest(ctx, approvalRequestID, "accept"); err != nil { t.Fatal(err) } diff --git a/internal/telegram/bot.go b/internal/telegram/bot.go index e51b283..3a7a7d4 100644 --- a/internal/telegram/bot.go +++ b/internal/telegram/bot.go @@ -1134,7 +1134,7 @@ func (b *Bot) handleApprovalCallback(ctx context.Context, callback *CallbackQuer if approval.Status != "pending" { return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Already resolved.") } - requestID, err := strconv.ParseInt(approval.CodexRequestID, 10, 64) + requestID, err := codexapp.ParseRequestIDKey(approval.CodexRequestID) if err != nil { return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Invalid request id.") } @@ -1169,6 +1169,9 @@ func (b *Bot) handleCodexEvents(ctx context.Context) { if event.ServerRequest { if err := b.handleCodexServerRequest(ctx, event); err != nil { b.logger.Printf("server request %s: %v", event.Method, err) + if event.ID != nil { + _ = b.codex.RespondServerRequestError(ctx, *event.ID, -32603, err.Error()) + } } continue } @@ -1710,14 +1713,13 @@ func (b *Bot) syncThreadWorkspaceFromCWD(ctx context.Context, codexThreadID, cwd func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event) error { if event.ID == nil { - return nil + return errors.New("server request missing id") } switch event.Method { case "item/commandExecution/requestApproval", "item/fileChange/requestApproval", "item/permissions/requestApproval": case "execCommandApproval", "applyPatchApproval": default: - b.logger.Printf("unhandled server request: %s", event.Method) - return nil + return fmt.Errorf("unsupported Codex server request: %s", event.Method) } var params struct { ThreadID string `json:"threadId"` @@ -1747,7 +1749,7 @@ func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event kind := event.Method approval, err := b.store.UpsertPendingApproval(ctx, store.PendingApproval{ TelegramUserID: thread.TelegramUserID, - CodexRequestID: strconv.FormatInt(*event.ID, 10), + CodexRequestID: event.ID.Key(), CodexThreadID: threadID, TurnID: params.TurnID, ItemID: itemID,