Files
Codex 44384a90c7 Use default Codex approval reviewer
Stop forcing approvalsReviewer=user on thread and turn start so reviewer routing follows Codex app-server defaults, matching CLI behavior.
2026-05-28 10:23:59 +00:00

631 lines
15 KiB
Go

package codexapp
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/gorilla/websocket"
)
type Client struct {
socketPath string
version string
nextID int64
mu sync.Mutex
conn *websocket.Conn
pending map[string]chan rpcResult
events chan Event
connected bool
logger *log.Logger
}
type Event struct {
ID *RequestID
Method string
Params json.RawMessage
ServerRequest bool
Err error
}
type RequestID struct {
value any
key string
}
func IntRequestID(id int64) RequestID {
return RequestID{value: id, key: "i:" + strconv.FormatInt(id, 10)}
}
func ParseRequestIDKey(key string) (RequestID, error) {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "i:") {
id, err := strconv.ParseInt(strings.TrimPrefix(key, "i:"), 10, 64)
if err != nil {
return RequestID{}, err
}
return IntRequestID(id), nil
}
if strings.HasPrefix(key, "s:") {
value := strings.TrimPrefix(key, "s:")
return RequestID{value: value, key: "s:" + value}, nil
}
if id, err := strconv.ParseInt(key, 10, 64); err == nil {
return IntRequestID(id), nil
}
if key == "" {
return RequestID{}, errors.New("request id is empty")
}
return RequestID{value: key, key: "s:" + key}, nil
}
func ParseRequestID(raw json.RawMessage) (RequestID, bool, error) {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" || trimmed == "null" {
return RequestID{}, false, nil
}
if strings.HasPrefix(trimmed, "\"") {
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return RequestID{value: value, key: "s:" + value}, true, nil
}
var value int64
if err := json.Unmarshal(raw, &value); err != nil {
return RequestID{}, false, err
}
return IntRequestID(value), true, nil
}
func (id RequestID) Key() string {
return id.key
}
func (id RequestID) Value() any {
return id.value
}
type RPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data,omitempty"`
}
func (e RPCError) Error() string {
if e.Code == 0 {
return e.Message
}
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}
type Thread struct {
ID string `json:"id"`
SessionID string `json:"sessionId,omitempty"`
Name string `json:"name,omitempty"`
Preview string `json:"preview,omitempty"`
CWD string `json:"cwd,omitempty"`
}
type Turn struct {
ID string `json:"id"`
Status string `json:"status"`
}
type InputItem struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
URL string `json:"url,omitempty"`
Path string `json:"path,omitempty"`
}
type Model struct {
ID string `json:"id"`
Model string `json:"model"`
DisplayName string `json:"displayName"`
Description string `json:"description"`
IsDefault bool `json:"isDefault"`
DefaultReasoningEffort string `json:"defaultReasoningEffort"`
SupportedReasoningEfforts []ReasoningEffortOption `json:"supportedReasoningEfforts"`
}
type ReasoningEffortOption struct {
Description string `json:"description"`
ReasoningEffort string `json:"reasoningEffort"`
}
func New(socketPath, version string) *Client {
return &Client{
socketPath: socketPath,
version: version,
pending: make(map[string]chan rpcResult),
events: make(chan Event, 128),
}
}
func (c *Client) SetLogger(logger *log.Logger) {
c.mu.Lock()
defer c.mu.Unlock()
c.logger = logger
}
func (c *Client) Events() <-chan Event {
return c.events
}
func (c *Client) EnsureConnected(ctx context.Context) error {
c.mu.Lock()
connected := c.connected && c.conn != nil
c.mu.Unlock()
if connected {
return nil
}
return c.Connect(ctx)
}
func (c *Client) Connect(ctx context.Context) error {
c.mu.Lock()
if c.connected && c.conn != nil {
c.mu.Unlock()
return nil
}
c.mu.Unlock()
dialer := websocket.Dialer{
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", c.socketPath)
},
}
conn, _, err := dialer.DialContext(ctx, "ws://codex-app-server/", http.Header{})
if err != nil {
return fmt.Errorf("connect codex app-server unix socket %s: %w", c.socketPath, err)
}
c.mu.Lock()
c.conn = conn
c.connected = true
c.mu.Unlock()
go c.readLoop(conn)
if err := c.initialize(ctx); err != nil {
_ = c.Close()
return err
}
return nil
}
func (c *Client) Close() error {
c.mu.Lock()
conn := c.conn
c.conn = nil
c.connected = false
c.mu.Unlock()
if conn == nil {
return nil
}
return conn.Close()
}
func (c *Client) initialize(ctx context.Context) error {
params := map[string]any{
"clientInfo": map[string]any{
"name": "codex_telegram_bot",
"title": "Codex Telegram Bot",
"version": c.version,
},
"capabilities": map[string]any{
"experimentalApi": true,
},
}
var ignored json.RawMessage
if err := c.call(ctx, "initialize", params, &ignored); err != nil {
return err
}
return c.notify("initialized", nil)
}
func threadWithCWD(thread Thread, cwd string) Thread {
if thread.CWD == "" {
thread.CWD = cwd
}
return thread
}
func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
params := map[string]any{
"cwd": cwd,
"approvalPolicy": "on-request",
"sandbox": threadSandbox(sandbox),
"serviceName": "codex_telegram_bot",
}
if model != "" {
params["model"] = model
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/start", params, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/read", map[string]any{
"threadId": threadID,
"includeTurns": false,
}, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Thread{}, err
}
var result struct {
Thread Thread `json:"thread"`
CWD string `json:"cwd"`
}
if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil {
return Thread{}, err
}
return threadWithCWD(result.Thread, result.CWD), nil
}
func (c *Client) ArchiveThread(ctx context.Context, threadID string) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
var ignored json.RawMessage
return c.call(ctx, "thread/archive", map[string]any{"threadId": threadID}, &ignored)
}
func (c *Client) SetThreadName(ctx context.Context, threadID, name string) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
var ignored json.RawMessage
return c.call(ctx, "thread/name/set", map[string]any{
"threadId": threadID,
"name": name,
}, &ignored)
}
func (c *Client) StartTurn(ctx context.Context, threadID, cwd, model, reasoningEffort, sandbox string, input []InputItem) (Turn, error) {
if err := c.EnsureConnected(ctx); err != nil {
return Turn{}, err
}
params := map[string]any{
"threadId": threadID,
"input": input,
"approvalPolicy": "on-request",
}
if strings.TrimSpace(cwd) != "" {
params["cwd"] = cwd
params["sandboxPolicy"] = SandboxPolicy(sandbox, cwd)
}
if model != "" {
params["model"] = model
}
if reasoningEffort != "" {
params["effort"] = reasoningEffort
}
var result struct {
Turn Turn `json:"turn"`
}
if err := c.call(ctx, "turn/start", params, &result); err != nil {
return Turn{}, err
}
return result.Turn, nil
}
func (c *Client) InterruptTurn(ctx context.Context, threadID, turnID string) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
var ignored json.RawMessage
return c.call(ctx, "turn/interrupt", map[string]any{
"threadId": threadID,
"turnId": turnID,
}, &ignored)
}
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
if err := c.EnsureConnected(ctx); err != nil {
return nil, err
}
var result struct {
Data []Model `json:"data"`
}
if err := c.call(ctx, "model/list", map[string]any{"limit": 50, "includeHidden": false}, &result); err != nil {
return nil, err
}
return result.Data, nil
}
func (c *Client) RespondServerRequest(ctx context.Context, requestID RequestID, result any) error {
return c.respondServerRequest(ctx, responseEnvelope{ID: requestID.Value(), Result: result})
}
func (c *Client) RespondServerRequestError(ctx context.Context, requestID RequestID, code int, message string) error {
return c.respondServerRequest(ctx, responseEnvelope{
ID: requestID.Value(),
Error: RPCError{Code: code, Message: message},
})
}
func (c *Client) respondServerRequest(ctx context.Context, response responseEnvelope) error {
if err := c.EnsureConnected(ctx); err != nil {
return err
}
done := make(chan error, 1)
go func() {
done <- c.write(response)
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
func (c *Client) call(ctx context.Context, method string, params any, result any) error {
id := atomic.AddInt64(&c.nextID, 1)
ch := make(chan rpcResult, 1)
c.mu.Lock()
if c.conn == nil {
c.mu.Unlock()
return errors.New("codex app-server is not connected")
}
c.pending[IntRequestID(id).Key()] = ch
c.mu.Unlock()
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
c.mu.Lock()
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return err
}
select {
case <-ctx.Done():
c.mu.Lock()
delete(c.pending, IntRequestID(id).Key())
c.mu.Unlock()
return ctx.Err()
case rpc := <-ch:
if rpc.err != nil {
return rpc.err
}
if result == nil || len(rpc.result) == 0 {
return nil
}
return json.Unmarshal(rpc.result, result)
}
}
func (c *Client) notify(method string, params any) error {
return c.write(notificationEnvelope{Method: method, Params: params})
}
func (c *Client) write(message any) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return errors.New("codex app-server is not connected")
}
return c.conn.WriteJSON(message)
}
func (c *Client) readLoop(conn *websocket.Conn) {
for {
var env incomingEnvelope
if err := conn.ReadJSON(&env); err != nil {
c.failConnection(conn, err)
return
}
id, hasID, err := ParseRequestID(env.ID)
if err != nil {
c.failConnection(conn, fmt.Errorf("decode request id: %w", err))
return
}
if env.Method != "" && hasID {
c.logProtocolEvent("server request", env.Method, id.Key(), env.Params)
c.events <- Event{
ID: &id,
Method: env.Method,
Params: env.Params,
ServerRequest: true,
}
continue
}
if hasID {
c.completeCall(id, env.Result, env.Error)
continue
}
if env.Method != "" {
if shouldLogNotification(env.Method) {
c.logProtocolEvent("notification", env.Method, "", env.Params)
}
c.events <- Event{
Method: env.Method,
Params: env.Params,
}
}
}
}
func (c *Client) logProtocolEvent(kind, method, id string, params json.RawMessage) {
c.mu.Lock()
logger := c.logger
c.mu.Unlock()
if logger == nil {
return
}
if id != "" {
logger.Printf("codex %s: method=%s id=%s params_bytes=%d", kind, method, id, len(params))
return
}
logger.Printf("codex %s: method=%s params_bytes=%d", kind, method, len(params))
}
func shouldLogNotification(method string) bool {
switch method {
case "item/guardianApprovalReview/started", "item/guardianApprovalReview/completed", "guardianWarning", "serverRequest/resolved":
return true
default:
return false
}
}
func (c *Client) completeCall(id RequestID, result json.RawMessage, rpcErr *RPCError) {
c.mu.Lock()
ch := c.pending[id.Key()]
delete(c.pending, id.Key())
c.mu.Unlock()
if ch == nil {
return
}
if rpcErr != nil {
ch <- rpcResult{err: *rpcErr}
return
}
ch <- rpcResult{result: result}
}
func (c *Client) failConnection(conn *websocket.Conn, err error) {
c.mu.Lock()
if c.conn == conn {
c.conn = nil
c.connected = false
}
pending := c.pending
c.pending = make(map[string]chan rpcResult)
c.mu.Unlock()
for _, ch := range pending {
ch <- rpcResult{err: err}
}
c.events <- Event{Method: "connection/closed", Err: err}
}
func DecodeParams[T any](event Event) (T, error) {
var value T
if len(event.Params) == 0 {
return value, nil
}
err := json.Unmarshal(event.Params, &value)
return value, err
}
func NormalizeSandbox(value string) (string, error) {
switch strings.ToLower(strings.TrimSpace(value)) {
case "", "workspace-write", "workspacewrite":
return "workspace-write", nil
case "read-only", "readonly":
return "read-only", nil
case "danger-full-access", "dangerfullaccess":
return "danger-full-access", nil
default:
return "", fmt.Errorf("unsupported sandbox %q", value)
}
}
func SandboxPolicy(sandbox, cwd string) map[string]any {
normalized, err := NormalizeSandbox(sandbox)
if err != nil {
normalized = "workspace-write"
}
switch normalized {
case "read-only":
return map[string]any{"type": "readOnly"}
case "danger-full-access":
return map[string]any{"type": "dangerFullAccess"}
default:
return map[string]any{
"type": "workspaceWrite",
"writableRoots": []string{cwd},
"networkAccess": false,
}
}
}
func threadSandbox(sandbox string) string {
normalized, err := NormalizeSandbox(sandbox)
if err != nil {
return "workspace-write"
}
return normalized
}
type requestEnvelope struct {
ID int64 `json:"id"`
Method string `json:"method"`
Params any `json:"params,omitempty"`
}
type notificationEnvelope struct {
Method string `json:"method"`
Params any `json:"params,omitempty"`
}
type responseEnvelope struct {
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error any `json:"error,omitempty"`
}
type incomingEnvelope struct {
ID json.RawMessage `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *RPCError `json:"error,omitempty"`
}
type rpcResult struct {
result json.RawMessage
err error
}