Files
codex-telegram-bot/internal/codexapp/client_test.go
2026-05-21 13:05:53 +00:00

213 lines
5.3 KiB
Go

package codexapp
import (
"context"
"encoding/json"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/gorilla/websocket"
)
func TestClientWebSocketUnixJSONRPC(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
socketPath := filepath.Join(t.TempDir(), "codex.sock")
projectCWD := filepath.Join(t.TempDir(), "project")
serverDone := make(chan error, 1)
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatal(err)
}
defer os.Remove(socketPath)
upgrader := websocket.Upgrader{}
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
serverDone <- err
return
}
defer conn.Close()
var initialize map[string]any
if err := conn.ReadJSON(&initialize); err != nil {
serverDone <- err
return
}
if initialize["method"] != "initialize" {
serverDone <- unexpectedMessage("initialize", initialize["method"])
return
}
if err := conn.WriteJSON(map[string]any{"id": initialize["id"], "result": map[string]any{"userAgent": "test"}}); err != nil {
serverDone <- err
return
}
var initialized map[string]any
if err := conn.ReadJSON(&initialized); err != nil {
serverDone <- err
return
}
if initialized["method"] != "initialized" {
serverDone <- unexpectedMessage("initialized", initialized["method"])
return
}
if err := conn.WriteJSON(map[string]any{
"id": 99,
"method": "item/commandExecution/requestApproval",
"params": map[string]any{"threadId": "thr_1"},
}); err != nil {
serverDone <- err
return
}
var start map[string]any
if err := conn.ReadJSON(&start); err != nil {
serverDone <- err
return
}
if start["method"] != "thread/start" {
serverDone <- unexpectedMessage("thread/start", start["method"])
return
}
if err := conn.WriteJSON(map[string]any{
"id": start["id"],
"result": map[string]any{
"cwd": projectCWD,
"thread": map[string]any{"id": "thr_1", "preview": "test"},
},
}); err != nil {
serverDone <- err
return
}
var readThread map[string]any
if err := conn.ReadJSON(&readThread); err != nil {
serverDone <- err
return
}
if readThread["method"] != "thread/read" {
serverDone <- unexpectedMessage("thread/read", readThread["method"])
return
}
readParams := readThread["params"].(map[string]any)
if readParams["threadId"] != "thr_1" || readParams["includeTurns"] != false {
payload, _ := json.Marshal(readParams)
serverDone <- unexpectedMessage("thread/read params", string(payload))
return
}
if err := conn.WriteJSON(map[string]any{
"id": readThread["id"],
"result": map[string]any{
"thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test", "cwd": projectCWD},
},
}); err != nil {
serverDone <- err
return
}
var setName map[string]any
if err := conn.ReadJSON(&setName); err != nil {
serverDone <- err
return
}
if setName["method"] != "thread/name/set" {
serverDone <- unexpectedMessage("thread/name/set", setName["method"])
return
}
params := setName["params"].(map[string]any)
if params["threadId"] != "thr_1" || params["name"] != "Short title" {
payload, _ := json.Marshal(params)
serverDone <- unexpectedMessage("thread/name/set params", string(payload))
return
}
if err := conn.WriteJSON(map[string]any{"id": setName["id"], "result": map[string]any{}}); err != nil {
serverDone <- err
return
}
var response map[string]any
if err := conn.ReadJSON(&response); err != nil {
serverDone <- err
return
}
if response["id"].(float64) != 99 || response["result"] != "accept" {
payload, _ := json.Marshal(response)
serverDone <- unexpectedMessage("approval response", string(payload))
return
}
serverDone <- nil
})}
defer server.Close()
go func() {
_ = server.Serve(ln)
}()
client := New(socketPath, "test")
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
defer client.Close()
select {
case event := <-client.Events():
if !event.ServerRequest || event.ID == nil || *event.ID != 99 {
t.Fatalf("unexpected event: %+v", event)
}
case <-ctx.Done():
t.Fatal(ctx.Err())
}
thread, err := client.StartThread(ctx, projectCWD, "", "workspace-write")
if err != nil {
t.Fatal(err)
}
if thread.ID != "thr_1" || thread.CWD != projectCWD {
t.Fatalf("unexpected thread: %+v", thread)
}
readThread, err := client.ReadThread(ctx, "thr_1")
if err != nil {
t.Fatal(err)
}
if readThread.ID != "thr_1" || readThread.Name != "Read title" || readThread.CWD != projectCWD {
t.Fatalf("unexpected read thread: %+v", readThread)
}
if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil {
t.Fatal(err)
}
if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil {
t.Fatal(err)
}
select {
case err := <-serverDone:
if err != nil {
t.Fatal(err)
}
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}
type errUnexpected string
func (e errUnexpected) Error() string {
return "unexpected " + string(e)
}
func unexpectedMessage(want string, got any) error {
return errUnexpected("message: want " + want + ", got " + jsonString(got))
}
func jsonString(value any) string {
data, _ := json.Marshal(value)
return string(data)
}