Support string Codex request IDs
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -22,19 +23,77 @@ type Client struct {
|
||||
|
||||
mu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
pending map[int64]chan rpcResult
|
||||
pending map[string]chan rpcResult
|
||||
events chan Event
|
||||
connected bool
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
ID *int64
|
||||
ID *RequestID
|
||||
Method string
|
||||
Params json.RawMessage
|
||||
ServerRequest bool
|
||||
Err error
|
||||
}
|
||||
|
||||
type RequestID struct {
|
||||
value any
|
||||
key string
|
||||
}
|
||||
|
||||
func IntRequestID(id int64) RequestID {
|
||||
return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)}
|
||||
}
|
||||
|
||||
func ParseRequestIDKey(key string) (RequestID, error) {
|
||||
key = strings.TrimSpace(key)
|
||||
if strings.HasPrefix(key, "i:") {
|
||||
id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64)
|
||||
if err != nil {
|
||||
return RequestID{}, err
|
||||
}
|
||||
return IntRequestID(id), nil
|
||||
}
|
||||
if strings.HasPrefix(key, "s:") {
|
||||
value := strings.TrimPrefix(key, "s:")
|
||||
return RequestID{value: value, key: "s:" + value}, nil
|
||||
}
|
||||
if id, err := strconv.ParseInt(key, 10, 64); err == nil {
|
||||
return IntRequestID(id), nil
|
||||
}
|
||||
if key == "" {
|
||||
return RequestID{}, errors.New("request id is empty")
|
||||
}
|
||||
return RequestID{value: key, key: "s:" + key}, nil
|
||||
}
|
||||
|
||||
func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) {
|
||||
trimmed := strings.TrimSpace(string(raw))
|
||||
if trimmed == "" || trimmed == "null" {
|
||||
return RequestID{}, false, nil
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "\"") {
|
||||
var value string
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return RequestID{}, false, err
|
||||
}
|
||||
return RequestID{value: value, key: "s:" + value}, true, nil
|
||||
}
|
||||
var value int64
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return RequestID{}, false, err
|
||||
}
|
||||
return IntRequestID(value), true, nil
|
||||
}
|
||||
|
||||
func (id RequestID) Key() string {
|
||||
return id.key
|
||||
}
|
||||
|
||||
func (id RequestID) Value() any {
|
||||
return id.value
|
||||
}
|
||||
|
||||
type RPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
@@ -87,7 +146,7 @@ func New(socketPath, version string) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
version: version,
|
||||
pending: make(map[int64]chan rpcResult),
|
||||
pending: make(map[string]chan rpcResult),
|
||||
events: make(chan Event, 128),
|
||||
}
|
||||
}
|
||||
@@ -315,13 +374,24 @@ func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
|
||||
return result.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error {
|
||||
func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error {
|
||||
return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result})
|
||||
}
|
||||
|
||||
func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error {
|
||||
return c.respondServerRequest(ctx, responseEnvelope{
|
||||
ID: requestID.Value(),
|
||||
Error: RPCError{Code: code, Message: message},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.write(responseEnvelope{ID: requestID, Result: result})
|
||||
done <- c.write(response)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -340,12 +410,12 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
|
||||
c.mu.Unlock()
|
||||
return errors.New("codex app-server is not connected")
|
||||
}
|
||||
c.pending[id] = ch
|
||||
c.pending[IntRequestID(id).Key()] = ch
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
delete(c.pending, IntRequestID(id).Key())
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
@@ -353,7 +423,7 @@ func (c *Client) call(ctx context.Context, method string, params any, result any
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
delete(c.pending, IntRequestID(id).Key())
|
||||
c.mu.Unlock()
|
||||
return ctx.Err()
|
||||
case rpc := <-ch:
|
||||
@@ -387,17 +457,22 @@ func (c *Client) readLoop(conn *websocket.Conn) {
|
||||
c.failConnection(conn, err)
|
||||
return
|
||||
}
|
||||
if env.Method != "" && env.ID != nil {
|
||||
id, hasID, err := ParseRequestID(env.ID)
|
||||
if err != nil {
|
||||
c.failConnection(conn, fmt.Errorf("decode request id: %w", err))
|
||||
return
|
||||
}
|
||||
if env.Method != "" && hasID {
|
||||
c.events <- Event{
|
||||
ID: env.ID,
|
||||
ID: &id,
|
||||
Method: env.Method,
|
||||
Params: env.Params,
|
||||
ServerRequest: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
if env.ID != nil {
|
||||
c.completeCall(*env.ID, env.Result, env.Error)
|
||||
if hasID {
|
||||
c.completeCall(id, env.Result, env.Error)
|
||||
continue
|
||||
}
|
||||
if env.Method != "" {
|
||||
@@ -409,10 +484,10 @@ func (c *Client) readLoop(conn *websocket.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) {
|
||||
func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) {
|
||||
c.mu.Lock()
|
||||
ch := c.pending[id]
|
||||
delete(c.pending, id)
|
||||
ch := c.pending[id.Key()]
|
||||
delete(c.pending, id.Key())
|
||||
c.mu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
@@ -431,7 +506,7 @@ func (c *Client) failConnection(conn *websocket.Conn, err error) {
|
||||
c.connected = false
|
||||
}
|
||||
pending := c.pending
|
||||
c.pending = make(map[int64]chan rpcResult)
|
||||
c.pending = make(map[string]chan rpcResult)
|
||||
c.mu.Unlock()
|
||||
|
||||
for _, ch := range pending {
|
||||
@@ -501,13 +576,13 @@ type notificationEnvelope struct {
|
||||
}
|
||||
|
||||
type responseEnvelope struct {
|
||||
ID int64 `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
ID any `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type incomingEnvelope struct {
|
||||
ID *int64 `json:"id,omitempty"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"id": 99,
|
||||
"id": "approval-99",
|
||||
"method": "item/commandExecution/requestApproval",
|
||||
"params": map[string]any{"threadId": "thr_1"},
|
||||
}); err != nil {
|
||||
@@ -138,7 +138,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if response["id"].(float64) != 99 || response["result"] != "accept" {
|
||||
if response["id"] != "approval-99" || response["result"] != "accept" {
|
||||
payload, _ := json.Marshal(response)
|
||||
serverDone <- unexpectedMessage("approval response", string(payload))
|
||||
return
|
||||
@@ -156,11 +156,13 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
var approvalRequestID RequestID
|
||||
select {
|
||||
case event := <-client.Events():
|
||||
if !event.ServerRequest || event.ID == nil || *event.ID != 99 {
|
||||
if !event.ServerRequest || event.ID == nil || event.ID.Key() != "s:approval-99" {
|
||||
t.Fatalf("unexpected event: %+v", event)
|
||||
}
|
||||
approvalRequestID = *event.ID
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
@@ -182,7 +184,7 @@ func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil {
|
||||
if err := client.RespondServerRequest(ctx, approvalRequestID, "accept"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1134,7 +1134,7 @@ func (b *Bot) handleApprovalCallback(ctx context.Context, callback *CallbackQuer
|
||||
if approval.Status != "pending" {
|
||||
return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Already resolved.")
|
||||
}
|
||||
requestID, err := strconv.ParseInt(approval.CodexRequestID, 10, 64)
|
||||
requestID, err := codexapp.ParseRequestIDKey(approval.CodexRequestID)
|
||||
if err != nil {
|
||||
return b.tg.AnswerCallbackQuery(ctx, callback.ID, "Invalid request id.")
|
||||
}
|
||||
@@ -1169,6 +1169,9 @@ func (b *Bot) handleCodexEvents(ctx context.Context) {
|
||||
if event.ServerRequest {
|
||||
if err := b.handleCodexServerRequest(ctx, event); err != nil {
|
||||
b.logger.Printf("server request %s: %v", event.Method, err)
|
||||
if event.ID != nil {
|
||||
_ = b.codex.RespondServerRequestError(ctx, *event.ID, -32603, err.Error())
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -1710,14 +1713,13 @@ func (b *Bot) syncThreadWorkspaceFromCWD(ctx context.Context, codexThreadID, cwd
|
||||
|
||||
func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event) error {
|
||||
if event.ID == nil {
|
||||
return nil
|
||||
return errors.New("server request missing id")
|
||||
}
|
||||
switch event.Method {
|
||||
case "item/commandExecution/requestApproval", "item/fileChange/requestApproval", "item/permissions/requestApproval":
|
||||
case "execCommandApproval", "applyPatchApproval":
|
||||
default:
|
||||
b.logger.Printf("unhandled server request: %s", event.Method)
|
||||
return nil
|
||||
return fmt.Errorf("unsupported Codex server request: %s", event.Method)
|
||||
}
|
||||
var params struct {
|
||||
ThreadID string `json:"threadId"`
|
||||
@@ -1747,7 +1749,7 @@ func (b *Bot) handleCodexServerRequest(ctx context.Context, event codexapp.Event
|
||||
kind := event.Method
|
||||
approval, err := b.store.UpsertPendingApproval(ctx, store.PendingApproval{
|
||||
TelegramUserID: thread.TelegramUserID,
|
||||
CodexRequestID: strconv.FormatInt(*event.ID, 10),
|
||||
CodexRequestID: event.ID.Key(),
|
||||
CodexThreadID: threadID,
|
||||
TurnID: params.TurnID,
|
||||
ItemID: itemID,
|
||||
|
||||
Reference in New Issue
Block a user