Add prompt-based picture command
This commit is contained in:
@@ -14,6 +14,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type PhotoUpload struct {
|
||||
Filename string
|
||||
Data []byte
|
||||
Caption string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
token string
|
||||
baseURL string
|
||||
@@ -205,6 +211,80 @@ func (c *Client) SendPhotoBytes(ctx context.Context, chatID int64, filename stri
|
||||
return decoded.Result, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendPhotoGroupBytes(ctx context.Context, chatID int64, photos []PhotoUpload) ([]Message, error) {
|
||||
if len(photos) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if len(photos) == 1 {
|
||||
message, err := c.SendPhotoBytes(ctx, chatID, photos[0].Filename, photos[0].Data, photos[0].Caption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []Message{message}, nil
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if err := writer.WriteField("chat_id", fmt.Sprint(chatID)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
media := make([]map[string]string, 0, len(photos))
|
||||
for i, photo := range photos {
|
||||
name := fmt.Sprintf("photo%d", i)
|
||||
entry := map[string]string{
|
||||
"type": "photo",
|
||||
"media": "attach://" + name,
|
||||
}
|
||||
if photo.Caption != "" {
|
||||
entry["caption"] = photo.Caption
|
||||
}
|
||||
media = append(media, entry)
|
||||
}
|
||||
mediaJSON, err := json.Marshal(media)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := writer.WriteField("media", string(mediaJSON)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i, photo := range photos {
|
||||
name := fmt.Sprintf("photo%d", i)
|
||||
part, err := writer.CreateFormFile(name, filepath.Base(photo.Filename))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := part.Write(photo.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/sendMediaGroup", &body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, c.redactError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
payload, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return nil, fmt.Errorf("sendMediaGroup: telegram returned %s: %s", resp.Status, string(payload))
|
||||
}
|
||||
var decoded apiResponse[[]Message]
|
||||
if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !decoded.OK {
|
||||
return nil, fmt.Errorf("sendMediaGroup: telegram error %d: %s", decoded.ErrorCode, decoded.Description)
|
||||
}
|
||||
return decoded.Result, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendDocumentBytes(ctx context.Context, chatID int64, filename string, data []byte, caption string) (Message, error) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
@@ -29,6 +29,7 @@ const (
|
||||
telegramThreadCWDDirectiveStart = "<!-- codex-thread-cwd "
|
||||
telegramDirectiveEnd = " -->"
|
||||
telegramPhotoCaptionLimit = 1024
|
||||
pictureMediaGroupLimit = 10
|
||||
)
|
||||
|
||||
type Bot struct {
|
||||
@@ -73,11 +74,17 @@ type outputState struct {
|
||||
chatID int64
|
||||
assistant strings.Builder
|
||||
sentAny bool
|
||||
pictureRequest bool
|
||||
tools map[string]toolMessageState
|
||||
sentImages map[string]bool
|
||||
generatedImages []generatedImageOutput
|
||||
workingIndicatorOff context.CancelFunc
|
||||
}
|
||||
|
||||
type generatedImageOutput struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
type toolMessageState struct {
|
||||
chatID int64
|
||||
messageID int
|
||||
@@ -216,7 +223,7 @@ func (b *Bot) handleCommand(ctx context.Context, message *Message, session store
|
||||
case "start", "help":
|
||||
return true, b.sendHelp(ctx, chatID)
|
||||
case "new":
|
||||
_, _, err := b.createNewThread(ctx, userID, chatID, session)
|
||||
_, _, err := b.createNewThread(ctx, userID, chatID, session, true)
|
||||
return true, err
|
||||
case "thread", "threads":
|
||||
return true, b.sendThreads(ctx, userID, chatID)
|
||||
@@ -240,6 +247,8 @@ func (b *Bot) handleCommand(ctx context.Context, message *Message, session store
|
||||
return true, b.handleModelCommand(ctx, userID, chatID, session, args)
|
||||
case "sandbox":
|
||||
return true, b.handleSandboxCommand(ctx, userID, chatID, session, args)
|
||||
case "pic":
|
||||
return true, b.handlePictureCommand(ctx, userID, chatID, session, args)
|
||||
case "diff":
|
||||
return true, b.sendDiff(ctx, chatID, session)
|
||||
default:
|
||||
@@ -265,6 +274,7 @@ func (b *Bot) sendHelp(ctx context.Context, chatID int64) error {
|
||||
"/workspace [ID] - select workspace",
|
||||
"/model - choose model and reasoning effort",
|
||||
"/sandbox [read-only|workspace-write|danger-full-access] - show or set sandbox",
|
||||
"/pic PROMPT - generate image(s) from a prompt",
|
||||
"/diff - show the latest streamed diff",
|
||||
"",
|
||||
"Plain text continues the active thread. Images are staged as local Codex image inputs; other files are staged and sent as paths.",
|
||||
@@ -873,7 +883,7 @@ func (b *Bot) syncUserThreadStates(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bot) createNewThread(ctx context.Context, userID, chatID int64, session store.Session) (store.Thread, store.Workspace, error) {
|
||||
func (b *Bot) createNewThread(ctx context.Context, userID, chatID int64, session store.Session, announce bool) (store.Thread, store.Workspace, error) {
|
||||
workspace, err := b.resolveWorkspace(ctx, userID, session)
|
||||
if err != nil {
|
||||
return store.Thread{}, store.Workspace{}, b.sendWorkspaceMissing(ctx, chatID)
|
||||
@@ -899,7 +909,9 @@ func (b *Bot) createNewThread(ctx context.Context, userID, chatID int64, session
|
||||
if err := b.store.SetSessionWorkspace(ctx, userID, thread.WorkspaceID); err != nil {
|
||||
return store.Thread{}, store.Workspace{}, err
|
||||
}
|
||||
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, threadWorkspace.Label), SendMessageOptions{})
|
||||
if announce {
|
||||
_, err = b.tg.SendMessage(ctx, chatID, fmt.Sprintf("New thread #%d in %s.", thread.ID, threadWorkspace.Label), SendMessageOptions{})
|
||||
}
|
||||
return thread, threadWorkspace, err
|
||||
}
|
||||
|
||||
@@ -916,7 +928,54 @@ func (b *Bot) ensureThread(ctx context.Context, userID, chatID int64, session st
|
||||
return thread, workspace, err
|
||||
}
|
||||
}
|
||||
return b.createNewThread(ctx, userID, chatID, session)
|
||||
return b.createNewThread(ctx, userID, chatID, session, true)
|
||||
}
|
||||
|
||||
func (b *Bot) ensureThreadForPicture(ctx context.Context, userID, chatID int64, session store.Session) (store.Thread, store.Workspace, error) {
|
||||
if session.ActiveThreadID != 0 {
|
||||
return b.ensureThread(ctx, userID, chatID, session)
|
||||
}
|
||||
return b.createNewThread(ctx, userID, chatID, session, false)
|
||||
}
|
||||
|
||||
func (b *Bot) handlePictureCommand(ctx context.Context, userID, chatID int64, session store.Session, args []string) error {
|
||||
prompt := strings.TrimSpace(strings.Join(args, " "))
|
||||
if prompt == "" {
|
||||
_, err := b.tg.SendMessage(ctx, chatID, "Use /pic PROMPT to generate image(s).", SendMessageOptions{})
|
||||
return err
|
||||
}
|
||||
if session.ActiveTurnID != "" {
|
||||
_, err := b.tg.SendMessage(ctx, chatID, "A Codex turn is already running. Use /cancel first, or wait for it to finish.", SendMessageOptions{})
|
||||
return err
|
||||
}
|
||||
thread, _, err := b.ensureThreadForPicture(ctx, userID, chatID, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := []codexapp.InputItem{{Type: "text", Text: pictureGenerationInstruction(prompt)}}
|
||||
b.registerPictureOutput(thread.CodexThreadID, chatID)
|
||||
turn, err := b.codex.StartTurn(ctx, thread.CodexThreadID, "", session.Model, session.ReasoningEffort, session.Sandbox, input)
|
||||
if err != nil {
|
||||
b.clearOutput(thread.CodexThreadID)
|
||||
return b.sendError(ctx, chatID, "Codex image generation failed", err)
|
||||
}
|
||||
if err := b.store.SetActiveTurn(ctx, userID, turn.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = b.store.TouchThread(ctx, thread.CodexThreadID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func pictureGenerationInstruction(prompt string) string {
|
||||
return strings.Join([]string{
|
||||
"You are handling a Telegram /pic command.",
|
||||
"Use only the built-in image generation capability to create image(s) from the user prompt below.",
|
||||
"Do not browse the web, run shell commands, call MCP tools, edit files, or ask follow-up questions.",
|
||||
"Avoid extra explanatory text; the Telegram bot will send generated image files automatically.",
|
||||
"",
|
||||
"User image prompt:",
|
||||
strings.TrimSpace(prompt),
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (b *Bot) activeThread(ctx context.Context, userID int64, session store.Session) (store.Thread, error) {
|
||||
@@ -1339,6 +1398,9 @@ func (b *Bot) handleCodexNotification(ctx context.Context, event codexapp.Event)
|
||||
return b.flushAssistantMessage(ctx, params.ThreadID)
|
||||
}
|
||||
if params.ThreadID != "" {
|
||||
if b.shouldSuppressPictureToolMessage(params.ThreadID, item) {
|
||||
return nil
|
||||
}
|
||||
return b.upsertToolMessage(ctx, params.ThreadID, item.ID, renderCodexItemStarted(item))
|
||||
}
|
||||
case "item/agentMessage/delta":
|
||||
@@ -1373,6 +1435,9 @@ func (b *Bot) handleCodexNotification(ctx context.Context, event codexapp.Event)
|
||||
return b.flushAssistantMessage(ctx, params.ThreadID)
|
||||
}
|
||||
if params.ThreadID != "" {
|
||||
if b.queuePictureImageOutput(params.ThreadID, item) {
|
||||
return nil
|
||||
}
|
||||
if err := b.upsertToolMessage(ctx, params.ThreadID, item.ID, renderCodexItemCompleted(item)); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1513,6 +1578,15 @@ func (b *Bot) registerOutput(threadID string, chatID int64) {
|
||||
b.outputs[threadID] = b.newOutputState(chatID)
|
||||
}
|
||||
|
||||
func (b *Bot) registerPictureOutput(threadID string, chatID int64) {
|
||||
b.registerOutput(threadID, chatID)
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if state := b.outputs[threadID]; state != nil {
|
||||
state.pictureRequest = true
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) clearOutput(threadID string) {
|
||||
b.mu.Lock()
|
||||
state := b.outputs[threadID]
|
||||
@@ -1594,6 +1668,38 @@ func (b *Bot) failActiveOutputs(ctx context.Context, message string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) shouldSuppressPictureToolMessage(threadID string, item codexThreadItemView) bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
state := b.outputs[threadID]
|
||||
return state != nil && state.pictureRequest && item.Type == "imageGeneration"
|
||||
}
|
||||
|
||||
func (b *Bot) queuePictureImageOutput(threadID string, item codexThreadItemView) bool {
|
||||
if item.Type != "imageGeneration" {
|
||||
return false
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
state := b.outputs[threadID]
|
||||
if state == nil || !state.pictureRequest {
|
||||
return false
|
||||
}
|
||||
path := strings.TrimSpace(item.SavedPath)
|
||||
if path == "" {
|
||||
return true
|
||||
}
|
||||
if state.sentImages == nil {
|
||||
state.sentImages = make(map[string]bool)
|
||||
}
|
||||
if state.sentImages[path] {
|
||||
return true
|
||||
}
|
||||
state.sentImages[path] = true
|
||||
state.generatedImages = append(state.generatedImages, generatedImageOutput{Path: path})
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *Bot) sendImageOutput(ctx context.Context, threadID string, item codexThreadItemView) error {
|
||||
if item.Type != "imageGeneration" || strings.TrimSpace(item.SavedPath) == "" {
|
||||
return nil
|
||||
@@ -2068,9 +2174,13 @@ func (b *Bot) flushAssistantMessage(ctx context.Context, threadID string) error
|
||||
}
|
||||
chatID := state.chatID
|
||||
text := state.assistant.String()
|
||||
pictureRequest := state.pictureRequest
|
||||
state.assistant.Reset()
|
||||
b.mu.Unlock()
|
||||
|
||||
if pictureRequest {
|
||||
return nil
|
||||
}
|
||||
if err := b.sendAssistantText(ctx, threadID, chatID, text); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2090,6 +2200,8 @@ func (b *Bot) completeTurnOutput(ctx context.Context, threadID string) error {
|
||||
}
|
||||
chatID := state.chatID
|
||||
sentAny := state.sentAny
|
||||
pictureRequest := state.pictureRequest
|
||||
generatedImages := append([]generatedImageOutput(nil), state.generatedImages...)
|
||||
workingIndicatorOff := state.workingIndicatorOff
|
||||
delete(b.outputs, threadID)
|
||||
b.mu.Unlock()
|
||||
@@ -2097,6 +2209,13 @@ func (b *Bot) completeTurnOutput(ctx context.Context, threadID string) error {
|
||||
workingIndicatorOff()
|
||||
}
|
||||
|
||||
if pictureRequest {
|
||||
if len(generatedImages) == 0 {
|
||||
_, err := b.tg.SendMessage(ctx, chatID, "No image was generated.", SendMessageOptions{})
|
||||
return err
|
||||
}
|
||||
return b.sendGeneratedImageOutputs(ctx, chatID, generatedImages)
|
||||
}
|
||||
if !sentAny {
|
||||
_, err := b.tg.SendMessage(ctx, chatID, "Done.", SendMessageOptions{})
|
||||
return err
|
||||
@@ -2104,6 +2223,37 @@ func (b *Bot) completeTurnOutput(ctx context.Context, threadID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bot) sendGeneratedImageOutputs(ctx context.Context, chatID int64, images []generatedImageOutput) error {
|
||||
uploads := make([]PhotoUpload, 0, len(images))
|
||||
for _, image := range images {
|
||||
path := strings.TrimSpace(image.Path)
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
b.logger.Printf("read generated image %s: %v", path, err)
|
||||
continue
|
||||
}
|
||||
uploads = append(uploads, PhotoUpload{Filename: path, Data: data})
|
||||
}
|
||||
if len(uploads) == 0 {
|
||||
_, err := b.tg.SendMessage(ctx, chatID, "Generated image file was not readable by the bot.", SendMessageOptions{})
|
||||
return err
|
||||
}
|
||||
for len(uploads) > 0 {
|
||||
count := len(uploads)
|
||||
if count > pictureMediaGroupLimit {
|
||||
count = pictureMediaGroupLimit
|
||||
}
|
||||
if _, err := b.tg.SendPhotoGroupBytes(ctx, chatID, uploads[:count]); err != nil {
|
||||
return err
|
||||
}
|
||||
uploads = uploads[count:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bot) outputChatID(ctx context.Context, threadID string) (int64, error) {
|
||||
b.mu.Lock()
|
||||
state := b.outputs[threadID]
|
||||
|
||||
@@ -90,6 +90,20 @@ func TestParseCommand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPictureGenerationInstruction(t *testing.T) {
|
||||
instruction := pictureGenerationInstruction("generate a blue cube")
|
||||
for _, want := range []string{"Telegram /pic command", "built-in image generation", "generate a blue cube"} {
|
||||
if !strings.Contains(instruction, want) {
|
||||
t.Fatalf("instruction missing %q in %q", want, instruction)
|
||||
}
|
||||
}
|
||||
for _, unwanted := range []string{"/home", "repo/playground"} {
|
||||
if strings.Contains(instruction, unwanted) {
|
||||
t.Fatalf("instruction contains non-portable text %q: %q", unwanted, instruction)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitAssistantMessageSegmentsWithPhotoDirective(t *testing.T) {
|
||||
photoPath := filepath.Join(string(filepath.Separator), "workspace", "photo.jpg")
|
||||
text := fmt.Sprintf("before\n<!-- telegram-photo {\"path\":%q,\"caption\":\"hello\"} -->\nafter", photoPath)
|
||||
|
||||
Reference in New Issue
Block a user