Initial codex telegram bot source
This commit is contained in:
479
internal/codexapp/client.go
Normal file
479
internal/codexapp/client.go
Normal file
@@ -0,0 +1,479 @@
|
||||
package codexapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
socketPath string
|
||||
version string
|
||||
|
||||
nextID int64
|
||||
|
||||
mu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
pending map[int64]chan rpcResult
|
||||
events chan Event
|
||||
connected bool
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
ID *int64
|
||||
Method string
|
||||
Params json.RawMessage
|
||||
ServerRequest bool
|
||||
Err error
|
||||
}
|
||||
|
||||
type RPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func (e RPCError) Error() string {
|
||||
if e.Code == 0 {
|
||||
return e.Message
|
||||
}
|
||||
return fmt.Sprintf("%d: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
type Thread struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Preview string `json:"preview,omitempty"`
|
||||
}
|
||||
|
||||
type Turn struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type InputItem struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description"`
|
||||
IsDefault bool `json:"isDefault"`
|
||||
DefaultReasoningEffort string `json:"defaultReasoningEffort"`
|
||||
SupportedReasoningEfforts []ReasoningEffortOption `json:"supportedReasoningEfforts"`
|
||||
}
|
||||
|
||||
type ReasoningEffortOption struct {
|
||||
Description string `json:"description"`
|
||||
ReasoningEffort string `json:"reasoningEffort"`
|
||||
}
|
||||
|
||||
func New(socketPath, version string) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
version: version,
|
||||
pending: make(map[int64]chan rpcResult),
|
||||
events: make(chan Event, 128),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Events() <-chan Event {
|
||||
return c.events
|
||||
}
|
||||
|
||||
func (c *Client) EnsureConnected(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
connected := c.connected && c.conn != nil
|
||||
c.mu.Unlock()
|
||||
if connected {
|
||||
return nil
|
||||
}
|
||||
return c.Connect(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.connected && c.conn != nil {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "unix", c.socketPath)
|
||||
},
|
||||
}
|
||||
conn, _, err := dialer.DialContext(ctx, "ws://codex-app-server/", http.Header{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect codex app-server unix socket %s: %w", c.socketPath, err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.connected = true
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.readLoop(conn)
|
||||
|
||||
if err := c.initialize(ctx); err != nil {
|
||||
_ = c.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.conn = nil
|
||||
c.connected = false
|
||||
c.mu.Unlock()
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
func (c *Client) initialize(ctx context.Context) error {
|
||||
params := map[string]any{
|
||||
"clientInfo": map[string]any{
|
||||
"name": "codex_telegram_bot",
|
||||
"title": "Codex Telegram Bot",
|
||||
"version": c.version,
|
||||
},
|
||||
"capabilities": map[string]any{
|
||||
"experimentalApi": true,
|
||||
},
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
if err := c.call(ctx, "initialize", params, &ignored); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.notify("initialized", nil)
|
||||
}
|
||||
|
||||
func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
params := map[string]any{
|
||||
"cwd": cwd,
|
||||
"approvalPolicy": "on-request",
|
||||
"sandbox": threadSandbox(sandbox),
|
||||
"serviceName": "codex_telegram_bot",
|
||||
}
|
||||
if model != "" {
|
||||
params["model"] = model
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/start", params, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
}
|
||||
|
||||
func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
}
|
||||
|
||||
func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
}
|
||||
|
||||
func (c *Client) ArchiveThread(ctx context.Context, threadID string) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
return c.call(ctx, "thread/archive", map[string]any{"threadId": threadID}, &ignored)
|
||||
}
|
||||
|
||||
func (c *Client) StartTurn(ctx context.Context, threadID, cwd, model, reasoningEffort, sandbox string, input []InputItem) (Turn, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Turn{}, err
|
||||
}
|
||||
params := map[string]any{
|
||||
"threadId": threadID,
|
||||
"input": input,
|
||||
"cwd": cwd,
|
||||
"approvalPolicy": "on-request",
|
||||
"sandboxPolicy": SandboxPolicy(sandbox, cwd),
|
||||
}
|
||||
if model != "" {
|
||||
params["model"] = model
|
||||
}
|
||||
if reasoningEffort != "" {
|
||||
params["effort"] = reasoningEffort
|
||||
}
|
||||
var result struct {
|
||||
Turn Turn `json:"turn"`
|
||||
}
|
||||
if err := c.call(ctx, "turn/start", params, &result); err != nil {
|
||||
return Turn{}, err
|
||||
}
|
||||
return result.Turn, nil
|
||||
}
|
||||
|
||||
func (c *Client) InterruptTurn(ctx context.Context, threadID, turnID string) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
return c.call(ctx, "turn/interrupt", map[string]any{
|
||||
"threadId": threadID,
|
||||
"turnId": turnID,
|
||||
}, &ignored)
|
||||
}
|
||||
|
||||
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result struct {
|
||||
Data []Model `json:"data"`
|
||||
}
|
||||
if err := c.call(ctx, "model/list", map[string]any{"limit": 50, "includeHidden": false}, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.write(responseEnvelope{ID: requestID, Result: result})
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) call(ctx context.Context, method string, params any, result any) error {
|
||||
id := atomic.AddInt64(&c.nextID, 1)
|
||||
ch := make(chan rpcResult, 1)
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn == nil {
|
||||
c.mu.Unlock()
|
||||
return errors.New("codex app-server is not connected")
|
||||
}
|
||||
c.pending[id] = ch
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
return ctx.Err()
|
||||
case rpc := <-ch:
|
||||
if rpc.err != nil {
|
||||
return rpc.err
|
||||
}
|
||||
if result == nil || len(rpc.result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(rpc.result, result)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) notify(method string, params any) error {
|
||||
return c.write(notificationEnvelope{Method: method, Params: params})
|
||||
}
|
||||
|
||||
func (c *Client) write(message any) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.conn == nil {
|
||||
return errors.New("codex app-server is not connected")
|
||||
}
|
||||
return c.conn.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(conn *websocket.Conn) {
|
||||
for {
|
||||
var env incomingEnvelope
|
||||
if err := conn.ReadJSON(&env); err != nil {
|
||||
c.failConnection(conn, err)
|
||||
return
|
||||
}
|
||||
if env.Method != "" && env.ID != nil {
|
||||
c.events <- Event{
|
||||
ID: env.ID,
|
||||
Method: env.Method,
|
||||
Params: env.Params,
|
||||
ServerRequest: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
if env.ID != nil {
|
||||
c.completeCall(*env.ID, env.Result, env.Error)
|
||||
continue
|
||||
}
|
||||
if env.Method != "" {
|
||||
c.events <- Event{
|
||||
Method: env.Method,
|
||||
Params: env.Params,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) {
|
||||
c.mu.Lock()
|
||||
ch := c.pending[id]
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
if rpcErr != nil {
|
||||
ch <- rpcResult{err: *rpcErr}
|
||||
return
|
||||
}
|
||||
ch <- rpcResult{result: result}
|
||||
}
|
||||
|
||||
func (c *Client) failConnection(conn *websocket.Conn, err error) {
|
||||
c.mu.Lock()
|
||||
if c.conn == conn {
|
||||
c.conn = nil
|
||||
c.connected = false
|
||||
}
|
||||
pending := c.pending
|
||||
c.pending = make(map[int64]chan rpcResult)
|
||||
c.mu.Unlock()
|
||||
|
||||
for _, ch := range pending {
|
||||
ch <- rpcResult{err: err}
|
||||
}
|
||||
c.events <- Event{Method: "connection/closed", Err: err}
|
||||
}
|
||||
|
||||
func DecodeParams[T any](event Event) (T, error) {
|
||||
var value T
|
||||
if len(event.Params) == 0 {
|
||||
return value, nil
|
||||
}
|
||||
err := json.Unmarshal(event.Params, &value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
func NormalizeSandbox(value string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "", "workspace-write", "workspacewrite":
|
||||
return "workspace-write", nil
|
||||
case "read-only", "readonly":
|
||||
return "read-only", nil
|
||||
case "danger-full-access", "dangerfullaccess":
|
||||
return "danger-full-access", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported sandbox %q", value)
|
||||
}
|
||||
}
|
||||
|
||||
func SandboxPolicy(sandbox, cwd string) map[string]any {
|
||||
normalized, err := NormalizeSandbox(sandbox)
|
||||
if err != nil {
|
||||
normalized = "workspace-write"
|
||||
}
|
||||
switch normalized {
|
||||
case "read-only":
|
||||
return map[string]any{"type": "readOnly"}
|
||||
case "danger-full-access":
|
||||
return map[string]any{"type": "dangerFullAccess"}
|
||||
default:
|
||||
return map[string]any{
|
||||
"type": "workspaceWrite",
|
||||
"writableRoots": []string{cwd},
|
||||
"networkAccess": false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func threadSandbox(sandbox string) string {
|
||||
normalized, err := NormalizeSandbox(sandbox)
|
||||
if err != nil {
|
||||
return "workspace-write"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
type requestEnvelope struct {
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type notificationEnvelope struct {
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type responseEnvelope struct {
|
||||
ID int64 `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type incomingEnvelope struct {
|
||||
ID *int64 `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *RPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type rpcResult struct {
|
||||
result json.RawMessage
|
||||
err error
|
||||
}
|
||||
155
internal/codexapp/client_test.go
Normal file
155
internal/codexapp/client_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package codexapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "codex.sock")
|
||||
serverDone := make(chan error, 1)
|
||||
ln, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(socketPath)
|
||||
|
||||
upgrader := websocket.Upgrader{}
|
||||
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var initialize map[string]any
|
||||
if err := conn.ReadJSON(&initialize); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if initialize["method"] != "initialize" {
|
||||
serverDone <- unexpectedMessage("initialize", initialize["method"])
|
||||
return
|
||||
}
|
||||
if err := conn.WriteJSON(map[string]any{"id": initialize["id"], "result": map[string]any{"userAgent": "test"}}); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
|
||||
var initialized map[string]any
|
||||
if err := conn.ReadJSON(&initialized); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if initialized["method"] != "initialized" {
|
||||
serverDone <- unexpectedMessage("initialized", initialized["method"])
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"id": 99,
|
||||
"method": "item/commandExecution/requestApproval",
|
||||
"params": map[string]any{"threadId": "thr_1"},
|
||||
}); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
|
||||
var start map[string]any
|
||||
if err := conn.ReadJSON(&start); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if start["method"] != "thread/start" {
|
||||
serverDone <- unexpectedMessage("thread/start", start["method"])
|
||||
return
|
||||
}
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"id": start["id"],
|
||||
"result": map[string]any{
|
||||
"thread": map[string]any{"id": "thr_1", "preview": "test"},
|
||||
},
|
||||
}); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
|
||||
var response map[string]any
|
||||
if err := conn.ReadJSON(&response); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if response["id"].(float64) != 99 || response["result"] != "accept" {
|
||||
payload, _ := json.Marshal(response)
|
||||
serverDone <- unexpectedMessage("approval response", string(payload))
|
||||
return
|
||||
}
|
||||
serverDone <- nil
|
||||
})}
|
||||
defer server.Close()
|
||||
go func() {
|
||||
_ = server.Serve(ln)
|
||||
}()
|
||||
|
||||
client := New(socketPath, "test")
|
||||
if err := client.Connect(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
select {
|
||||
case event := <-client.Events():
|
||||
if !event.ServerRequest || event.ID == nil || *event.ID != 99 {
|
||||
t.Fatalf("unexpected event: %+v", event)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
|
||||
thread, err := client.StartThread(ctx, "/tmp/project", "", "workspace-write")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if thread.ID != "thr_1" {
|
||||
t.Fatalf("unexpected thread: %+v", thread)
|
||||
}
|
||||
if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-serverDone:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
type errUnexpected string
|
||||
|
||||
func (e errUnexpected) Error() string {
|
||||
return "unexpected " + string(e)
|
||||
}
|
||||
|
||||
func unexpectedMessage(want string, got any) error {
|
||||
return errUnexpected("message: want " + want + ", got " + jsonString(got))
|
||||
}
|
||||
|
||||
func jsonString(value any) string {
|
||||
data, _ := json.Marshal(value)
|
||||
return string(data)
|
||||
}
|
||||
18
internal/codexapp/turn.go
Normal file
18
internal/codexapp/turn.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package codexapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func (c *Client) SteerTurn(ctx context.Context, threadID, expectedTurnID string, input []InputItem) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
return c.call(ctx, "turn/steer", map[string]any{
|
||||
"threadId": threadID,
|
||||
"expectedTurnId": expectedTurnID,
|
||||
"input": input,
|
||||
}, &ignored)
|
||||
}
|
||||
Reference in New Issue
Block a user