Initial codex telegram bot source
This commit is contained in:
479
internal/codexapp/client.go
Normal file
479
internal/codexapp/client.go
Normal file
@@ -0,0 +1,479 @@
|
||||
package codexapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
socketPath string
|
||||
version string
|
||||
|
||||
nextID int64
|
||||
|
||||
mu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
pending map[int64]chan rpcResult
|
||||
events chan Event
|
||||
connected bool
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
ID *int64
|
||||
Method string
|
||||
Params json.RawMessage
|
||||
ServerRequest bool
|
||||
Err error
|
||||
}
|
||||
|
||||
type RPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func (e RPCError) Error() string {
|
||||
if e.Code == 0 {
|
||||
return e.Message
|
||||
}
|
||||
return fmt.Sprintf("%d: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
type Thread struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Preview string `json:"preview,omitempty"`
|
||||
}
|
||||
|
||||
type Turn struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type InputItem struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description"`
|
||||
IsDefault bool `json:"isDefault"`
|
||||
DefaultReasoningEffort string `json:"defaultReasoningEffort"`
|
||||
SupportedReasoningEfforts []ReasoningEffortOption `json:"supportedReasoningEfforts"`
|
||||
}
|
||||
|
||||
type ReasoningEffortOption struct {
|
||||
Description string `json:"description"`
|
||||
ReasoningEffort string `json:"reasoningEffort"`
|
||||
}
|
||||
|
||||
func New(socketPath, version string) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
version: version,
|
||||
pending: make(map[int64]chan rpcResult),
|
||||
events: make(chan Event, 128),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Events() <-chan Event {
|
||||
return c.events
|
||||
}
|
||||
|
||||
func (c *Client) EnsureConnected(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
connected := c.connected && c.conn != nil
|
||||
c.mu.Unlock()
|
||||
if connected {
|
||||
return nil
|
||||
}
|
||||
return c.Connect(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.connected && c.conn != nil {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "unix", c.socketPath)
|
||||
},
|
||||
}
|
||||
conn, _, err := dialer.DialContext(ctx, "ws://codex-app-server/", http.Header{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect codex app-server unix socket %s: %w", c.socketPath, err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.connected = true
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.readLoop(conn)
|
||||
|
||||
if err := c.initialize(ctx); err != nil {
|
||||
_ = c.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.conn = nil
|
||||
c.connected = false
|
||||
c.mu.Unlock()
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
func (c *Client) initialize(ctx context.Context) error {
|
||||
params := map[string]any{
|
||||
"clientInfo": map[string]any{
|
||||
"name": "codex_telegram_bot",
|
||||
"title": "Codex Telegram Bot",
|
||||
"version": c.version,
|
||||
},
|
||||
"capabilities": map[string]any{
|
||||
"experimentalApi": true,
|
||||
},
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
if err := c.call(ctx, "initialize", params, &ignored); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.notify("initialized", nil)
|
||||
}
|
||||
|
||||
func (c *Client) StartThread(ctx context.Context, cwd, model, sandbox string) (Thread, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
params := map[string]any{
|
||||
"cwd": cwd,
|
||||
"approvalPolicy": "on-request",
|
||||
"sandbox": threadSandbox(sandbox),
|
||||
"serviceName": "codex_telegram_bot",
|
||||
}
|
||||
if model != "" {
|
||||
params["model"] = model
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/start", params, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
}
|
||||
|
||||
func (c *Client) ResumeThread(ctx context.Context, threadID string) (Thread, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/resume", map[string]any{"threadId": threadID}, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
}
|
||||
|
||||
func (c *Client) ForkThread(ctx context.Context, threadID string) (Thread, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
var result struct {
|
||||
Thread Thread `json:"thread"`
|
||||
}
|
||||
if err := c.call(ctx, "thread/fork", map[string]any{"threadId": threadID}, &result); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return result.Thread, nil
|
||||
}
|
||||
|
||||
func (c *Client) ArchiveThread(ctx context.Context, threadID string) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
return c.call(ctx, "thread/archive", map[string]any{"threadId": threadID}, &ignored)
|
||||
}
|
||||
|
||||
func (c *Client) StartTurn(ctx context.Context, threadID, cwd, model, reasoningEffort, sandbox string, input []InputItem) (Turn, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return Turn{}, err
|
||||
}
|
||||
params := map[string]any{
|
||||
"threadId": threadID,
|
||||
"input": input,
|
||||
"cwd": cwd,
|
||||
"approvalPolicy": "on-request",
|
||||
"sandboxPolicy": SandboxPolicy(sandbox, cwd),
|
||||
}
|
||||
if model != "" {
|
||||
params["model"] = model
|
||||
}
|
||||
if reasoningEffort != "" {
|
||||
params["effort"] = reasoningEffort
|
||||
}
|
||||
var result struct {
|
||||
Turn Turn `json:"turn"`
|
||||
}
|
||||
if err := c.call(ctx, "turn/start", params, &result); err != nil {
|
||||
return Turn{}, err
|
||||
}
|
||||
return result.Turn, nil
|
||||
}
|
||||
|
||||
func (c *Client) InterruptTurn(ctx context.Context, threadID, turnID string) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
return c.call(ctx, "turn/interrupt", map[string]any{
|
||||
"threadId": threadID,
|
||||
"turnId": turnID,
|
||||
}, &ignored)
|
||||
}
|
||||
|
||||
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result struct {
|
||||
Data []Model `json:"data"`
|
||||
}
|
||||
if err := c.call(ctx, "model/list", map[string]any{"limit": 50, "includeHidden": false}, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Data, nil
|
||||
}
|
||||
|
||||
func (c *Client) RespondServerRequest(ctx context.Context, requestID int64, result any) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.write(responseEnvelope{ID: requestID, Result: result})
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) call(ctx context.Context, method string, params any, result any) error {
|
||||
id := atomic.AddInt64(&c.nextID, 1)
|
||||
ch := make(chan rpcResult, 1)
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn == nil {
|
||||
c.mu.Unlock()
|
||||
return errors.New("codex app-server is not connected")
|
||||
}
|
||||
c.pending[id] = ch
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := c.write(requestEnvelope{ID: id, Method: method, Params: params}); err != nil {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
return ctx.Err()
|
||||
case rpc := <-ch:
|
||||
if rpc.err != nil {
|
||||
return rpc.err
|
||||
}
|
||||
if result == nil || len(rpc.result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(rpc.result, result)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) notify(method string, params any) error {
|
||||
return c.write(notificationEnvelope{Method: method, Params: params})
|
||||
}
|
||||
|
||||
func (c *Client) write(message any) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.conn == nil {
|
||||
return errors.New("codex app-server is not connected")
|
||||
}
|
||||
return c.conn.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(conn *websocket.Conn) {
|
||||
for {
|
||||
var env incomingEnvelope
|
||||
if err := conn.ReadJSON(&env); err != nil {
|
||||
c.failConnection(conn, err)
|
||||
return
|
||||
}
|
||||
if env.Method != "" && env.ID != nil {
|
||||
c.events <- Event{
|
||||
ID: env.ID,
|
||||
Method: env.Method,
|
||||
Params: env.Params,
|
||||
ServerRequest: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
if env.ID != nil {
|
||||
c.completeCall(*env.ID, env.Result, env.Error)
|
||||
continue
|
||||
}
|
||||
if env.Method != "" {
|
||||
c.events <- Event{
|
||||
Method: env.Method,
|
||||
Params: env.Params,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) completeCall(id int64, result json.RawMessage, rpcErr *RPCError) {
|
||||
c.mu.Lock()
|
||||
ch := c.pending[id]
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
if rpcErr != nil {
|
||||
ch <- rpcResult{err: *rpcErr}
|
||||
return
|
||||
}
|
||||
ch <- rpcResult{result: result}
|
||||
}
|
||||
|
||||
func (c *Client) failConnection(conn *websocket.Conn, err error) {
|
||||
c.mu.Lock()
|
||||
if c.conn == conn {
|
||||
c.conn = nil
|
||||
c.connected = false
|
||||
}
|
||||
pending := c.pending
|
||||
c.pending = make(map[int64]chan rpcResult)
|
||||
c.mu.Unlock()
|
||||
|
||||
for _, ch := range pending {
|
||||
ch <- rpcResult{err: err}
|
||||
}
|
||||
c.events <- Event{Method: "connection/closed", Err: err}
|
||||
}
|
||||
|
||||
func DecodeParams[T any](event Event) (T, error) {
|
||||
var value T
|
||||
if len(event.Params) == 0 {
|
||||
return value, nil
|
||||
}
|
||||
err := json.Unmarshal(event.Params, &value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
func NormalizeSandbox(value string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "", "workspace-write", "workspacewrite":
|
||||
return "workspace-write", nil
|
||||
case "read-only", "readonly":
|
||||
return "read-only", nil
|
||||
case "danger-full-access", "dangerfullaccess":
|
||||
return "danger-full-access", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported sandbox %q", value)
|
||||
}
|
||||
}
|
||||
|
||||
func SandboxPolicy(sandbox, cwd string) map[string]any {
|
||||
normalized, err := NormalizeSandbox(sandbox)
|
||||
if err != nil {
|
||||
normalized = "workspace-write"
|
||||
}
|
||||
switch normalized {
|
||||
case "read-only":
|
||||
return map[string]any{"type": "readOnly"}
|
||||
case "danger-full-access":
|
||||
return map[string]any{"type": "dangerFullAccess"}
|
||||
default:
|
||||
return map[string]any{
|
||||
"type": "workspaceWrite",
|
||||
"writableRoots": []string{cwd},
|
||||
"networkAccess": false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func threadSandbox(sandbox string) string {
|
||||
normalized, err := NormalizeSandbox(sandbox)
|
||||
if err != nil {
|
||||
return "workspace-write"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
type requestEnvelope struct {
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type notificationEnvelope struct {
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type responseEnvelope struct {
|
||||
ID int64 `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error any `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type incomingEnvelope struct {
|
||||
ID *int64 `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *RPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type rpcResult struct {
|
||||
result json.RawMessage
|
||||
err error
|
||||
}
|
||||
155
internal/codexapp/client_test.go
Normal file
155
internal/codexapp/client_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package codexapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestClientWebSocketUnixJSONRPC(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "codex.sock")
|
||||
serverDone := make(chan error, 1)
|
||||
ln, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(socketPath)
|
||||
|
||||
upgrader := websocket.Upgrader{}
|
||||
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var initialize map[string]any
|
||||
if err := conn.ReadJSON(&initialize); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if initialize["method"] != "initialize" {
|
||||
serverDone <- unexpectedMessage("initialize", initialize["method"])
|
||||
return
|
||||
}
|
||||
if err := conn.WriteJSON(map[string]any{"id": initialize["id"], "result": map[string]any{"userAgent": "test"}}); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
|
||||
var initialized map[string]any
|
||||
if err := conn.ReadJSON(&initialized); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if initialized["method"] != "initialized" {
|
||||
serverDone <- unexpectedMessage("initialized", initialized["method"])
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"id": 99,
|
||||
"method": "item/commandExecution/requestApproval",
|
||||
"params": map[string]any{"threadId": "thr_1"},
|
||||
}); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
|
||||
var start map[string]any
|
||||
if err := conn.ReadJSON(&start); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if start["method"] != "thread/start" {
|
||||
serverDone <- unexpectedMessage("thread/start", start["method"])
|
||||
return
|
||||
}
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"id": start["id"],
|
||||
"result": map[string]any{
|
||||
"thread": map[string]any{"id": "thr_1", "preview": "test"},
|
||||
},
|
||||
}); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
|
||||
var response map[string]any
|
||||
if err := conn.ReadJSON(&response); err != nil {
|
||||
serverDone <- err
|
||||
return
|
||||
}
|
||||
if response["id"].(float64) != 99 || response["result"] != "accept" {
|
||||
payload, _ := json.Marshal(response)
|
||||
serverDone <- unexpectedMessage("approval response", string(payload))
|
||||
return
|
||||
}
|
||||
serverDone <- nil
|
||||
})}
|
||||
defer server.Close()
|
||||
go func() {
|
||||
_ = server.Serve(ln)
|
||||
}()
|
||||
|
||||
client := New(socketPath, "test")
|
||||
if err := client.Connect(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
select {
|
||||
case event := <-client.Events():
|
||||
if !event.ServerRequest || event.ID == nil || *event.ID != 99 {
|
||||
t.Fatalf("unexpected event: %+v", event)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
|
||||
thread, err := client.StartThread(ctx, "/tmp/project", "", "workspace-write")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if thread.ID != "thr_1" {
|
||||
t.Fatalf("unexpected thread: %+v", thread)
|
||||
}
|
||||
if err := client.RespondServerRequest(ctx, 99, "accept"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-serverDone:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
type errUnexpected string
|
||||
|
||||
func (e errUnexpected) Error() string {
|
||||
return "unexpected " + string(e)
|
||||
}
|
||||
|
||||
func unexpectedMessage(want string, got any) error {
|
||||
return errUnexpected("message: want " + want + ", got " + jsonString(got))
|
||||
}
|
||||
|
||||
func jsonString(value any) string {
|
||||
data, _ := json.Marshal(value)
|
||||
return string(data)
|
||||
}
|
||||
18
internal/codexapp/turn.go
Normal file
18
internal/codexapp/turn.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package codexapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func (c *Client) SteerTurn(ctx context.Context, threadID, expectedTurnID string, input []InputItem) error {
|
||||
if err := c.EnsureConnected(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
var ignored json.RawMessage
|
||||
return c.call(ctx, "turn/steer", map[string]any{
|
||||
"threadId": threadID,
|
||||
"expectedTurnId": expectedTurnID,
|
||||
"input": input,
|
||||
}, &ignored)
|
||||
}
|
||||
61
internal/config/config.go
Normal file
61
internal/config/config.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
TelegramToken string
|
||||
DatabasePath string
|
||||
CodexSocketPath string
|
||||
UploadDir string
|
||||
DefaultModel string
|
||||
DefaultSandbox string
|
||||
PollTimeout time.Duration
|
||||
AppVersion string
|
||||
}
|
||||
|
||||
func Load() (Config, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
TelegramToken: os.Getenv("TELEGRAM_BOT_TOKEN"),
|
||||
DatabasePath: envOr("DB_PATH", filepath.Join(wd, "data", "bot.db")),
|
||||
CodexSocketPath: envOr("HOST_CODEX_SOCKET", filepath.Join(wd, "run", "codex.sock")),
|
||||
UploadDir: envOr("HOST_UPLOAD_DIR", filepath.Join(wd, "uploads")),
|
||||
DefaultModel: os.Getenv("DEFAULT_MODEL"),
|
||||
DefaultSandbox: envOr("DEFAULT_SANDBOX", "workspace-write"),
|
||||
PollTimeout: durationSeconds("POLL_TIMEOUT_SECONDS", 30),
|
||||
AppVersion: envOr("APP_VERSION", "0.1.0"),
|
||||
}
|
||||
if cfg.TelegramToken == "" {
|
||||
return Config{}, errors.New("TELEGRAM_BOT_TOKEN is required")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func envOr(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func durationSeconds(key string, fallback int) time.Duration {
|
||||
raw := os.Getenv(key)
|
||||
if raw == "" {
|
||||
return time.Duration(fallback) * time.Second
|
||||
}
|
||||
value, err := strconv.Atoi(raw)
|
||||
if err != nil || value <= 0 {
|
||||
return time.Duration(fallback) * time.Second
|
||||
}
|
||||
return time.Duration(value) * time.Second
|
||||
}
|
||||
94
internal/store/migrations.go
Normal file
94
internal/store/migrations.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package store
|
||||
|
||||
type migration struct {
|
||||
Name string
|
||||
SQL string
|
||||
}
|
||||
|
||||
var migrations = []migration{
|
||||
{
|
||||
Name: "initial_state",
|
||||
SQL: `
|
||||
CREATE TABLE allowed_users (
|
||||
telegram_user_id INTEGER PRIMARY KEY,
|
||||
username TEXT NOT NULL DEFAULT '',
|
||||
notes TEXT NOT NULL DEFAULT '',
|
||||
added_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE workspaces (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
path TEXT NOT NULL UNIQUE,
|
||||
label TEXT NOT NULL,
|
||||
is_default INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE sessions (
|
||||
telegram_user_id INTEGER PRIMARY KEY,
|
||||
active_thread_id INTEGER,
|
||||
active_workspace_id INTEGER,
|
||||
model TEXT NOT NULL DEFAULT '',
|
||||
sandbox TEXT NOT NULL DEFAULT 'workspace-write',
|
||||
active_turn_id TEXT NOT NULL DEFAULT '',
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
FOREIGN KEY(active_thread_id) REFERENCES threads(id) ON DELETE SET NULL,
|
||||
FOREIGN KEY(active_workspace_id) REFERENCES workspaces(id) ON DELETE SET NULL
|
||||
);
|
||||
|
||||
CREATE TABLE threads (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
telegram_user_id INTEGER NOT NULL,
|
||||
codex_thread_id TEXT NOT NULL UNIQUE,
|
||||
workspace_id INTEGER NOT NULL,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
archived INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
FOREIGN KEY(workspace_id) REFERENCES workspaces(id) ON DELETE RESTRICT
|
||||
);
|
||||
|
||||
CREATE INDEX idx_threads_user_updated ON threads(telegram_user_id, archived, updated_at DESC);
|
||||
|
||||
CREATE TABLE pending_approvals (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
telegram_user_id INTEGER NOT NULL,
|
||||
codex_request_id TEXT NOT NULL,
|
||||
codex_thread_id TEXT NOT NULL,
|
||||
turn_id TEXT NOT NULL DEFAULT '',
|
||||
item_id TEXT NOT NULL DEFAULT '',
|
||||
kind TEXT NOT NULL,
|
||||
payload_json TEXT NOT NULL,
|
||||
message_chat_id INTEGER NOT NULL DEFAULT 0,
|
||||
message_id INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
resolved_at TEXT NOT NULL DEFAULT '',
|
||||
UNIQUE(telegram_user_id, codex_request_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_pending_approvals_status ON pending_approvals(telegram_user_id, status);
|
||||
|
||||
CREATE TABLE audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
telegram_user_id INTEGER NOT NULL DEFAULT 0,
|
||||
action TEXT NOT NULL,
|
||||
details TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX idx_audit_user_created ON audit_log(telegram_user_id, created_at DESC);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "session_reasoning_effort",
|
||||
SQL: "ALTER TABLE sessions ADD COLUMN reasoning_effort TEXT NOT NULL DEFAULT ''",
|
||||
},
|
||||
{
|
||||
Name: "session_settings_message",
|
||||
SQL: `
|
||||
ALTER TABLE sessions ADD COLUMN settings_chat_id INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE sessions ADD COLUMN settings_message_id INTEGER NOT NULL DEFAULT 0;
|
||||
`,
|
||||
},
|
||||
}
|
||||
516
internal/store/store.go
Normal file
516
internal/store/store.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type AllowedUser struct {
|
||||
TelegramUserID int64
|
||||
Username string
|
||||
Notes string
|
||||
AddedAt string
|
||||
}
|
||||
|
||||
type Workspace struct {
|
||||
ID int64
|
||||
Path string
|
||||
Label string
|
||||
IsDefault bool
|
||||
CreatedAt string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
TelegramUserID int64
|
||||
ActiveThreadID int64
|
||||
ActiveWorkspaceID int64
|
||||
Model string
|
||||
ReasoningEffort string
|
||||
Sandbox string
|
||||
ActiveTurnID string
|
||||
SettingsChatID int64
|
||||
SettingsMessageID int
|
||||
UpdatedAt string
|
||||
}
|
||||
|
||||
type Thread struct {
|
||||
ID int64
|
||||
TelegramUserID int64
|
||||
CodexThreadID string
|
||||
WorkspaceID int64
|
||||
Title string
|
||||
Archived bool
|
||||
CreatedAt string
|
||||
UpdatedAt string
|
||||
}
|
||||
|
||||
type PendingApproval struct {
|
||||
ID int64
|
||||
TelegramUserID int64
|
||||
CodexRequestID string
|
||||
CodexThreadID string
|
||||
TurnID string
|
||||
ItemID string
|
||||
Kind string
|
||||
PayloadJSON string
|
||||
MessageChatID int64
|
||||
MessageID int
|
||||
Status string
|
||||
CreatedAt string
|
||||
ResolvedAt string
|
||||
}
|
||||
|
||||
func Open(ctx context.Context, path string) (*Store, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
s := &Store{db: db}
|
||||
if err := s.configure(ctx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := s.migrate(ctx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *Store) configure(ctx context.Context) error {
|
||||
statements := []string{
|
||||
"PRAGMA foreign_keys = ON",
|
||||
"PRAGMA journal_mode = WAL",
|
||||
"PRAGMA busy_timeout = 5000",
|
||||
}
|
||||
for _, statement := range statements {
|
||||
if _, err := s.db.ExecContext(ctx, statement); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) migrate(ctx context.Context) error {
|
||||
if _, err := s.db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var current int
|
||||
if err := s.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(¤t); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := current; i < len(migrations); i++ {
|
||||
version := i + 1
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, migrations[i].SQL); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("migration %d %s: %w", version, migrations[i].Name, err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (version, name) VALUES (?, ?)", version, migrations[i].Name); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateWorkspacePath(path string) (string, error) {
|
||||
if path == "" {
|
||||
return "", errors.New("workspace path is required")
|
||||
}
|
||||
if strings.ContainsRune(path, 0) {
|
||||
return "", errors.New("workspace path contains a NUL byte")
|
||||
}
|
||||
if !filepath.IsAbs(path) {
|
||||
return "", errors.New("workspace path must be absolute")
|
||||
}
|
||||
clean := filepath.Clean(path)
|
||||
if clean == string(filepath.Separator) {
|
||||
return "", errors.New("workspace path cannot be filesystem root")
|
||||
}
|
||||
return clean, nil
|
||||
}
|
||||
|
||||
func (s *Store) IsAllowed(ctx context.Context, telegramUserID int64) (bool, error) {
|
||||
var exists int
|
||||
err := s.db.QueryRowContext(ctx, "SELECT 1 FROM allowed_users WHERE telegram_user_id = ?", telegramUserID).Scan(&exists)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
func (s *Store) AddAllowedUser(ctx context.Context, telegramUserID int64, username, notes string) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO allowed_users (telegram_user_id, username, notes)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(telegram_user_id) DO UPDATE SET username = excluded.username, notes = excluded.notes`,
|
||||
telegramUserID, username, notes)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) RemoveAllowedUser(ctx context.Context, telegramUserID int64) error {
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM allowed_users WHERE telegram_user_id = ?", telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ListAllowedUsers(ctx context.Context) ([]AllowedUser, error) {
|
||||
rows, err := s.db.QueryContext(ctx, "SELECT telegram_user_id, username, notes, added_at FROM allowed_users ORDER BY telegram_user_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []AllowedUser
|
||||
for rows.Next() {
|
||||
var user AllowedUser
|
||||
if err := rows.Scan(&user.TelegramUserID, &user.Username, &user.Notes, &user.AddedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) AddWorkspace(ctx context.Context, path, label string, isDefault bool) (Workspace, error) {
|
||||
clean, err := ValidateWorkspacePath(path)
|
||||
if err != nil {
|
||||
return Workspace{}, err
|
||||
}
|
||||
if label == "" {
|
||||
label = filepath.Base(clean)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return Workspace{}, err
|
||||
}
|
||||
if isDefault {
|
||||
if _, err := tx.ExecContext(ctx, "UPDATE workspaces SET is_default = 0"); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return Workspace{}, err
|
||||
}
|
||||
}
|
||||
result, err := tx.ExecContext(ctx, `
|
||||
INSERT INTO workspaces (path, label, is_default)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(path) DO UPDATE SET label = excluded.label, is_default = excluded.is_default`,
|
||||
clean, label, boolInt(isDefault))
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return Workspace{}, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return Workspace{}, err
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
if id == 0 {
|
||||
return s.GetWorkspaceByPath(ctx, clean)
|
||||
}
|
||||
return s.GetWorkspaceByPath(ctx, clean)
|
||||
}
|
||||
|
||||
func (s *Store) ListWorkspaces(ctx context.Context) ([]Workspace, error) {
|
||||
rows, err := s.db.QueryContext(ctx, "SELECT id, path, label, is_default, created_at FROM workspaces ORDER BY is_default DESC, label, path")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var workspaces []Workspace
|
||||
for rows.Next() {
|
||||
workspace, err := scanWorkspace(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspaces = append(workspaces, workspace)
|
||||
}
|
||||
return workspaces, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) GetWorkspaceByID(ctx context.Context, id int64) (Workspace, error) {
|
||||
row := s.db.QueryRowContext(ctx, "SELECT id, path, label, is_default, created_at FROM workspaces WHERE id = ?", id)
|
||||
return scanWorkspace(row)
|
||||
}
|
||||
|
||||
func (s *Store) GetWorkspaceByPath(ctx context.Context, path string) (Workspace, error) {
|
||||
row := s.db.QueryRowContext(ctx, "SELECT id, path, label, is_default, created_at FROM workspaces WHERE path = ?", path)
|
||||
return scanWorkspace(row)
|
||||
}
|
||||
|
||||
func (s *Store) DefaultWorkspace(ctx context.Context) (Workspace, error) {
|
||||
row := s.db.QueryRowContext(ctx, "SELECT id, path, label, is_default, created_at FROM workspaces ORDER BY is_default DESC, id ASC LIMIT 1")
|
||||
return scanWorkspace(row)
|
||||
}
|
||||
|
||||
func (s *Store) GetOrCreateSession(ctx context.Context, telegramUserID int64, defaultModel, defaultSandbox string) (Session, error) {
|
||||
if defaultSandbox == "" {
|
||||
defaultSandbox = "workspace-write"
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO sessions (telegram_user_id, model, sandbox)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(telegram_user_id) DO NOTHING`, telegramUserID, defaultModel, defaultSandbox)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
return s.GetSession(ctx, telegramUserID)
|
||||
}
|
||||
|
||||
func (s *Store) GetSession(ctx context.Context, telegramUserID int64) (Session, error) {
|
||||
row := s.db.QueryRowContext(ctx, `
|
||||
SELECT telegram_user_id, COALESCE(active_thread_id, 0), COALESCE(active_workspace_id, 0), model, COALESCE(reasoning_effort, ''), sandbox, active_turn_id, COALESCE(settings_chat_id, 0), COALESCE(settings_message_id, 0), updated_at
|
||||
FROM sessions WHERE telegram_user_id = ?`, telegramUserID)
|
||||
var session Session
|
||||
err := row.Scan(&session.TelegramUserID, &session.ActiveThreadID, &session.ActiveWorkspaceID, &session.Model, &session.ReasoningEffort, &session.Sandbox, &session.ActiveTurnID, &session.SettingsChatID, &session.SettingsMessageID, &session.UpdatedAt)
|
||||
return session, err
|
||||
}
|
||||
|
||||
func (s *Store) SetSessionWorkspace(ctx context.Context, telegramUserID, workspaceID int64) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET active_workspace_id = ?, updated_at = datetime('now')
|
||||
WHERE telegram_user_id = ?`, workspaceID, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetSessionModel(ctx context.Context, telegramUserID int64, model string) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE sessions SET model = ?, reasoning_effort = '', updated_at = datetime('now') WHERE telegram_user_id = ?", model, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetSessionReasoningEffort(ctx context.Context, telegramUserID int64, effort string) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE sessions SET reasoning_effort = ?, updated_at = datetime('now') WHERE telegram_user_id = ?", effort, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetSessionSettingsMessage(ctx context.Context, telegramUserID int64, chatID int64, messageID int) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE sessions SET settings_chat_id = ?, settings_message_id = ?, updated_at = datetime('now') WHERE telegram_user_id = ?", chatID, messageID, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetSessionSandbox(ctx context.Context, telegramUserID int64, sandbox string) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE sessions SET sandbox = ?, updated_at = datetime('now') WHERE telegram_user_id = ?", sandbox, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetActiveThread(ctx context.Context, telegramUserID, threadID int64) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET active_thread_id = ?, active_turn_id = '', updated_at = datetime('now')
|
||||
WHERE telegram_user_id = ?`, threadID, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetActiveTurn(ctx context.Context, telegramUserID int64, turnID string) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE sessions SET active_turn_id = ?, updated_at = datetime('now') WHERE telegram_user_id = ?", turnID, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ClearActiveTurns(ctx context.Context) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE sessions SET active_turn_id = '', updated_at = datetime('now') WHERE active_turn_id <> ''")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateThread(ctx context.Context, telegramUserID int64, codexThreadID string, workspaceID int64, title string) (Thread, error) {
|
||||
result, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO threads (telegram_user_id, codex_thread_id, workspace_id, title)
|
||||
VALUES (?, ?, ?, ?)`, telegramUserID, codexThreadID, workspaceID, title)
|
||||
if err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
return s.GetThreadByID(ctx, telegramUserID, id)
|
||||
}
|
||||
|
||||
func (s *Store) GetThreadByID(ctx context.Context, telegramUserID, id int64) (Thread, error) {
|
||||
row := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, telegram_user_id, codex_thread_id, workspace_id, title, archived, created_at, updated_at
|
||||
FROM threads WHERE telegram_user_id = ? AND id = ?`, telegramUserID, id)
|
||||
return scanThread(row)
|
||||
}
|
||||
|
||||
func (s *Store) GetThreadByCodexID(ctx context.Context, codexThreadID string) (Thread, error) {
|
||||
row := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, telegram_user_id, codex_thread_id, workspace_id, title, archived, created_at, updated_at
|
||||
FROM threads WHERE codex_thread_id = ?`, codexThreadID)
|
||||
return scanThread(row)
|
||||
}
|
||||
|
||||
func (s *Store) ListThreads(ctx context.Context, telegramUserID int64, includeArchived bool) ([]Thread, error) {
|
||||
return s.ListThreadsPage(ctx, telegramUserID, includeArchived, 20, 0)
|
||||
}
|
||||
|
||||
func (s *Store) ListThreadsPage(ctx context.Context, telegramUserID int64, includeArchived bool, limit, offset int) ([]Thread, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
query := `
|
||||
SELECT id, telegram_user_id, codex_thread_id, workspace_id, title, archived, created_at, updated_at
|
||||
FROM threads WHERE telegram_user_id = ?`
|
||||
args := []any{telegramUserID}
|
||||
if !includeArchived {
|
||||
query += " AND archived = 0"
|
||||
}
|
||||
query += " ORDER BY updated_at DESC, id DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var threads []Thread
|
||||
for rows.Next() {
|
||||
thread, err := scanThread(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
threads = append(threads, thread)
|
||||
}
|
||||
return threads, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) ArchiveThread(ctx context.Context, telegramUserID, id int64) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE threads SET archived = 1, updated_at = datetime('now')
|
||||
WHERE telegram_user_id = ? AND id = ?`, telegramUserID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) TouchThread(ctx context.Context, codexThreadID string) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE threads SET updated_at = datetime('now') WHERE codex_thread_id = ?", codexThreadID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) UpsertPendingApproval(ctx context.Context, approval PendingApproval) (PendingApproval, error) {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO pending_approvals (
|
||||
telegram_user_id, codex_request_id, codex_thread_id, turn_id, item_id, kind, payload_json,
|
||||
message_chat_id, message_id, status
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending')
|
||||
ON CONFLICT(telegram_user_id, codex_request_id) DO UPDATE SET
|
||||
payload_json = excluded.payload_json,
|
||||
message_chat_id = excluded.message_chat_id,
|
||||
message_id = excluded.message_id,
|
||||
status = 'pending',
|
||||
resolved_at = ''`,
|
||||
approval.TelegramUserID, approval.CodexRequestID, approval.CodexThreadID, approval.TurnID,
|
||||
approval.ItemID, approval.Kind, approval.PayloadJSON, approval.MessageChatID, approval.MessageID)
|
||||
if err != nil {
|
||||
return PendingApproval{}, err
|
||||
}
|
||||
return s.GetPendingApprovalByRequest(ctx, approval.TelegramUserID, approval.CodexRequestID)
|
||||
}
|
||||
|
||||
func (s *Store) UpdatePendingApprovalMessage(ctx context.Context, id int64, chatID int64, messageID int) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE pending_approvals SET message_chat_id = ?, message_id = ? WHERE id = ?", chatID, messageID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) GetPendingApproval(ctx context.Context, telegramUserID, id int64) (PendingApproval, error) {
|
||||
row := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, telegram_user_id, codex_request_id, codex_thread_id, turn_id, item_id, kind, payload_json,
|
||||
message_chat_id, message_id, status, created_at, COALESCE(resolved_at, '')
|
||||
FROM pending_approvals WHERE telegram_user_id = ? AND id = ?`, telegramUserID, id)
|
||||
return scanPendingApproval(row)
|
||||
}
|
||||
|
||||
func (s *Store) GetPendingApprovalByRequest(ctx context.Context, telegramUserID int64, requestID string) (PendingApproval, error) {
|
||||
row := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, telegram_user_id, codex_request_id, codex_thread_id, turn_id, item_id, kind, payload_json,
|
||||
message_chat_id, message_id, status, created_at, COALESCE(resolved_at, '')
|
||||
FROM pending_approvals WHERE telegram_user_id = ? AND codex_request_id = ?`, telegramUserID, requestID)
|
||||
return scanPendingApproval(row)
|
||||
}
|
||||
|
||||
func (s *Store) ResolvePendingApproval(ctx context.Context, telegramUserID, id int64, status string) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE pending_approvals SET status = ?, resolved_at = datetime('now')
|
||||
WHERE telegram_user_id = ? AND id = ? AND status = 'pending'`, status, telegramUserID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) Audit(ctx context.Context, telegramUserID int64, action, details string) error {
|
||||
_, err := s.db.ExecContext(ctx, "INSERT INTO audit_log (telegram_user_id, action, details) VALUES (?, ?, ?)", telegramUserID, action, details)
|
||||
return err
|
||||
}
|
||||
|
||||
type scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanWorkspace(row scanner) (Workspace, error) {
|
||||
var workspace Workspace
|
||||
var isDefault int
|
||||
if err := row.Scan(&workspace.ID, &workspace.Path, &workspace.Label, &isDefault, &workspace.CreatedAt); err != nil {
|
||||
return Workspace{}, err
|
||||
}
|
||||
workspace.IsDefault = isDefault != 0
|
||||
return workspace, nil
|
||||
}
|
||||
|
||||
func scanThread(row scanner) (Thread, error) {
|
||||
var thread Thread
|
||||
var archived int
|
||||
if err := row.Scan(&thread.ID, &thread.TelegramUserID, &thread.CodexThreadID, &thread.WorkspaceID, &thread.Title, &archived, &thread.CreatedAt, &thread.UpdatedAt); err != nil {
|
||||
return Thread{}, err
|
||||
}
|
||||
thread.Archived = archived != 0
|
||||
return thread, nil
|
||||
}
|
||||
|
||||
func scanPendingApproval(row scanner) (PendingApproval, error) {
|
||||
var approval PendingApproval
|
||||
if err := row.Scan(&approval.ID, &approval.TelegramUserID, &approval.CodexRequestID, &approval.CodexThreadID,
|
||||
&approval.TurnID, &approval.ItemID, &approval.Kind, &approval.PayloadJSON, &approval.MessageChatID,
|
||||
&approval.MessageID, &approval.Status, &approval.CreatedAt, &approval.ResolvedAt); err != nil {
|
||||
return PendingApproval{}, err
|
||||
}
|
||||
return approval, nil
|
||||
}
|
||||
|
||||
func boolInt(value bool) int {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
137
internal/store/store_test.go
Normal file
137
internal/store/store_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStoreUsersWorkspacesSessions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st, err := Open(ctx, filepath.Join(t.TempDir(), "bot.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
allowed, err := st.IsAllowed(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if allowed {
|
||||
t.Fatal("new user should not be allowed")
|
||||
}
|
||||
if err := st.AddAllowedUser(ctx, 42, "alice", "owner"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
allowed, err = st.IsAllowed(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !allowed {
|
||||
t.Fatal("user should be allowed")
|
||||
}
|
||||
|
||||
workspacePath := t.TempDir()
|
||||
ws, err := st.AddWorkspace(ctx, workspacePath, "repo", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !ws.IsDefault {
|
||||
t.Fatal("workspace should be default")
|
||||
}
|
||||
|
||||
session, err := st.GetOrCreateSession(ctx, 42, "test-model", "read-only")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if session.Model != "test-model" || session.Sandbox != "read-only" || session.ReasoningEffort != "" {
|
||||
t.Fatalf("unexpected session defaults: %+v", session)
|
||||
}
|
||||
if err := st.SetSessionWorkspace(ctx, 42, ws.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
effort := "value-returned-by-server"
|
||||
if err := st.SetSessionReasoningEffort(ctx, 42, effort); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session, err = st.GetSession(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if session.ActiveWorkspaceID != ws.ID || session.ReasoningEffort != effort {
|
||||
t.Fatalf("workspace not saved: %+v", session)
|
||||
}
|
||||
if err := st.SetSessionSettingsMessage(ctx, 42, 1001, 2002); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session, err = st.GetSession(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if session.SettingsChatID != 1001 || session.SettingsMessageID != 2002 {
|
||||
t.Fatalf("settings message not saved: %+v", session)
|
||||
}
|
||||
if err := st.SetActiveTurn(ctx, 42, "turn-123"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := st.ClearActiveTurns(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session, err = st.GetSession(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if session.ActiveTurnID != "" {
|
||||
t.Fatalf("active turn not cleared: %+v", session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListThreadsPage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st, err := Open(ctx, filepath.Join(t.TempDir(), "bot.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
ws, err := st.AddWorkspace(ctx, t.TempDir(), "repo", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := st.CreateThread(ctx, 42, string(rune('a'+i)), ws.ID, "thread"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
threads, err := st.ListThreadsPage(ctx, 42, false, 2, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(threads) != 2 {
|
||||
t.Fatalf("got %d threads, want 2", len(threads))
|
||||
}
|
||||
threads, err = st.ListThreadsPage(ctx, 42, false, 2, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(threads) != 1 {
|
||||
t.Fatalf("got %d threads on second page, want 1", len(threads))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspacePath(t *testing.T) {
|
||||
if _, err := ValidateWorkspacePath("relative/path"); err == nil {
|
||||
t.Fatal("relative path should be rejected")
|
||||
}
|
||||
if _, err := ValidateWorkspacePath("/"); err == nil {
|
||||
t.Fatal("filesystem root should be rejected")
|
||||
}
|
||||
clean, err := ValidateWorkspacePath("/tmp/../tmp/project")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if clean != "/tmp/project" {
|
||||
t.Fatalf("unexpected clean path: %s", clean)
|
||||
}
|
||||
}
|
||||
211
internal/telegram/api.go
Normal file
211
internal/telegram/api.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
token string
|
||||
baseURL string
|
||||
fileBaseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(token string) *Client {
|
||||
return &Client{
|
||||
token: token,
|
||||
baseURL: "https://api.telegram.org/bot" + token,
|
||||
fileBaseURL: "https://api.telegram.org/file/bot" + token,
|
||||
httpClient: &http.Client{Timeout: 90 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) GetUpdates(ctx context.Context, offset int, timeoutSeconds int) ([]Update, error) {
|
||||
params := map[string]any{
|
||||
"offset": offset,
|
||||
"timeout": timeoutSeconds,
|
||||
"allowed_updates": []string{"message", "callback_query"},
|
||||
}
|
||||
var updates []Update
|
||||
if err := c.postJSON(ctx, "getUpdates", params, &updates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return updates, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendMessage(ctx context.Context, chatID int64, text string, opts SendMessageOptions) (Message, error) {
|
||||
params := map[string]any{
|
||||
"chat_id": chatID,
|
||||
"text": text,
|
||||
}
|
||||
if opts.ParseMode != "" {
|
||||
params["parse_mode"] = opts.ParseMode
|
||||
}
|
||||
if opts.ReplyMarkup != nil {
|
||||
params["reply_markup"] = opts.ReplyMarkup
|
||||
}
|
||||
var message Message
|
||||
if err := c.postJSON(ctx, "sendMessage", params, &message); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (c *Client) EditMessageText(ctx context.Context, chatID int64, messageID int, text string, opts EditMessageTextOptions) (Message, error) {
|
||||
params := map[string]any{
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"text": text,
|
||||
}
|
||||
if opts.ParseMode != "" {
|
||||
params["parse_mode"] = opts.ParseMode
|
||||
}
|
||||
if opts.ReplyMarkup != nil {
|
||||
params["reply_markup"] = opts.ReplyMarkup
|
||||
}
|
||||
var message Message
|
||||
if err := c.postJSON(ctx, "editMessageText", params, &message); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (c *Client) PinChatMessage(ctx context.Context, chatID int64, messageID int, disableNotification bool) error {
|
||||
params := map[string]any{
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"disable_notification": disableNotification,
|
||||
}
|
||||
var ignored bool
|
||||
return c.postJSON(ctx, "pinChatMessage", params, &ignored)
|
||||
}
|
||||
|
||||
func (c *Client) AnswerCallbackQuery(ctx context.Context, callbackQueryID, text string) error {
|
||||
params := map[string]any{
|
||||
"callback_query_id": callbackQueryID,
|
||||
}
|
||||
if text != "" {
|
||||
params["text"] = text
|
||||
}
|
||||
var ignored bool
|
||||
return c.postJSON(ctx, "answerCallbackQuery", params, &ignored)
|
||||
}
|
||||
|
||||
func (c *Client) GetFile(ctx context.Context, fileID string) (File, error) {
|
||||
var file File
|
||||
if err := c.postJSON(ctx, "getFile", map[string]any{"file_id": fileID}, &file); err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (c *Client) DownloadFile(ctx context.Context, filePath string) ([]byte, error) {
|
||||
u := c.fileBaseURL + "/" + url.PathEscape(filePath)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("download file: telegram returned %s", resp.Status)
|
||||
}
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (c *Client) SendDocumentBytes(ctx context.Context, chatID int64, filename string, data []byte, caption string) (Message, error) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if err := writer.WriteField("chat_id", fmt.Sprint(chatID)); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
if caption != "" {
|
||||
if err := writer.WriteField("caption", caption); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
}
|
||||
part, err := writer.CreateFormFile("document", filepath.Base(filename))
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/sendDocument", &body)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
payload, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return Message{}, fmt.Errorf("sendDocument: telegram returned %s: %s", resp.Status, string(payload))
|
||||
}
|
||||
var decoded apiResponse[Message]
|
||||
if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
if !decoded.OK {
|
||||
return Message{}, fmt.Errorf("sendDocument: telegram error %d: %s", decoded.ErrorCode, decoded.Description)
|
||||
}
|
||||
return decoded.Result, nil
|
||||
}
|
||||
|
||||
func (c *Client) postJSON(ctx context.Context, method string, params any, result any) error {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/"+method, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
payload, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("%s: telegram returned %s: %s", method, resp.Status, string(payload))
|
||||
}
|
||||
var decoded apiResponse[json.RawMessage]
|
||||
if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
if !decoded.OK {
|
||||
return fmt.Errorf("%s: telegram error %d: %s", method, decoded.ErrorCode, decoded.Description)
|
||||
}
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(decoded.Result, result)
|
||||
}
|
||||
|
||||
type apiResponse[T any] struct {
|
||||
OK bool `json:"ok"`
|
||||
Result T `json:"result"`
|
||||
ErrorCode int `json:"error_code,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
1760
internal/telegram/bot.go
Normal file
1760
internal/telegram/bot.go
Normal file
File diff suppressed because it is too large
Load Diff
26
internal/telegram/download.go
Normal file
26
internal/telegram/download.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Client) DownloadFilePath(ctx context.Context, filePath string) ([]byte, error) {
|
||||
u := c.fileBaseURL + "/" + strings.TrimLeft(filePath, "/")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("download file: telegram returned %s", resp.Status)
|
||||
}
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
183
internal/telegram/render.go
Normal file
183
internal/telegram/render.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const TelegramMessageLimit = 4096
|
||||
const TelegramHTMLMessageLimit = 3900
|
||||
|
||||
func EscapeHTML(text string) string {
|
||||
return html.EscapeString(text)
|
||||
}
|
||||
|
||||
func SummaryDetailsHTML(summary, details string) string {
|
||||
summary = strings.TrimSpace(summary)
|
||||
details = strings.TrimSpace(details)
|
||||
if details == "" {
|
||||
return EscapeHTML(summary)
|
||||
}
|
||||
if summary == "" {
|
||||
return ExpandableQuoteHTML(details)
|
||||
}
|
||||
return EscapeHTML(summary) + "\n" + ExpandableQuoteHTML(details)
|
||||
}
|
||||
|
||||
func ExpandableQuoteHTML(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
return "<blockquote expandable>" + EscapeHTML(text) + "</blockquote>"
|
||||
}
|
||||
|
||||
func SummaryDetailsHTMLLimited(summary, details string, limit int) string {
|
||||
if limit <= 0 {
|
||||
limit = TelegramHTMLMessageLimit
|
||||
}
|
||||
summary = strings.TrimSpace(summary)
|
||||
details = strings.TrimSpace(details)
|
||||
out := SummaryDetailsHTML(summary, details)
|
||||
if len([]rune(out)) <= limit || details == "" {
|
||||
return out
|
||||
}
|
||||
|
||||
suffix := "\n...[truncated]"
|
||||
runes := []rune(details)
|
||||
for len(runes) > 0 {
|
||||
candidateLen := len(runes) - max(1, (len([]rune(out))-limit)/2)
|
||||
if candidateLen < 0 {
|
||||
candidateLen = 0
|
||||
}
|
||||
if candidateLen > len(runes) {
|
||||
candidateLen = len(runes)
|
||||
}
|
||||
candidate := strings.TrimSpace(string(runes[:candidateLen])) + suffix
|
||||
out = SummaryDetailsHTML(summary, candidate)
|
||||
if len([]rune(out)) <= limit || candidateLen == 0 {
|
||||
return out
|
||||
}
|
||||
runes = runes[:candidateLen]
|
||||
}
|
||||
return SummaryDetailsHTML(summary, suffix)
|
||||
}
|
||||
|
||||
func ChunkText(text string, max int) []string {
|
||||
if max <= 0 {
|
||||
max = TelegramMessageLimit
|
||||
}
|
||||
runes := []rune(text)
|
||||
if len(runes) == 0 {
|
||||
return nil
|
||||
}
|
||||
var chunks []string
|
||||
for len(runes) > max {
|
||||
cut := max
|
||||
for i := max; i > max/2; i-- {
|
||||
if runes[i-1] == '\n' {
|
||||
cut = i
|
||||
break
|
||||
}
|
||||
}
|
||||
chunks = append(chunks, string(runes[:cut]))
|
||||
runes = runes[cut:]
|
||||
}
|
||||
if len(runes) > 0 {
|
||||
chunks = append(chunks, string(runes))
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
func ApprovalCallbackData(id int64, decision string) string {
|
||||
return fmt.Sprintf("approval:%d:%s", id, decision)
|
||||
}
|
||||
|
||||
func ParseApprovalCallbackData(data string) (int64, string, bool) {
|
||||
parts := strings.Split(data, ":")
|
||||
if len(parts) != 3 || parts[0] != "approval" {
|
||||
return 0, "", false
|
||||
}
|
||||
id, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return 0, "", false
|
||||
}
|
||||
switch parts[2] {
|
||||
case "accept", "acceptForSession", "decline", "cancel", "details":
|
||||
return id, parts[2], true
|
||||
default:
|
||||
return 0, "", false
|
||||
}
|
||||
}
|
||||
|
||||
func WorkspaceCallbackData(id int64) string {
|
||||
return fmt.Sprintf("workspace:%d", id)
|
||||
}
|
||||
|
||||
func ParseWorkspaceCallbackData(data string) (int64, bool) {
|
||||
if !strings.HasPrefix(data, "workspace:") {
|
||||
return 0, false
|
||||
}
|
||||
id, err := strconv.ParseInt(strings.TrimPrefix(data, "workspace:"), 10, 64)
|
||||
return id, err == nil && id > 0
|
||||
}
|
||||
|
||||
func ResumeThreadCallbackData(id int64) string {
|
||||
return fmt.Sprintf("resume:thread:%d", id)
|
||||
}
|
||||
|
||||
func ParseResumeThreadCallbackData(data string) (int64, bool) {
|
||||
if !strings.HasPrefix(data, "resume:thread:") {
|
||||
return 0, false
|
||||
}
|
||||
id, err := strconv.ParseInt(strings.TrimPrefix(data, "resume:thread:"), 10, 64)
|
||||
return id, err == nil && id > 0
|
||||
}
|
||||
|
||||
func ResumePageCallbackData(page int) string {
|
||||
return fmt.Sprintf("resume:page:%d", page)
|
||||
}
|
||||
|
||||
func ParseResumePageCallbackData(data string) (int, bool) {
|
||||
if !strings.HasPrefix(data, "resume:page:") {
|
||||
return 0, false
|
||||
}
|
||||
page, err := strconv.Atoi(strings.TrimPrefix(data, "resume:page:"))
|
||||
return page, err == nil && page >= 0
|
||||
}
|
||||
|
||||
func ModelCallbackData(modelID string) (string, bool) {
|
||||
encoded := base64.RawURLEncoding.EncodeToString([]byte(modelID))
|
||||
data := "model:" + encoded
|
||||
return data, len([]rune(data)) <= 64
|
||||
}
|
||||
|
||||
func ParseModelCallbackData(data string) (string, bool) {
|
||||
if !strings.HasPrefix(data, "model:") {
|
||||
return "", false
|
||||
}
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(data, "model:"))
|
||||
if err != nil || len(decoded) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return string(decoded), true
|
||||
}
|
||||
|
||||
func EffortCallbackData(effort string) string {
|
||||
encoded := base64.RawURLEncoding.EncodeToString([]byte(effort))
|
||||
return "effort:" + encoded
|
||||
}
|
||||
|
||||
func ParseEffortCallbackData(data string) (string, bool) {
|
||||
if !strings.HasPrefix(data, "effort:") {
|
||||
return "", false
|
||||
}
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(data, "effort:"))
|
||||
if err != nil || len(decoded) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return string(decoded), true
|
||||
}
|
||||
152
internal/telegram/render_test.go
Normal file
152
internal/telegram/render_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"codex-telegram-bot/internal/store"
|
||||
)
|
||||
|
||||
func TestEscapeHTML(t *testing.T) {
|
||||
got := EscapeHTML(`<run & "test">`)
|
||||
want := "<run & "test">"
|
||||
if got != want {
|
||||
t.Fatalf("EscapeHTML() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkText(t *testing.T) {
|
||||
text := strings.Repeat("a", 25)
|
||||
chunks := ChunkText(text, 10)
|
||||
if len(chunks) != 3 {
|
||||
t.Fatalf("got %d chunks", len(chunks))
|
||||
}
|
||||
for _, chunk := range chunks {
|
||||
if len([]rune(chunk)) > 10 {
|
||||
t.Fatalf("chunk too long: %q", chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalCallbackData(t *testing.T) {
|
||||
data := ApprovalCallbackData(12, "accept")
|
||||
id, decision, ok := ParseApprovalCallbackData(data)
|
||||
if !ok || id != 12 || decision != "accept" {
|
||||
t.Fatalf("unexpected callback parse: id=%d decision=%s ok=%v", id, decision, ok)
|
||||
}
|
||||
if _, _, ok := ParseApprovalCallbackData("approval:12:unknown"); ok {
|
||||
t.Fatal("unknown decisions should be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalResponseForPermissions(t *testing.T) {
|
||||
approval := store.PendingApproval{
|
||||
Kind: "item/permissions/requestApproval",
|
||||
PayloadJSON: `{"permissions":{"network":{"enabled":true}}}`,
|
||||
}
|
||||
response, ok := approvalResponse(approval, "accept").(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("approval response should be a map")
|
||||
}
|
||||
if response["scope"] != "turn" {
|
||||
t.Fatalf("scope = %v, want turn", response["scope"])
|
||||
}
|
||||
permissions, ok := response["permissions"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("permissions should be a map")
|
||||
}
|
||||
network, ok := permissions["network"].(map[string]any)
|
||||
if !ok || network["enabled"] != true {
|
||||
t.Fatalf("unexpected permissions: %#v", permissions)
|
||||
}
|
||||
|
||||
denied, ok := approvalResponse(approval, "decline").(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("denied response should be a map")
|
||||
}
|
||||
deniedPermissions, ok := denied["permissions"].(map[string]any)
|
||||
if !ok || len(deniedPermissions) != 0 {
|
||||
t.Fatalf("denied permissions = %#v, want empty map", denied["permissions"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCommand(t *testing.T) {
|
||||
name, args, ok := parseCommand("/resume@my_bot 123")
|
||||
if !ok || name != "resume" || len(args) != 1 || args[0] != "123" {
|
||||
t.Fatalf("unexpected command parse: %q %#v %v", name, args, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderCodexCommandExecutionItem(t *testing.T) {
|
||||
output := "line 1\nline 2"
|
||||
exitCode := 0
|
||||
item := codexThreadItemView{
|
||||
Type: "commandExecution",
|
||||
Command: "go test ./...",
|
||||
AggregatedOutput: &output,
|
||||
ExitCode: &exitCode,
|
||||
}
|
||||
text := renderCodexItemCompleted(item)
|
||||
for _, want := range []string{"Tool call: command finished", "Command: go test ./...", "Exit code: 0", "line 1"} {
|
||||
if !strings.Contains(text, want) {
|
||||
t.Fatalf("rendered command item missing %q in %q", want, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderCodexStartedItems(t *testing.T) {
|
||||
text := renderCodexItemStarted(codexThreadItemView{Type: "webSearch", Query: "telegram bot api"})
|
||||
if !strings.Contains(text, "web search started") || !strings.Contains(text, "telegram bot api") {
|
||||
t.Fatalf("unexpected web search render: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResumeCallbackData(t *testing.T) {
|
||||
threadID, ok := ParseResumeThreadCallbackData(ResumeThreadCallbackData(123))
|
||||
if !ok || threadID != 123 {
|
||||
t.Fatalf("unexpected resume thread callback: id=%d ok=%v", threadID, ok)
|
||||
}
|
||||
page, ok := ParseResumePageCallbackData(ResumePageCallbackData(2))
|
||||
if !ok || page != 2 {
|
||||
t.Fatalf("unexpected resume page callback: page=%d ok=%v", page, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResumeThreadListText(t *testing.T) {
|
||||
threads := []store.Thread{{ID: 42, Title: "do xyz"}, {ID: 43, Title: "executed xxx command"}}
|
||||
text := resumeThreadListText(threads, 0)
|
||||
for _, want := range []string{"Thread ID 42: do xyz", "Thread ID 43: executed xxx command"} {
|
||||
if !strings.Contains(text, want) {
|
||||
t.Fatalf("resume list missing %q in %q", want, text)
|
||||
}
|
||||
}
|
||||
markup := resumeThreadMarkup(threads, 0, true)
|
||||
if len(markup.InlineKeyboard) != 2 || markup.InlineKeyboard[0][0].Text != "ID 42" || markup.InlineKeyboard[0][1].Text != "ID 43" {
|
||||
t.Fatalf("unexpected resume buttons: %#v", markup.InlineKeyboard)
|
||||
}
|
||||
firstID, ok := ParseResumeThreadCallbackData(markup.InlineKeyboard[0][0].CallbackData)
|
||||
if !ok || firstID != 42 {
|
||||
t.Fatalf("first resume button targets id=%d ok=%v", firstID, ok)
|
||||
}
|
||||
secondID, ok := ParseResumeThreadCallbackData(markup.InlineKeyboard[0][1].CallbackData)
|
||||
if !ok || secondID != 43 {
|
||||
t.Fatalf("second resume button targets id=%d ok=%v", secondID, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelAndEffortCallbackData(t *testing.T) {
|
||||
modelID := strings.Join([]string{"server", "model", "id"}, "-")
|
||||
data, ok := ModelCallbackData(modelID)
|
||||
if !ok {
|
||||
t.Fatal("model callback should fit")
|
||||
}
|
||||
model, ok := ParseModelCallbackData(data)
|
||||
if !ok || model != modelID {
|
||||
t.Fatalf("unexpected model callback parse: model=%q ok=%v", model, ok)
|
||||
}
|
||||
effortName := strings.Join([]string{"server", "effort"}, "-")
|
||||
effort, ok := ParseEffortCallbackData(EffortCallbackData(effortName))
|
||||
if !ok || effort != effortName {
|
||||
t.Fatalf("unexpected effort callback parse: effort=%q ok=%v", effort, ok)
|
||||
}
|
||||
}
|
||||
76
internal/telegram/types.go
Normal file
76
internal/telegram/types.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package telegram
|
||||
|
||||
type Update struct {
|
||||
UpdateID int `json:"update_id"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
CallbackQuery *CallbackQuery `json:"callback_query,omitempty"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
IsBot bool `json:"is_bot"`
|
||||
FirstName string `json:"first_name,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
}
|
||||
|
||||
type Chat struct {
|
||||
ID int64 `json:"id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
MessageID int `json:"message_id"`
|
||||
From *User `json:"from,omitempty"`
|
||||
Chat Chat `json:"chat"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Caption string `json:"caption,omitempty"`
|
||||
Document *Document `json:"document,omitempty"`
|
||||
Photo []PhotoSize `json:"photo,omitempty"`
|
||||
PinnedMessage *Message `json:"pinned_message,omitempty"`
|
||||
}
|
||||
|
||||
type Document struct {
|
||||
FileID string `json:"file_id"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
}
|
||||
|
||||
type PhotoSize struct {
|
||||
FileID string `json:"file_id"`
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
type CallbackQuery struct {
|
||||
ID string `json:"id"`
|
||||
From User `json:"from"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
FileID string `json:"file_id"`
|
||||
FilePath string `json:"file_path,omitempty"`
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
}
|
||||
|
||||
type InlineKeyboardMarkup struct {
|
||||
InlineKeyboard [][]InlineKeyboardButton `json:"inline_keyboard"`
|
||||
}
|
||||
|
||||
type InlineKeyboardButton struct {
|
||||
Text string `json:"text"`
|
||||
CallbackData string `json:"callback_data,omitempty"`
|
||||
}
|
||||
|
||||
type SendMessageOptions struct {
|
||||
ParseMode string `json:"parse_mode,omitempty"`
|
||||
ReplyMarkup *InlineKeyboardMarkup `json:"reply_markup,omitempty"`
|
||||
}
|
||||
|
||||
type EditMessageTextOptions struct {
|
||||
ParseMode string `json:"parse_mode,omitempty"`
|
||||
ReplyMarkup *InlineKeyboardMarkup `json:"reply_markup,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user