Support string Codex request IDs

This commit is contained in:
Codex
2026-05-25 06:04:31 +00:00
parent b46c4beb86
commit ab5cc4fbfe
3 changed files with 108 additions and 29 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -22,19 +23,77 @@ type Client struct {
mu sync.Mutex
conn *websocket.Conn
pending map[int64]chan rpcResult
pending map[string]chan rpcResult
events chan Event
connected bool
}
type Event struct {
ID *int64
ID *RequestID
Method string
Params json.RawMessage
ServerRequest bool
Err error
}
type RequestID struct {
value any
key string
}
func IntRequestID(id int64) RequestID {
return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)}
}
func ParseRequestIDKey(key string) (RequestID, error) {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "i:") {
id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64)
if err != nil {
return RequestID{}, err
}
return IntRequestID(id), nil
}
if strings.HasPrefix(key, "s:") {
value := strings.TrimPrefix(key, "s:")
return RequestID{value: value, key: "s:" + value}, nil
}
if id, err := strconv.ParseInt(key, 10, 64); err == nil {
return IntRequestID(id), nil
}
if key == "" {
return RequestID{}, errors.New("request id is empty")
}
return RequestID{value: key, key: "s:" + key}, nil
}
func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" || trimmed == "null" {
return RequestID{}, false, nil
}
if strings.HasPrefix(trimmed, "\"") {
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return RequestID{value: value, key: "s:" + value}, true, nil
}
var value int64
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return IntRequestID(value), true, nil
}
func (id RequestID) Key() string {
return id.key
}
func (id RequestID) Value() any {
return id.value
}
type RPCError struct {
Code int `json:"code"`
Message string `json:"message"`
@@ -87,7 +146,7 @@ func New(socketPath, version string) *Client {
return &Client{
socketPath: socketPath,
version: version,
pending: make(map[int64]chan rpcResult),
pending: make(map[string]chan rpcResult),
events: make(chan Event, 128),
}
}
@@ -315,13 +374,24 @@ func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
return result.Data, nil
}
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error {
func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error {
return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result})
}
func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error {
return c.respondServerRequest(ctx, responseEnvelope{
ID: requestID.Value(),
Error: RPCError{Code: code, Message: message},
})
}
func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
done := make(chan error, 1)
go func() {
done <- c.write(responseEnvelope{ID: requestID, Result: result})
done <- c.write(response)
}()
select {
case <-ctx.Done():
@@ -340,12 +410,12 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
c.mu.Unlock()
return errors.New("codex app-server is not connected")
}
c.pending[id] = ch
c.pending[IntRequestID(id).Key()] = ch
c.mu.Unlock()
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
c.mu.Lock()
delete(c.pending, id)
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return err
}
@@ -353,7 +423,7 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
select {
case <-ctx.Done():
c.mu.Lock()
delete(c.pending, id)
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return ctx.Err()
case rpc := <-ch:
@@ -387,17 +457,22 @@ func (c *Client) readLoop(conn *websocket.Conn) {
c.failConnection(conn, err)
return
}
if env.Method != "" && env.ID != nil {
id, hasID, err := ParseRequestID(env.ID)
if err != nil {
c.failConnection(conn, fmt.Errorf("decode request id: %w", err))
return
}
if env.Method != "" && hasID {
c.events <- Event{
ID: env.ID,
ID: &id,
Method: env.Method,
Params: env.Params,
ServerRequest: true,
}
continue
}
if env.ID != nil {
c.completeCall(*env.ID, env.Result, env.Error)
if hasID {
c.completeCall(id, env.Result, env.Error)
continue
}
if env.Method != "" {
@@ -409,10 +484,10 @@ func (c *Client) readLoop(conn *websocket.Conn) {
}
}
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) {
func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) {
c.mu.Lock()
ch := c.pending[id]
delete(c.pending, id)
ch := c.pending[id.Key()]
delete(c.pending, id.Key())
c.mu.Unlock()
if ch == nil {
return
@@ -431,7 +506,7 @@ func (c *Client) failConnection(conn *websocket.Conn, err error) {
c.connected = false
}
pending := c.pending
c.pending = make(map[int64]chan rpcResult)
c.pending = make(map[string]chan rpcResult)
c.mu.Unlock()
for _, ch := range pending {
@@ -501,13 +576,13 @@ type notificationEnvelope struct {
}
type responseEnvelope struct {
ID int64 `json:"id"`
Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"`
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"`
}
type incomingEnvelope struct {
ID *int64 `json:"id,omitempty"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`

View File

@@ -60,7 +60,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
}
if err := conn.WriteJSON(map[string]any{
"id": 99,
"id": "approval-99",
"method": "item/commandExecution/requestApproval",
"params": map[string]any{"threadId": "thr_1"},
}); err != nil {
@@ -138,7 +138,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
serverDone <- err
return
}
if response["id"].(float64) != 99 || response["result"] != "accept" {
if response["id"] != "approval-99" || response["result"] != "accept" {
payload, _ := json.Marshal(response)
serverDone <- unexpectedMessage("approval response", string(payload))
return
@@ -156,11 +156,13 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
}
defer client.Close()
var approvalRequestID RequestID
select {
case event := <-client.Events():
if !event.ServerRequest || event.ID == nil || *event.ID != 99 {
if !event.ServerRequest || event.ID == nil || event.ID.Key() != "s:approval-99" {
t.Fatalf("unexpected event: %+v", event)
}
approvalRequestID = *event.ID
case <-ctx.Done():
t.Fatal(ctx.Err())
}
@@ -182,7 +184,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil {
t.Fatal(err)
}
if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil {
if err := client.RespondServerRequest(ctx, approvalRequestID, "accept"); err != nil {
t.Fatal(err)
}

View File

@@ -1134,7 +1134,7 @@ func (b *Bot) handleApprovalCallback(ctx context.Context, callback *CallbackQuer
if approval.Status != "pending" {
return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Already resolved.")
}
requestID, err := strconv.ParseInt(approval.CodexRequestID, 10, 64)
requestID, err := codexapp.ParseRequestIDKey(approval.CodexRequestID)
if err != nil {
return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Invalid request id.")
}
@@ -1169,6 +1169,9 @@ func (b *Bot) handleCodexEvents(ctx context.Context) {
if event.ServerRequest {
if err := b.handleCodexServerRequest(ctx, event); err != nil {
b.logger.Printf("server request %s: %v", event.Method, err)
if event.ID != nil {
_ = b.codex.RespondServerRequestError(ctx, *event.ID, -32603, err.Error())
}
}
continue
}
@@ -1710,14 +1713,13 @@ func (b *Bot) syncThreadWorkspaceFromCWD(ctx context.Context, codexThreadID, cwd
func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event) error {
if event.ID == nil {
return nil
return errors.New("server request missing id")
}
switch event.Method {
case "item/commandExecution/requestApproval", "item/fileChange/requestApproval", "item/permissions/requestApproval":
case "execCommandApproval", "applyPatchApproval":
default:
b.logger.Printf("unhandled server request: %s", event.Method)
return nil
return fmt.Errorf("unsupported Codex server request: %s", event.Method)
}
var params struct {
ThreadID string `json:"threadId"`
@@ -1747,7 +1749,7 @@ func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event
kind := event.Method
approval, err := b.store.UpsertPendingApproval(ctx, store.PendingApproval{
TelegramUserID: thread.TelegramUserID,
CodexRequestID: strconv.FormatInt(*event.ID, 10),
CodexRequestID: event.ID.Key(),
CodexThreadID: threadID,
TurnID: params.TurnID,
ItemID: itemID,