Support string Codex request IDs
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -22,19 +23,77 @@ type Client struct {
|
||||
|
||||
mu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
pending map[int64]chan rpcResult
|
||||
pending map[string]chan rpcResult
|
||||
events chan Event
|
||||
connected bool
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
ID *int64
|
||||
ID *RequestID
|
||||
Method string
|
||||
Params json.RawMessage
|
||||
ServerRequest bool
|
||||
Err error
|
||||
}
|
||||
|
||||
type RequestID struct {
|
||||
value any
|
||||
key string
|
||||
}
|
||||
|
||||
func IntRequestID(id int64) RequestID {
|
||||
return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)}
|
||||
}
|
||||
|
||||
func ParseRequestIDKey(key string) (RequestID, error) {
|
||||
key = strings.TrimSpace(key)
|
||||
if strings.HasPrefix(key, "i:") {
|
||||
id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64)
|
||||
if err != nil {
|
||||
return RequestID{}, err
|
||||
}
|
||||
return IntRequestID(id), nil
|
||||
}
|
||||
if strings.HasPrefix(key, "s:") {
|
||||
value := strings.TrimPrefix(key, "s:")
|
||||
return RequestID{value: value, key: "s:" + value}, nil
|
||||
}
|
||||
if id, err := strconv.ParseInt(key, 10, 64); err == nil {
|
||||
return IntRequestID(id), nil
|
||||
}
|
||||
if key == "" {
|
||||
return RequestID{}, errors.New("request id is empty")
|
||||
}
|
||||
return RequestID{value: key, key: "s:" + key}, nil
|
||||
}
|
||||
|
||||
func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) {
|
||||
trimmed := strings.TrimSpace(string(raw))
|
||||
if trimmed == "" || trimmed == "null" {
|
||||
return RequestID{}, false, nil
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "\"") {
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return RequestID{}, false, err
|
||||
}
|
||||
return RequestID{value: value, key: "s:" + value}, true, nil
|
||||
}
|
||||
var value int64
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return RequestID{}, false, err
|
||||
}
|
||||
return IntRequestID(value), true, nil
|
||||
}
|
||||
|
||||
func (id RequestID) Key() string {
|
||||
return id.key
|
||||
}
|
||||
|
||||
func (id RequestID) Value() any {
|
||||
return id.value
|
||||
}
|
||||
|
||||
type RPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
@@ -87,7 +146,7 @@ func New(socketPath, version string) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
version: version,
|
||||
pending: make(map[int64]chan rpcResult),
|
||||
pending: make(map[string]chan rpcResult),
|
||||
events: make(chan Event, 128),
|
||||
}
|
||||
}
|
||||
@@ -315,13 +374,24 @@ func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
|
||||
return result.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error {
|
||||
func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error {
|
||||
return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result})
|
||||
}
|
||||
|
||||
func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error {
|
||||
return c.respondServerRequest(ctx, responseEnvelope{
|
||||
ID: requestID.Value(),
|
||||
Error: RPCError{Code: code, Message: message},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.write(responseEnvelope{ID: requestID, Result: result})
|
||||
done <- c.write(response)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -340,12 +410,12 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
|
||||
c.mu.Unlock()
|
||||
return errors.New("codex app-server is not connected")
|
||||
}
|
||||
c.pending[id] = ch
|
||||
c.pending[IntRequestID(id).Key()] = ch
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
delete(c.pending, IntRequestID(id).Key())
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
@@ -353,7 +423,7 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
delete(c.pending, IntRequestID(id).Key())
|
||||
c.mu.Unlock()
|
||||
return ctx.Err()
|
||||
case rpc := <-ch:
|
||||
@@ -387,17 +457,22 @@ func (c *Client) readLoop(conn *websocket.Conn) {
|
||||
c.failConnection(conn, err)
|
||||
return
|
||||
}
|
||||
if env.Method != "" && env.ID != nil {
|
||||
id, hasID, err := ParseRequestID(env.ID)
|
||||
if err != nil {
|
||||
c.failConnection(conn, fmt.Errorf("decode request id: %w", err))
|
||||
return
|
||||
}
|
||||
if env.Method != "" && hasID {
|
||||
c.events <- Event{
|
||||
ID: env.ID,
|
||||
ID: &id,
|
||||
Method: env.Method,
|
||||
Params: env.Params,
|
||||
ServerRequest: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
if env.ID != nil {
|
||||
c.completeCall(*env.ID, env.Result, env.Error)
|
||||
if hasID {
|
||||
c.completeCall(id, env.Result, env.Error)
|
||||
continue
|
||||
}
|
||||
if env.Method != "" {
|
||||
@@ -409,10 +484,10 @@ func (c *Client) readLoop(conn *websocket.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) {
|
||||
func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) {
|
||||
c.mu.Lock()
|
||||
ch := c.pending[id]
|
||||
delete(c.pending, id)
|
||||
ch := c.pending[id.Key()]
|
||||
delete(c.pending, id.Key())
|
||||
c.mu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
@@ -431,7 +506,7 @@ func (c *Client) failConnection(conn *websocket.Conn, err error) {
|
||||
c.connected = false
|
||||
}
|
||||
pending := c.pending
|
||||
c.pending = make(map[int64]chan rpcResult)
|
||||
c.pending = make(map[string]chan rpcResult)
|
||||
c.mu.Unlock()
|
||||
|
||||
for _, ch := range pending {
|
||||
@@ -501,13 +576,13 @@ type notificationEnvelope struct {
|
||||
}
|
||||
|
||||
type responseEnvelope struct {
|
||||
ID int64 `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
ID any `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type incomingEnvelope struct {
|
||||
ID *int64 `json:"id,omitempty"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user