Refine Telegram thread commands
This commit is contained in:
@@ -339,6 +339,13 @@ WHERE telegram_user_id = ?`, threadID, telegramUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ClearActiveThread(ctx context.Context, telegramUserID, threadID int64) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET active_thread_id = NULL, active_turn_id = '', updated_at = datetime('now')
|
||||
WHERE telegram_user_id = ? AND active_thread_id = ?`, telegramUserID, threadID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetActiveTurn(ctx context.Context, telegramUserID int64, turnID string) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE sessions SET active_turn_id = ?, updated_at = datetime('now') WHERE telegram_user_id = ?", turnID, telegramUserID)
|
||||
return err
|
||||
@@ -413,6 +420,18 @@ func (s *Store) ListThreads(ctx context.Context, telegramUserID int64, includeAr
|
||||
}
|
||||
|
||||
func (s *Store) ListThreadsPage(ctx context.Context, telegramUserID int64, includeArchived bool, limit, offset int) ([]Thread, error) {
|
||||
archivedFilter := ""
|
||||
if !includeArchived {
|
||||
archivedFilter = "archived = 0"
|
||||
}
|
||||
return s.listThreadsPage(ctx, telegramUserID, archivedFilter, limit, offset)
|
||||
}
|
||||
|
||||
func (s *Store) ListArchivedThreadsPage(ctx context.Context, telegramUserID int64, limit, offset int) ([]Thread, error) {
|
||||
return s.listThreadsPage(ctx, telegramUserID, "archived = 1", limit, offset)
|
||||
}
|
||||
|
||||
func (s *Store) listThreadsPage(ctx context.Context, telegramUserID int64, archivedFilter string, limit, offset int) ([]Thread, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
@@ -423,8 +442,8 @@ func (s *Store) ListThreadsPage(ctx context.Context, telegramUserID int64, inclu
|
||||
SELECT id, telegram_user_id, codex_thread_id, workspace_id, title, archived, created_at, updated_at
|
||||
FROM threads WHERE telegram_user_id = ?`
|
||||
args := []any{telegramUserID}
|
||||
if !includeArchived {
|
||||
query += " AND archived = 0"
|
||||
if archivedFilter != "" {
|
||||
query += " AND " + archivedFilter
|
||||
}
|
||||
query += " ORDER BY updated_at DESC, id DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
@@ -453,6 +472,52 @@ WHERE telegram_user_id = ? AND id = ?`, telegramUserID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) UnarchiveThread(ctx context.Context, telegramUserID, id int64) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE threads SET archived = 0, updated_at = datetime('now')
|
||||
WHERE telegram_user_id = ? AND id = ?`, telegramUserID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) DeleteThread(ctx context.Context, telegramUserID, id int64) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
UPDATE sessions SET active_thread_id = NULL, active_turn_id = '', updated_at = datetime('now')
|
||||
WHERE telegram_user_id = ? AND active_thread_id = ?`, telegramUserID, id); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
DELETE FROM threads
|
||||
WHERE telegram_user_id = ? AND id = ?`, telegramUserID, id); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) DeleteThreadByCodexID(ctx context.Context, codexThreadID string) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
UPDATE sessions
|
||||
SET active_thread_id = NULL, active_turn_id = '', updated_at = datetime('now')
|
||||
WHERE active_thread_id IN (SELECT id FROM threads WHERE codex_thread_id = ?)`, codexThreadID); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, "DELETE FROM threads WHERE codex_thread_id = ?", codexThreadID); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) TouchThread(ctx context.Context, codexThreadID string) error {
|
||||
_, err := s.db.ExecContext(ctx, "UPDATE threads SET updated_at = datetime('now') WHERE codex_thread_id = ?", codexThreadID)
|
||||
return err
|
||||
|
||||
@@ -109,6 +109,42 @@ func TestStoreUsersWorkspacesSessions(t *testing.T) {
|
||||
if session.ActiveTurnID != "" {
|
||||
t.Fatalf("active turn not cleared: %+v", session)
|
||||
}
|
||||
if err := st.SetActiveThread(ctx, 42, thread.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := st.SetActiveTurn(ctx, 42, "turn-delete"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := st.DeleteThread(ctx, 42, thread.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session, err = st.GetSession(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if session.ActiveThreadID != 0 || session.ActiveTurnID != "" {
|
||||
t.Fatalf("delete should clear active thread and turn: %+v", session)
|
||||
}
|
||||
if _, err := st.GetThreadByID(ctx, 42, thread.ID); err == nil {
|
||||
t.Fatal("deleted thread should not be found")
|
||||
}
|
||||
thread, err = st.CreateThread(ctx, 42, "codex-thread-delete-by-id", ws.ID, "delete by codex id")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := st.SetActiveThread(ctx, 42, thread.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := st.DeleteThreadByCodexID(ctx, "codex-thread-delete-by-id"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session, err = st.GetSession(ctx, 42)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if session.ActiveThreadID != 0 {
|
||||
t.Fatalf("delete by codex id should clear active thread: %+v", session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListThreadsPage(t *testing.T) {
|
||||
@@ -142,6 +178,26 @@ func TestListThreadsPage(t *testing.T) {
|
||||
if len(threads) != 1 {
|
||||
t.Fatalf("got %d threads on second page, want 1", len(threads))
|
||||
}
|
||||
if err := st.ArchiveThread(ctx, 42, threads[0].ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
archived, err := st.ListArchivedThreadsPage(ctx, 42, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(archived) != 1 || !archived[0].Archived {
|
||||
t.Fatalf("archived threads = %+v, want one archived thread", archived)
|
||||
}
|
||||
if err := st.UnarchiveThread(ctx, 42, archived[0].ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
archived, err = st.ListArchivedThreadsPage(ctx, 42, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(archived) != 0 {
|
||||
t.Fatalf("archived threads after unarchive = %+v, want none", archived)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenameThread(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user