73 lines
1.9 KiB
Go
73 lines
1.9 KiB
Go
package codexstate
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
func TestSetThreadCWD(t *testing.T) {
|
|
ctx := context.Background()
|
|
dir := t.TempDir()
|
|
rollout := filepath.Join(dir, "rollout.jsonl")
|
|
first := map[string]any{"type": "session_meta", "payload": map[string]any{"id": "thr_1", "cwd": "/old"}}
|
|
encoded, err := json.Marshal(first)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := os.WriteFile(rollout, []byte(string(encoded)+"\n{\"type\":\"other\"}\n"), 0o644); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
dbPath := filepath.Join(dir, "state_test.sqlite")
|
|
db, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := db.ExecContext(ctx, "CREATE TABLE threads (id TEXT PRIMARY KEY, cwd TEXT, rollout_path TEXT, updated_at INTEGER, updated_at_ms INTEGER)"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := db.ExecContext(ctx, "INSERT INTO threads (id, cwd, rollout_path) VALUES (?, ?, ?)", "thr_1", "/old", rollout); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
_ = db.Close()
|
|
|
|
result, err := SetThreadCWD(ctx, "", dbPath, "thr_1", "/new/path")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result.Before != "/old" || result.After != "/new/path" {
|
|
t.Fatalf("unexpected result: %+v", result)
|
|
}
|
|
|
|
db, err = sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer db.Close()
|
|
var cwd string
|
|
if err := db.QueryRowContext(ctx, "SELECT cwd FROM threads WHERE id = ?", "thr_1").Scan(&cwd); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if cwd != "/new/path" {
|
|
t.Fatalf("db cwd = %q", cwd)
|
|
}
|
|
data, err := os.ReadFile(rollout)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var updated map[string]any
|
|
if err := json.Unmarshal([]byte(strings.SplitN(string(data), "\n", 2)[0]), &updated); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
payload := updated["payload"].(map[string]any)
|
|
if payload["cwd"] != "/new/path" {
|
|
t.Fatalf("rollout cwd = %v", payload["cwd"])
|
|
}
|
|
}
|