521 lines
12 KiB
Go
521 lines
12 KiB
Go
package codexapp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type Client struct {
|
|
socketPath string
|
|
version string
|
|
|
|
nextID int64
|
|
|
|
mu sync.Mutex
|
|
conn *websocket.Conn
|
|
pending map[int64]chan rpcResult
|
|
events chan Event
|
|
connected bool
|
|
}
|
|
|
|
type Event struct {
|
|
ID *int64
|
|
Method string
|
|
Params json.RawMessage
|
|
ServerRequest bool
|
|
Err error
|
|
}
|
|
|
|
type RPCError struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data json.RawMessage `json:"data,omitempty"`
|
|
}
|
|
|
|
func (e RPCError) Error() string {
|
|
if e.Code == 0 {
|
|
return e.Message
|
|
}
|
|
return fmt.Sprintf("%d: %s", e.Code, e.Message)
|
|
}
|
|
|
|
type Thread struct {
|
|
ID string `json:"id"`
|
|
SessionID string `json:"sessionId,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Preview string `json:"preview,omitempty"`
|
|
CWD string `json:"cwd,omitempty"`
|
|
}
|
|
|
|
type Turn struct {
|
|
ID string `json:"id"`
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
type InputItem struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
URL string `json:"url,omitempty"`
|
|
Path string `json:"path,omitempty"`
|
|
}
|
|
|
|
type Model struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
DisplayName string `json:"displayName"`
|
|
Description string `json:"description"`
|
|
IsDefault bool `json:"isDefault"`
|
|
DefaultReasoningEffort string `json:"defaultReasoningEffort"`
|
|
SupportedReasoningEfforts []ReasoningEffortOption `json:"supportedReasoningEfforts"`
|
|
}
|
|
|
|
type ReasoningEffortOption struct {
|
|
Description string `json:"description"`
|
|
ReasoningEffort string `json:"reasoningEffort"`
|
|
}
|
|
|
|
func New(socketPath, version string) *Client {
|
|
return &Client{
|
|
socketPath: socketPath,
|
|
version: version,
|
|
pending: make(map[int64]chan rpcResult),
|
|
events: make(chan Event, 128),
|
|
}
|
|
}
|
|
|
|
func (c *Client) Events() <-chan Event {
|
|
return c.events
|
|
}
|
|
|
|
func (c *Client) EnsureConnected(ctx context.Context) error {
|
|
c.mu.Lock()
|
|
connected := c.connected && c.conn != nil
|
|
c.mu.Unlock()
|
|
if connected {
|
|
return nil
|
|
}
|
|
return c.Connect(ctx)
|
|
}
|
|
|
|
func (c *Client) Connect(ctx context.Context) error {
|
|
c.mu.Lock()
|
|
if c.connected && c.conn != nil {
|
|
c.mu.Unlock()
|
|
return nil
|
|
}
|
|
c.mu.Unlock()
|
|
|
|
dialer := websocket.Dialer{
|
|
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
var d net.Dialer
|
|
return d.DialContext(ctx, "unix", c.socketPath)
|
|
},
|
|
}
|
|
conn, _, err := dialer.DialContext(ctx, "ws://codex-app-server/", http.Header{})
|
|
if err != nil {
|
|
return fmt.Errorf("connect codex app-server unix socket %s: %w", c.socketPath, err)
|
|
}
|
|
|
|
c.mu.Lock()
|
|
c.conn = conn
|
|
c.connected = true
|
|
c.mu.Unlock()
|
|
|
|
go c.readLoop(conn)
|
|
|
|
if err := c.initialize(ctx); err != nil {
|
|
_ = c.Close()
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) Close() error {
|
|
c.mu.Lock()
|
|
conn := c.conn
|
|
c.conn = nil
|
|
c.connected = false
|
|
c.mu.Unlock()
|
|
if conn == nil {
|
|
return nil
|
|
}
|
|
return conn.Close()
|
|
}
|
|
|
|
func (c *Client) initialize(ctx context.Context) error {
|
|
params := map[string]any{
|
|
"clientInfo": map[string]any{
|
|
"name": "codex_telegram_bot",
|
|
"title": "Codex Telegram Bot",
|
|
"version": c.version,
|
|
},
|
|
"capabilities": map[string]any{
|
|
"experimentalApi": true,
|
|
},
|
|
}
|
|
var ignored json.RawMessage
|
|
if err := c.call(ctx, "initialize", params, &ignored); err != nil {
|
|
return err
|
|
}
|
|
return c.notify("initialized", nil)
|
|
}
|
|
|
|
func threadWithCWD(thread Thread, cwd string) Thread {
|
|
if thread.CWD == "" {
|
|
thread.CWD = cwd
|
|
}
|
|
return thread
|
|
}
|
|
|
|
func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
params := map[string]any{
|
|
"cwd": cwd,
|
|
"approvalPolicy": "on-request",
|
|
"sandbox": threadSandbox(sandbox),
|
|
"serviceName": "codex_telegram_bot",
|
|
}
|
|
if model != "" {
|
|
params["model"] = model
|
|
}
|
|
var result struct {
|
|
Thread Thread `json:"thread"`
|
|
CWD string `json:"cwd"`
|
|
}
|
|
if err := c.call(ctx, "thread/start", params, &result); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
return threadWithCWD(result.Thread, result.CWD), nil
|
|
}
|
|
|
|
func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
var result struct {
|
|
Thread Thread `json:"thread"`
|
|
CWD string `json:"cwd"`
|
|
}
|
|
if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
return threadWithCWD(result.Thread, result.CWD), nil
|
|
}
|
|
|
|
func (c *Client) ReadThread(ctx context.Context, threadID string) (Thread, error) {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
var result struct {
|
|
Thread Thread `json:"thread"`
|
|
CWD string `json:"cwd"`
|
|
}
|
|
if err := c.call(ctx, "thread/read", map[string]any{
|
|
"threadId": threadID,
|
|
"includeTurns": false,
|
|
}, &result); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
return threadWithCWD(result.Thread, result.CWD), nil
|
|
}
|
|
|
|
func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
var result struct {
|
|
Thread Thread `json:"thread"`
|
|
CWD string `json:"cwd"`
|
|
}
|
|
if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil {
|
|
return Thread{}, err
|
|
}
|
|
return threadWithCWD(result.Thread, result.CWD), nil
|
|
}
|
|
|
|
func (c *Client) ArchiveThread(ctx context.Context, threadID string) error {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return err
|
|
}
|
|
var ignored json.RawMessage
|
|
return c.call(ctx, "thread/archive", map[string]any{"threadId": threadID}, &ignored)
|
|
}
|
|
|
|
func (c *Client) SetThreadName(ctx context.Context, threadID, name string) error {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return err
|
|
}
|
|
var ignored json.RawMessage
|
|
return c.call(ctx, "thread/name/set", map[string]any{
|
|
"threadId": threadID,
|
|
"name": name,
|
|
}, &ignored)
|
|
}
|
|
|
|
func (c *Client) StartTurn(ctx context.Context, threadID, cwd, model, reasoningEffort, sandbox string, input []InputItem) (Turn, error) {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return Turn{}, err
|
|
}
|
|
params := map[string]any{
|
|
"threadId": threadID,
|
|
"input": input,
|
|
"approvalPolicy": "on-request",
|
|
}
|
|
if strings.TrimSpace(cwd) != "" {
|
|
params["cwd"] = cwd
|
|
params["sandboxPolicy"] = SandboxPolicy(sandbox, cwd)
|
|
}
|
|
if model != "" {
|
|
params["model"] = model
|
|
}
|
|
if reasoningEffort != "" {
|
|
params["effort"] = reasoningEffort
|
|
}
|
|
var result struct {
|
|
Turn Turn `json:"turn"`
|
|
}
|
|
if err := c.call(ctx, "turn/start", params, &result); err != nil {
|
|
return Turn{}, err
|
|
}
|
|
return result.Turn, nil
|
|
}
|
|
|
|
func (c *Client) InterruptTurn(ctx context.Context, threadID, turnID string) error {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return err
|
|
}
|
|
var ignored json.RawMessage
|
|
return c.call(ctx, "turn/interrupt", map[string]any{
|
|
"threadId": threadID,
|
|
"turnId": turnID,
|
|
}, &ignored)
|
|
}
|
|
|
|
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
var result struct {
|
|
Data []Model `json:"data"`
|
|
}
|
|
if err := c.call(ctx, "model/list", map[string]any{"limit": 50, "includeHidden": false}, &result); err != nil {
|
|
return nil, err
|
|
}
|
|
return result.Data, nil
|
|
}
|
|
|
|
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error {
|
|
if err := c.EnsureConnected(ctx); err != nil {
|
|
return err
|
|
}
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- c.write(responseEnvelope{ID: requestID, Result: result})
|
|
}()
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-done:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (c *Client) call(ctx context.Context, method string, params any, result any) error {
|
|
id := atomic.AddInt64(&c.nextID, 1)
|
|
ch := make(chan rpcResult, 1)
|
|
|
|
c.mu.Lock()
|
|
if c.conn == nil {
|
|
c.mu.Unlock()
|
|
return errors.New("codex app-server is not connected")
|
|
}
|
|
c.pending[id] = ch
|
|
c.mu.Unlock()
|
|
|
|
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
|
|
c.mu.Lock()
|
|
delete(c.pending, id)
|
|
c.mu.Unlock()
|
|
return err
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
c.mu.Lock()
|
|
delete(c.pending, id)
|
|
c.mu.Unlock()
|
|
return ctx.Err()
|
|
case rpc := <-ch:
|
|
if rpc.err != nil {
|
|
return rpc.err
|
|
}
|
|
if result == nil || len(rpc.result) == 0 {
|
|
return nil
|
|
}
|
|
return json.Unmarshal(rpc.result, result)
|
|
}
|
|
}
|
|
|
|
func (c *Client) notify(method string, params any) error {
|
|
return c.write(notificationEnvelope{Method: method, Params: params})
|
|
}
|
|
|
|
func (c *Client) write(message any) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.conn == nil {
|
|
return errors.New("codex app-server is not connected")
|
|
}
|
|
return c.conn.WriteJSON(message)
|
|
}
|
|
|
|
func (c *Client) readLoop(conn *websocket.Conn) {
|
|
for {
|
|
var env incomingEnvelope
|
|
if err := conn.ReadJSON(&env); err != nil {
|
|
c.failConnection(conn, err)
|
|
return
|
|
}
|
|
if env.Method != "" && env.ID != nil {
|
|
c.events <- Event{
|
|
ID: env.ID,
|
|
Method: env.Method,
|
|
Params: env.Params,
|
|
ServerRequest: true,
|
|
}
|
|
continue
|
|
}
|
|
if env.ID != nil {
|
|
c.completeCall(*env.ID, env.Result, env.Error)
|
|
continue
|
|
}
|
|
if env.Method != "" {
|
|
c.events <- Event{
|
|
Method: env.Method,
|
|
Params: env.Params,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) {
|
|
c.mu.Lock()
|
|
ch := c.pending[id]
|
|
delete(c.pending, id)
|
|
c.mu.Unlock()
|
|
if ch == nil {
|
|
return
|
|
}
|
|
if rpcErr != nil {
|
|
ch <- rpcResult{err: *rpcErr}
|
|
return
|
|
}
|
|
ch <- rpcResult{result: result}
|
|
}
|
|
|
|
func (c *Client) failConnection(conn *websocket.Conn, err error) {
|
|
c.mu.Lock()
|
|
if c.conn == conn {
|
|
c.conn = nil
|
|
c.connected = false
|
|
}
|
|
pending := c.pending
|
|
c.pending = make(map[int64]chan rpcResult)
|
|
c.mu.Unlock()
|
|
|
|
for _, ch := range pending {
|
|
ch <- rpcResult{err: err}
|
|
}
|
|
c.events <- Event{Method: "connection/closed", Err: err}
|
|
}
|
|
|
|
func DecodeParams[T any](event Event) (T, error) {
|
|
var value T
|
|
if len(event.Params) == 0 {
|
|
return value, nil
|
|
}
|
|
err := json.Unmarshal(event.Params, &value)
|
|
return value, err
|
|
}
|
|
|
|
func NormalizeSandbox(value string) (string, error) {
|
|
switch strings.ToLower(strings.TrimSpace(value)) {
|
|
case "", "workspace-write", "workspacewrite":
|
|
return "workspace-write", nil
|
|
case "read-only", "readonly":
|
|
return "read-only", nil
|
|
case "danger-full-access", "dangerfullaccess":
|
|
return "danger-full-access", nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported sandbox %q", value)
|
|
}
|
|
}
|
|
|
|
func SandboxPolicy(sandbox, cwd string) map[string]any {
|
|
normalized, err := NormalizeSandbox(sandbox)
|
|
if err != nil {
|
|
normalized = "workspace-write"
|
|
}
|
|
switch normalized {
|
|
case "read-only":
|
|
return map[string]any{"type": "readOnly"}
|
|
case "danger-full-access":
|
|
return map[string]any{"type": "dangerFullAccess"}
|
|
default:
|
|
return map[string]any{
|
|
"type": "workspaceWrite",
|
|
"writableRoots": []string{cwd},
|
|
"networkAccess": false,
|
|
}
|
|
}
|
|
}
|
|
|
|
func threadSandbox(sandbox string) string {
|
|
normalized, err := NormalizeSandbox(sandbox)
|
|
if err != nil {
|
|
return "workspace-write"
|
|
}
|
|
return normalized
|
|
}
|
|
|
|
type requestEnvelope struct {
|
|
ID int64 `json:"id"`
|
|
Method string `json:"method"`
|
|
Params any `json:"params,omitempty"`
|
|
}
|
|
|
|
type notificationEnvelope struct {
|
|
Method string `json:"method"`
|
|
Params any `json:"params,omitempty"`
|
|
}
|
|
|
|
type responseEnvelope struct {
|
|
ID int64 `json:"id"`
|
|
Result any `json:"result,omitempty"`
|
|
Error any `json:"error,omitempty"`
|
|
}
|
|
|
|
type incomingEnvelope struct {
|
|
ID *int64 `json:"id,omitempty"`
|
|
Method string `json:"method,omitempty"`
|
|
Params json.RawMessage `json:"params,omitempty"`
|
|
Result json.RawMessage `json:"result,omitempty"`
|
|
Error *RPCError `json:"error,omitempty"`
|
|
}
|
|
|
|
type rpcResult struct {
|
|
result json.RawMessage
|
|
err error
|
|
}
|