211 lines
5.1 KiB
Go
211 lines
5.1 KiB
Go
package codexapp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
socketPath := filepath.Join(t.TempDir(), "codex.sock")
|
|
serverDone := make(chan error, 1)
|
|
ln, err := net.Listen("unix", socketPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.Remove(socketPath)
|
|
|
|
upgrader := websocket.Upgrader{}
|
|
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
var initialize map[string]any
|
|
if err := conn.ReadJSON(&initialize); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
if initialize["method"] != "initialize" {
|
|
serverDone <- unexpectedMessage("initialize", initialize["method"])
|
|
return
|
|
}
|
|
if err := conn.WriteJSON(map[string]any{"id": initialize["id"], "result": map[string]any{"userAgent": "test"}}); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
|
|
var initialized map[string]any
|
|
if err := conn.ReadJSON(&initialized); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
if initialized["method"] != "initialized" {
|
|
serverDone <- unexpectedMessage("initialized", initialized["method"])
|
|
return
|
|
}
|
|
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"id": 99,
|
|
"method": "item/commandExecution/requestApproval",
|
|
"params": map[string]any{"threadId": "thr_1"},
|
|
}); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
|
|
var start map[string]any
|
|
if err := conn.ReadJSON(&start); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
if start["method"] != "thread/start" {
|
|
serverDone <- unexpectedMessage("thread/start", start["method"])
|
|
return
|
|
}
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"id": start["id"],
|
|
"result": map[string]any{
|
|
"thread": map[string]any{"id": "thr_1", "preview": "test"},
|
|
},
|
|
}); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
|
|
var readThread map[string]any
|
|
if err := conn.ReadJSON(&readThread); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
if readThread["method"] != "thread/read" {
|
|
serverDone <- unexpectedMessage("thread/read", readThread["method"])
|
|
return
|
|
}
|
|
readParams := readThread["params"].(map[string]any)
|
|
if readParams["threadId"] != "thr_1" || readParams["includeTurns"] != false {
|
|
payload, _ := json.Marshal(readParams)
|
|
serverDone <- unexpectedMessage("thread/read params", string(payload))
|
|
return
|
|
}
|
|
if err := conn.WriteJSON(map[string]any{
|
|
"id": readThread["id"],
|
|
"result": map[string]any{
|
|
"thread": map[string]any{"id": "thr_1", "name": "Read title", "preview": "test"},
|
|
},
|
|
}); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
|
|
var setName map[string]any
|
|
if err := conn.ReadJSON(&setName); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
if setName["method"] != "thread/name/set" {
|
|
serverDone <- unexpectedMessage("thread/name/set", setName["method"])
|
|
return
|
|
}
|
|
params := setName["params"].(map[string]any)
|
|
if params["threadId"] != "thr_1" || params["name"] != "Short title" {
|
|
payload, _ := json.Marshal(params)
|
|
serverDone <- unexpectedMessage("thread/name/set params", string(payload))
|
|
return
|
|
}
|
|
if err := conn.WriteJSON(map[string]any{"id": setName["id"], "result": map[string]any{}}); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
|
|
var response map[string]any
|
|
if err := conn.ReadJSON(&response); err != nil {
|
|
serverDone <- err
|
|
return
|
|
}
|
|
if response["id"].(float64) != 99 || response["result"] != "accept" {
|
|
payload, _ := json.Marshal(response)
|
|
serverDone <- unexpectedMessage("approval response", string(payload))
|
|
return
|
|
}
|
|
serverDone <- nil
|
|
})}
|
|
defer server.Close()
|
|
go func() {
|
|
_ = server.Serve(ln)
|
|
}()
|
|
|
|
client := New(socketPath, "test")
|
|
if err := client.Connect(ctx); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer client.Close()
|
|
|
|
select {
|
|
case event := <-client.Events():
|
|
if !event.ServerRequest || event.ID == nil || *event.ID != 99 {
|
|
t.Fatalf("unexpected event: %+v", event)
|
|
}
|
|
case <-ctx.Done():
|
|
t.Fatal(ctx.Err())
|
|
}
|
|
|
|
thread, err := client.StartThread(ctx, "/tmp/project", "", "workspace-write")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if thread.ID != "thr_1" {
|
|
t.Fatalf("unexpected thread: %+v", thread)
|
|
}
|
|
readThread, err := client.ReadThread(ctx, "thr_1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if readThread.ID != "thr_1" || readThread.Name != "Read title" {
|
|
t.Fatalf("unexpected read thread: %+v", readThread)
|
|
}
|
|
if err := client.SetThreadName(ctx, "thr_1", "Short title"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
select {
|
|
case err := <-serverDone:
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
case <-ctx.Done():
|
|
t.Fatal(ctx.Err())
|
|
}
|
|
}
|
|
|
|
type errUnexpected string
|
|
|
|
func (e errUnexpected) Error() string {
|
|
return "unexpected " + string(e)
|
|
}
|
|
|
|
func unexpectedMessage(want string, got any) error {
|
|
return errUnexpected("message: want " + want + ", got " + jsonString(got))
|
|
}
|
|
|
|
func jsonString(value any) string {
|
|
data, _ := json.Marshal(value)
|
|
return string(data)
|
|
}
|