Support string Codex request IDs

This commit is contained in:
Codex
2026-05-25 06:04:31 +00:00
parent b46c4beb86
commit ab5cc4fbfe
3 changed files with 108 additions and 29 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -22,19 +23,77 @@ type Client struct {
mu sync.Mutex
conn *websocket.Conn
pending map[int64]chan rpcResult
pending map[string]chan rpcResult
events chan Event
connected bool
}
type Event struct {
ID *int64
ID *RequestID
Method string
Params json.RawMessage
ServerRequest bool
Err error
}
type RequestID struct {
value any
key string
}
func IntRequestID(id int64) RequestID {
return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)}
}
func ParseRequestIDKey(key string) (RequestID, error) {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "i:") {
id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64)
if err != nil {
return RequestID{}, err
}
return IntRequestID(id), nil
}
if strings.HasPrefix(key, "s:") {
value := strings.TrimPrefix(key, "s:")
return RequestID{value: value, key: "s:" + value}, nil
}
if id, err := strconv.ParseInt(key, 10, 64); err == nil {
return IntRequestID(id), nil
}
if key == "" {
return RequestID{}, errors.New("request id is empty")
}
return RequestID{value: key, key: "s:" + key}, nil
}
func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" || trimmed == "null" {
return RequestID{}, false, nil
}
if strings.HasPrefix(trimmed, "\"") {
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return RequestID{value: value, key: "s:" + value}, true, nil
}
var value int64
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return IntRequestID(value), true, nil
}
func (id RequestID) Key() string {
return id.key
}
func (id RequestID) Value() any {
return id.value
}
type RPCError struct {
Code int `json:"code"`
Message string `json:"message"`
@@ -87,7 +146,7 @@ func New(socketPath, version string) *Client {
return &Client{
socketPath: socketPath,
version: version,
pending: make(map[int64]chan rpcResult),
pending: make(map[string]chan rpcResult),
events: make(chan Event, 128),
}
}
@@ -315,13 +374,24 @@ func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
return result.Data, nil
}
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error {
func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error {
return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result})
}
func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error {
return c.respondServerRequest(ctx, responseEnvelope{
ID: requestID.Value(),
Error: RPCError{Code: code, Message: message},
})
}
func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
done := make(chan error, 1)
go func() {
done <- c.write(responseEnvelope{ID: requestID, Result: result})
done <- c.write(response)
}()
select {
case <-ctx.Done():
@@ -340,12 +410,12 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
c.mu.Unlock()
return errors.New("codex app-server is not connected")
}
c.pending[id] = ch
c.pending[IntRequestID(id).Key()] = ch
c.mu.Unlock()
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
c.mu.Lock()
delete(c.pending, id)
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return err
}
@@ -353,7 +423,7 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
select {
case <-ctx.Done():
c.mu.Lock()
delete(c.pending, id)
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return ctx.Err()
case rpc := <-ch:
@@ -387,17 +457,22 @@ func (c *Client) readLoop(conn *websocket.Conn) {
c.failConnection(conn, err)
return
}
if env.Method != "" && env.ID != nil {
id, hasID, err := ParseRequestID(env.ID)
if err != nil {
c.failConnection(conn, fmt.Errorf("decode request id: %w", err))
return
}
if env.Method != "" && hasID {
c.events <- Event{
ID: env.ID,
ID: &id,
Method: env.Method,
Params: env.Params,
ServerRequest: true,
}
continue
}
if env.ID != nil {
c.completeCall(*env.ID, env.Result, env.Error)
if hasID {
c.completeCall(id, env.Result, env.Error)
continue
}
if env.Method != "" {
@@ -409,10 +484,10 @@ func (c *Client) readLoop(conn *websocket.Conn) {
}
}
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) {
func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) {
c.mu.Lock()
ch := c.pending[id]
delete(c.pending, id)
ch := c.pending[id.Key()]
delete(c.pending, id.Key())
c.mu.Unlock()
if ch == nil {
return
@@ -431,7 +506,7 @@ func (c *Client) failConnection(conn *websocket.Conn, err error) {
c.connected = false
}
pending := c.pending
c.pending = make(map[int64]chan rpcResult)
c.pending = make(map[string]chan rpcResult)
c.mu.Unlock()
for _, ch := range pending {
@@ -501,13 +576,13 @@ type notificationEnvelope struct {
}
type responseEnvelope struct {
ID int64 `json:"id"`
Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"`
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"`
}
type incomingEnvelope struct {
ID *int64 `json:"id,omitempty"`
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`