Support string Codex request IDs

This commit is contained in:
Codex
2026-05-25 06:04:31 +00:00
parent b46c4beb86
commit ab5cc4fbfe
3 changed files with 108 additions and 29 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -22,19 +23,77 @@ type Client struct {
mu sync.Mutex mu sync.Mutex
conn *websocket.Conn conn *websocket.Conn
pending map[int64]chan rpcResult pending map[string]chan rpcResult
events chan Event events chan Event
connected bool connected bool
} }
type Event struct { type Event struct {
ID *int64 ID *RequestID
Method string Method string
Params json.RawMessage Params json.RawMessage
ServerRequest bool ServerRequest bool
Err error Err error
} }
type RequestID struct {
value any
key string
}
func IntRequestID(id int64) RequestID {
return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)}
}
func ParseRequestIDKey(key string) (RequestID, error) {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "i:") {
id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64)
if err != nil {
return RequestID{}, err
}
return IntRequestID(id), nil
}
if strings.HasPrefix(key, "s:") {
value := strings.TrimPrefix(key, "s:")
return RequestID{value: value, key: "s:" + value}, nil
}
if id, err := strconv.ParseInt(key, 10, 64); err == nil {
return IntRequestID(id), nil
}
if key == "" {
return RequestID{}, errors.New("request id is empty")
}
return RequestID{value: key, key: "s:" + key}, nil
}
func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" || trimmed == "null" {
return RequestID{}, false, nil
}
if strings.HasPrefix(trimmed, "\"") {
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return RequestID{value: value, key: "s:" + value}, true, nil
}
var value int64
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return IntRequestID(value), true, nil
}
func (id RequestID) Key() string {
return id.key
}
func (id RequestID) Value() any {
return id.value
}
type RPCError struct { type RPCError struct {
Code int `json:"code"` Code int `json:"code"`
Message string `json:"message"` Message string `json:"message"`
@@ -87,7 +146,7 @@ func New(socketPath, version string) *Client {
return &Client{ return &Client{
socketPath: socketPath, socketPath: socketPath,
version: version, version: version,
pending: make(map[int64]chan rpcResult), pending: make(map[string]chan rpcResult),
events: make(chan Event, 128), events: make(chan Event, 128),
} }
} }
@@ -315,13 +374,24 @@ func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
return result.Data, nil return result.Data, nil
} }
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error { func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error {
return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result})
}
func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error {
return c.respondServerRequest(ctx, responseEnvelope{
ID: requestID.Value(),
Error: RPCError{Code: code, Message: message},
})
}
func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error {
if err := c.EnsureConnected(ctx); err != nil { if err := c.EnsureConnected(ctx); err != nil {
return err return err
} }
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- c.write(responseEnvelope{ID: requestID, Result: result}) done <- c.write(response)
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -340,12 +410,12 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
c.mu.Unlock() c.mu.Unlock()
return errors.New("codex app-server is not connected") return errors.New("codex app-server is not connected")
} }
c.pending[id] = ch c.pending[IntRequestID(id).Key()] = ch
c.mu.Unlock() c.mu.Unlock()
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil { if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
c.mu.Lock() c.mu.Lock()
delete(c.pending, id) delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock() c.mu.Unlock()
return err return err
} }
@@ -353,7 +423,7 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.mu.Lock() c.mu.Lock()
delete(c.pending, id) delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock() c.mu.Unlock()
return ctx.Err() return ctx.Err()
case rpc := <-ch: case rpc := <-ch:
@@ -387,17 +457,22 @@ func (c *Client) readLoop(conn *websocket.Conn) {
c.failConnection(conn, err) c.failConnection(conn, err)
return return
} }
if env.Method != "" && env.ID != nil { id, hasID, err := ParseRequestID(env.ID)
if err != nil {
c.failConnection(conn, fmt.Errorf("decode request id: %w", err))
return
}
if env.Method != "" && hasID {
c.events <- Event{ c.events <- Event{
ID: env.ID, ID: &id,
Method: env.Method, Method: env.Method,
Params: env.Params, Params: env.Params,
ServerRequest: true, ServerRequest: true,
} }
continue continue
} }
if env.ID != nil { if hasID {
c.completeCall(*env.ID, env.Result, env.Error) c.completeCall(id, env.Result, env.Error)
continue continue
} }
if env.Method != "" { if env.Method != "" {
@@ -409,10 +484,10 @@ func (c *Client) readLoop(conn *websocket.Conn) {
} }
} }
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) { func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) {
c.mu.Lock() c.mu.Lock()
ch := c.pending[id] ch := c.pending[id.Key()]
delete(c.pending, id) delete(c.pending, id.Key())
c.mu.Unlock() c.mu.Unlock()
if ch == nil { if ch == nil {
return return
@@ -431,7 +506,7 @@ func (c *Client) failConnection(conn *websocket.Conn, err error) {
c.connected = false c.connected = false
} }
pending := c.pending pending := c.pending
c.pending = make(map[int64]chan rpcResult) c.pending = make(map[string]chan rpcResult)
c.mu.Unlock() c.mu.Unlock()
for _, ch := range pending { for _, ch := range pending {
@@ -501,13 +576,13 @@ type notificationEnvelope struct {
} }
type responseEnvelope struct { type responseEnvelope struct {
ID int64 `json:"id"` ID any `json:"id"`
Result any `json:"result,omitempty"` Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"` Error any `json:"error,omitempty"`
} }
type incomingEnvelope struct { type incomingEnvelope struct {
ID *int64 `json:"id,omitempty"` ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method,omitempty"` Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"` Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"` Result json.RawMessage `json:"result,omitempty"`

View File

@@ -60,7 +60,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
} }
if err := conn.WriteJSON(map[string]any{ if err := conn.WriteJSON(map[string]any{
"id": 99, "id": "approval-99",
"method": "item/commandExecution/requestApproval", "method": "item/commandExecution/requestApproval",
"params": map[string]any{"threadId": "thr_1"}, "params": map[string]any{"threadId": "thr_1"},
}); err != nil { }); err != nil {
@@ -138,7 +138,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
serverDone <- err serverDone <- err
return return
} }
if response["id"].(float64) != 99 || response["result"] != "accept" { if response["id"] != "approval-99" || response["result"] != "accept" {
payload, _ := json.Marshal(response) payload, _ := json.Marshal(response)
serverDone <- unexpectedMessage("approval response", string(payload)) serverDone <- unexpectedMessage("approval response", string(payload))
return return
@@ -156,11 +156,13 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
} }
defer client.Close() defer client.Close()
var approvalRequestID RequestID
select { select {
case event := <-client.Events(): case event := <-client.Events():
if !event.ServerRequest || event.ID == nil || *event.ID != 99 { if !event.ServerRequest || event.ID == nil || event.ID.Key() != "s:approval-99" {
t.Fatalf("unexpected event: %+v", event) t.Fatalf("unexpected event: %+v", event)
} }
approvalRequestID = *event.ID
case <-ctx.Done(): case <-ctx.Done():
t.Fatal(ctx.Err()) t.Fatal(ctx.Err())
} }
@@ -182,7 +184,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil { if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil { if err := client.RespondServerRequest(ctx, approvalRequestID, "accept"); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1134,7 +1134,7 @@ func (b *Bot) handleApprovalCallback(ctx context.Context, callback *CallbackQuer
if approval.Status != "pending" { if approval.Status != "pending" {
return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Already resolved.") return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Already resolved.")
} }
requestID, err := strconv.ParseInt(approval.CodexRequestID, 10, 64) requestID, err := codexapp.ParseRequestIDKey(approval.CodexRequestID)
if err != nil { if err != nil {
return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Invalid request id.") return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Invalid request id.")
} }
@@ -1169,6 +1169,9 @@ func (b *Bot) handleCodexEvents(ctx context.Context) {
if event.ServerRequest { if event.ServerRequest {
if err := b.handleCodexServerRequest(ctx, event); err != nil { if err := b.handleCodexServerRequest(ctx, event); err != nil {
b.logger.Printf("server request %s: %v", event.Method, err) b.logger.Printf("server request %s: %v", event.Method, err)
if event.ID != nil {
_ = b.codex.RespondServerRequestError(ctx, *event.ID, -32603, err.Error())
}
} }
continue continue
} }
@@ -1710,14 +1713,13 @@ func (b *Bot) syncThreadWorkspaceFromCWD(ctx context.Context, codexThreadID, cwd
func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event) error { func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event) error {
if event.ID == nil { if event.ID == nil {
return nil return errors.New("server request missing id")
} }
switch event.Method { switch event.Method {
case "item/commandExecution/requestApproval", "item/fileChange/requestApproval", "item/permissions/requestApproval": case "item/commandExecution/requestApproval", "item/fileChange/requestApproval", "item/permissions/requestApproval":
case "execCommandApproval", "applyPatchApproval": case "execCommandApproval", "applyPatchApproval":
default: default:
b.logger.Printf("unhandled server request: %s", event.Method) return fmt.Errorf("unsupported Codex server request: %s", event.Method)
return nil
} }
var params struct { var params struct {
ThreadID string `json:"threadId"` ThreadID string `json:"threadId"`
@@ -1747,7 +1749,7 @@ func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event
kind := event.Method kind := event.Method
approval, err := b.store.UpsertPendingApproval(ctx, store.PendingApproval{ approval, err := b.store.UpsertPendingApproval(ctx, store.PendingApproval{
TelegramUserID: thread.TelegramUserID, TelegramUserID: thread.TelegramUserID,
CodexRequestID: strconv.FormatInt(*event.ID, 10), CodexRequestID: event.ID.Key(),
CodexThreadID: threadID, CodexThreadID: threadID,
TurnID: params.TurnID, TurnID: params.TurnID,
ItemID: itemID, ItemID: itemID,