Files
codex-telegram-bot/internal/codexapp/client.go
Codex c00ffb42f2 Route Codex approvals to Telegram
Force app-server turns to use the user approval reviewer so command approvals surface in the bot on Codex 0.134.

Add focused protocol logs for approval requests and guardian review events to diagnose silent approval stalls.
2026-05-28 09:57:43 +00:00

633 lines
15 KiB
Go

package codexapp
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/gorilla/websocket"
)
type Client struct {
socketPath string
version string
nextID int64
mu sync.Mutex
conn *websocket.Conn
pending map[string]chan rpcResult
events chan Event
connected bool
logger *log.Logger
}
type Event struct {
ID *RequestID
Method string
Params json.RawMessage
ServerRequest bool
Err error
}
type RequestID struct {
value any
key string
}
func IntRequestID(id int64) RequestID {
return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)}
}
func ParseRequestIDKey(key string) (RequestID, error) {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "i:") {
id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64)
if err != nil {
return RequestID{}, err
}
return IntRequestID(id), nil
}
if strings.HasPrefix(key, "s:") {
value := strings.TrimPrefix(key, "s:")
return RequestID{value: value, key: "s:" + value}, nil
}
if id, err := strconv.ParseInt(key, 10, 64); err == nil {
return IntRequestID(id), nil
}
if key == "" {
return RequestID{}, errors.New("request id is empty")
}
return RequestID{value: key, key: "s:" + key}, nil
}
func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" || trimmed == "null" {
return RequestID{}, false, nil
}
if strings.HasPrefix(trimmed, "\"") {
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return RequestID{value: value, key: "s:" + value}, true, nil
}
var value int64
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return IntRequestID(value), true, nil
}
func (id RequestID) Key() string {
return id.key
}
func (id RequestID) Value() any {
return id.value
}
type RPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data,omitempty"`
}
func (e RPCError) Error() string {
if e.Code == 0 {
return e.Message
}
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}
type Thread struct {
ID string `json:"id"`
SessionID string `json:"sessionId,omitempty"`
Name string `json:"name,omitempty"`
Preview string `json:"preview,omitempty"`
CWD string `json:"cwd,omitempty"`
}
type Turn struct {
ID string `json:"id"`
Status string `json:"status"`
}
type InputItem struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
URL string `json:"url,omitempty"`
Path string `json:"path,omitempty"`
}
type Model struct {
ID string `json:"id"`
Model string `json:"model"`
DisplayName string `json:"displayName"`
Description string `json:"description"`
IsDefault bool `json:"isDefault"`
DefaultReasoningEffort string `json:"defaultReasoningEffort"`
SupportedReasoningEfforts []ReasoningEffortOption `json:"supportedReasoningEfforts"`
}
type ReasoningEffortOption struct {
Description string `json:"description"`
ReasoningEffort string `json:"reasoningEffort"`
}
func New(socketPath, version string) *Client {
return &Client{
socketPath: socketPath,
version: version,
pending: make(map[string]chan rpcResult),
events: make(chan Event, 128),
}
}
func (c *Client) SetLogger(logger *log.Logger) {
c.mu.Lock()
defer c.mu.Unlock()
c.logger = logger
}
func (c *Client) Events() <-chan Event {
return c.events
}
func (c *Client) EnsureConnected(ctx context.Context) error {
c.mu.Lock()
connected := c.connected && c.conn != nil
c.mu.Unlock()
if connected {
return nil
}
return c.Connect(ctx)
}
func (c *Client) Connect(ctx context.Context) error {
c.mu.Lock()
if c.connected && c.conn != nil {
c.mu.Unlock()
return nil
}
c.mu.Unlock()
dialer := websocket.Dialer{
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", c.socketPath)
},
}
conn, _, err := dialer.DialContext(ctx, "ws://codex-app-server/", http.Header{})
if err != nil {
return fmt.Errorf("connect codex app-server unix socket %s: %w", c.socketPath, err)
}
c.mu.Lock()
c.conn = conn
c.connected = true
c.mu.Unlock()
go c.readLoop(conn)
if err := c.initialize(ctx); err != nil {
_ = c.Close()
return err
}
return nil
}
func (c *Client) Close() error {
c.mu.Lock()
conn := c.conn
c.conn = nil
c.connected = false
c.mu.Unlock()
if conn == nil {
return nil
}
return conn.Close()
}
func (c *Client) initialize(ctx context.Context) error {
params := map[string]any{
"clientInfo": map[string]any{
"name": "codex_telegram_bot",
"title": "Codex Telegram Bot",
"version": c.version,
},
"capabilities": map[string]any{
"experimentalApi": true,
},
}
var ignored json.RawMessage
if err := c.call(ctx, "initialize", params, &ignored); err != nil {
return err
}
return c.notify("initialized", nil)
}
func threadWithCWD(thread Thread, cwd string) Thread {
if thread.CWD == "" {
thread.CWD = cwd
}
return thread
}
func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
params := map[string]any{
"cwd": cwd,
"approvalPolicy": "on-request",
"approvalsReviewer": "user",
"sandbox": threadSandbox(sandbox),
"serviceName": "codex_telegram_bot",
}
if model != "" {
params["model"] = model
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/start", params, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/read", map[string]any{
"threadId": threadID,
"includeTurns": false,
}, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ArchiveThread(ctx context.Context, threadID string) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
var ignored json.RawMessage
return c.call(ctx, "thread/archive", map[string]any{"threadId": threadID}, &ignored)
}
func (c *Client) SetThreadName(ctx context.Context, threadID, name string) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
var ignored json.RawMessage
return c.call(ctx, "thread/name/set", map[string]any{
"threadId": threadID,
"name": name,
}, &ignored)
}
func (c *Client) StartTurn(ctx context.Context, threadID, cwd, model, reasoningEffort, sandbox string, input []InputItem) (Turn, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Turn{}, err
}
params := map[string]any{
"threadId": threadID,
"input": input,
"approvalPolicy": "on-request",
"approvalsReviewer": "user",
}
if strings.TrimSpace(cwd) != "" {
params["cwd"] = cwd
params["sandboxPolicy"] = SandboxPolicy(sandbox, cwd)
}
if model != "" {
params["model"] = model
}
if reasoningEffort != "" {
params["effort"] = reasoningEffort
}
var result struct {
Turn Turn `json:"turn"`
}
if err := c.call(ctx, "turn/start", params, &result); err != nil {
return Turn{}, err
}
return result.Turn, nil
}
func (c *Client) InterruptTurn(ctx context.Context, threadID, turnID string) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
var ignored json.RawMessage
return c.call(ctx, "turn/interrupt", map[string]any{
"threadId": threadID,
"turnId": turnID,
}, &ignored)
}
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
if err := c.EnsureConnected(ctx); err != nil {
return nil, err
}
var result struct {
Data []Model `json:"data"`
}
if err := c.call(ctx, "model/list", map[string]any{"limit": 50, "includeHidden": false}, &result); err != nil {
return nil, err
}
return result.Data, nil
}
func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error {
return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result})
}
func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error {
return c.respondServerRequest(ctx, responseEnvelope{
ID: requestID.Value(),
Error: RPCError{Code: code, Message: message},
})
}
func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
done := make(chan error, 1)
go func() {
done <- c.write(response)
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
func (c *Client) call(ctx context.Context, method string, params any, result any) error {
id := atomic.AddInt64(&c.nextID, 1)
ch := make(chan rpcResult, 1)
c.mu.Lock()
if c.conn == nil {
c.mu.Unlock()
return errors.New("codex app-server is not connected")
}
c.pending[IntRequestID(id).Key()] = ch
c.mu.Unlock()
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
c.mu.Lock()
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return err
}
select {
case <-ctx.Done():
c.mu.Lock()
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return ctx.Err()
case rpc := <-ch:
if rpc.err != nil {
return rpc.err
}
if result == nil || len(rpc.result) == 0 {
return nil
}
return json.Unmarshal(rpc.result, result)
}
}
func (c *Client) notify(method string, params any) error {
return c.write(notificationEnvelope{Method: method, Params: params})
}
func (c *Client) write(message any) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return errors.New("codex app-server is not connected")
}
return c.conn.WriteJSON(message)
}
func (c *Client) readLoop(conn *websocket.Conn) {
for {
var env incomingEnvelope
if err := conn.ReadJSON(&env); err != nil {
c.failConnection(conn, err)
return
}
id, hasID, err := ParseRequestID(env.ID)
if err != nil {
c.failConnection(conn, fmt.Errorf("decode request id: %w", err))
return
}
if env.Method != "" && hasID {
c.logProtocolEvent("server request", env.Method, id.Key(), env.Params)
c.events <- Event{
ID: &id,
Method: env.Method,
Params: env.Params,
ServerRequest: true,
}
continue
}
if hasID {
c.completeCall(id, env.Result, env.Error)
continue
}
if env.Method != "" {
if shouldLogNotification(env.Method) {
c.logProtocolEvent("notification", env.Method, "", env.Params)
}
c.events <- Event{
Method: env.Method,
Params: env.Params,
}
}
}
}
func (c *Client) logProtocolEvent(kind, method, id string, params json.RawMessage) {
c.mu.Lock()
logger := c.logger
c.mu.Unlock()
if logger == nil {
return
}
if id != "" {
logger.Printf("codex %s: method=%s id=%s params_bytes=%d", kind, method, id, len(params))
return
}
logger.Printf("codex %s: method=%s params_bytes=%d", kind, method, len(params))
}
func shouldLogNotification(method string) bool {
switch method {
case "item/guardianApprovalReview/started", "item/guardianApprovalReview/completed", "guardianWarning", "serverRequest/resolved":
return true
default:
return false
}
}
func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) {
c.mu.Lock()
ch := c.pending[id.Key()]
delete(c.pending, id.Key())
c.mu.Unlock()
if ch == nil {
return
}
if rpcErr != nil {
ch <- rpcResult{err: *rpcErr}
return
}
ch <- rpcResult{result: result}
}
func (c *Client) failConnection(conn *websocket.Conn, err error) {
c.mu.Lock()
if c.conn == conn {
c.conn = nil
c.connected = false
}
pending := c.pending
c.pending = make(map[string]chan rpcResult)
c.mu.Unlock()
for _, ch := range pending {
ch <- rpcResult{err: err}
}
c.events <- Event{Method: "connection/closed", Err: err}
}
func DecodeParams[T any](event Event) (T, error) {
var value T
if len(event.Params) == 0 {
return value, nil
}
err := json.Unmarshal(event.Params, &value)
return value, err
}
func NormalizeSandbox(value string) (string, error) {
switch strings.ToLower(strings.TrimSpace(value)) {
case "", "workspace-write", "workspacewrite":
return "workspace-write", nil
case "read-only", "readonly":
return "read-only", nil
case "danger-full-access", "dangerfullaccess":
return "danger-full-access", nil
default:
return "", fmt.Errorf("unsupported sandbox %q", value)
}
}
func SandboxPolicy(sandbox, cwd string) map[string]any {
normalized, err := NormalizeSandbox(sandbox)
if err != nil {
normalized = "workspace-write"
}
switch normalized {
case "read-only":
return map[string]any{"type": "readOnly"}
case "danger-full-access":
return map[string]any{"type": "dangerFullAccess"}
default:
return map[string]any{
"type": "workspaceWrite",
"writableRoots": []string{cwd},
"networkAccess": false,
}
}
}
func threadSandbox(sandbox string) string {
normalized, err := NormalizeSandbox(sandbox)
if err != nil {
return "workspace-write"
}
return normalized
}
type requestEnvelope struct {
ID int64 `json:"id"`
Method string `json:"method"`
Params any `json:"params,omitempty"`
}
type notificationEnvelope struct {
Method string `json:"method"`
Params any `json:"params,omitempty"`
}
type responseEnvelope struct {
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"`
}
type incomingEnvelope struct {
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *RPCError `json:"error,omitempty"`
}
type rpcResult struct {
result json.RawMessage
err error
}