mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2525 from afjcjsbx/fix/vision-unsupported-media-stuck
fix(agent): recover after image-input-unsupported failures
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func messagesContainMedia(messages []providers.Message) bool {
|
||||
for _, msg := range messages {
|
||||
for _, ref := range msg.Media {
|
||||
if strings.TrimSpace(ref) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stripMessageMedia(messages []providers.Message) []providers.Message {
|
||||
if !messagesContainMedia(messages) {
|
||||
return messages
|
||||
}
|
||||
stripped := make([]providers.Message, len(messages))
|
||||
for i, msg := range messages {
|
||||
stripped[i] = msg
|
||||
stripped[i].Media = nil
|
||||
}
|
||||
return stripped
|
||||
}
|
||||
|
||||
func isVisionUnsupportedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
// OpenRouter (and OpenAI-compatible) style.
|
||||
if strings.Contains(msg, "no endpoints found that support image input") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Common provider variants.
|
||||
if strings.Contains(msg, "does not support image input") ||
|
||||
strings.Contains(msg, "does not support image inputs") ||
|
||||
strings.Contains(msg, "does not support images") ||
|
||||
strings.Contains(msg, "image input is not supported") ||
|
||||
strings.Contains(msg, "images are not supported") ||
|
||||
strings.Contains(msg, "does not support vision") ||
|
||||
strings.Contains(msg, "unsupported content type: image_url") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Some providers return a generic "invalid" message that still mentions image_url.
|
||||
if strings.Contains(msg, "image_url") && strings.Contains(msg, "invalid") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -2360,6 +2360,8 @@ turnLoop:
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
maxRetries := 2
|
||||
callHasMedia := messagesContainMedia(callMessages)
|
||||
didStripMedia := false
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
response, err = callLLM(callMessages, providerToolDefs)
|
||||
if err == nil {
|
||||
@@ -2370,6 +2372,45 @@ turnLoop:
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
// If the provider/model doesn't support multimodal inputs, retry once with media stripped
|
||||
// so the session doesn't get "stuck" after a user sends an image.
|
||||
if callHasMedia && !didStripMedia && isVisionUnsupportedError(err) {
|
||||
didStripMedia = true
|
||||
if !ts.opts.NoHistory {
|
||||
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
ts.agent.Sessions.SetHistory(ts.sessionKey, stripMessageMedia(history))
|
||||
|
||||
// Keep persistedMessages aligned so abort restore-point trimming remains correct.
|
||||
ts.mu.Lock()
|
||||
for i := range ts.persistedMessages {
|
||||
ts.persistedMessages[i].Media = nil
|
||||
}
|
||||
ts.mu.Unlock()
|
||||
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
}
|
||||
|
||||
messages = stripMessageMedia(messages)
|
||||
callMessages = stripMessageMedia(callMessages)
|
||||
callHasMedia = false
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: 1,
|
||||
MaxRetries: 1,
|
||||
Reason: "vision_unsupported",
|
||||
Error: err.Error(),
|
||||
Backoff: 0,
|
||||
},
|
||||
)
|
||||
response, err = callLLM(callMessages, providerToolDefs)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
|
||||
strings.Contains(errMsg, "deadline exceeded") ||
|
||||
|
||||
@@ -2565,6 +2565,136 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type visionUnsupportedMediaProvider struct {
|
||||
calls int
|
||||
mediaSeen []bool
|
||||
}
|
||||
|
||||
func (p *visionUnsupportedMediaProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.calls++
|
||||
|
||||
hasMedia := false
|
||||
for _, msg := range messages {
|
||||
for _, ref := range msg.Media {
|
||||
if strings.TrimSpace(ref) != "" {
|
||||
hasMedia = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasMedia {
|
||||
break
|
||||
}
|
||||
}
|
||||
p.mediaSeen = append(p.mediaSeen, hasMedia)
|
||||
|
||||
if hasMedia {
|
||||
return nil, fmt.Errorf("API request failed: " +
|
||||
"Status: 404 Body: {\"error\":{\"message\":\"No endpoints found that support image input\"}}")
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "ok",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *visionUnsupportedMediaProvider) GetDefaultModel() string {
|
||||
return "mock-fail-model"
|
||||
}
|
||||
|
||||
func TestAgentLoop_VisionUnsupportedErrorStripsSessionMedia(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &visionUnsupportedMediaProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
sessionKey := "agent:main:telegram:direct:user1"
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), responseTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := al.processMessage(timeoutCtx, testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
MessageID: "m1",
|
||||
},
|
||||
Content: "describe this",
|
||||
Media: []string{"data:image/png;base64,abc123"},
|
||||
SessionKey: sessionKey,
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if resp != "ok" {
|
||||
t.Fatalf("response = %q, want %q", resp, "ok")
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2)
|
||||
}
|
||||
if !slices.Equal(provider.mediaSeen, []bool{true, false}) {
|
||||
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false})
|
||||
}
|
||||
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
history := agent.Sessions.GetHistory(sessionKey)
|
||||
for i, msg := range history {
|
||||
if len(msg.Media) > 0 {
|
||||
t.Fatalf("history[%d].Media = %v, want no media after stripping", i, msg.Media)
|
||||
}
|
||||
}
|
||||
|
||||
timeoutCtx2, cancel2 := context.WithTimeout(context.Background(), responseTimeout)
|
||||
defer cancel2()
|
||||
|
||||
resp2, err := al.processMessage(timeoutCtx2, testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
MessageID: "m2",
|
||||
},
|
||||
Content: "hello again",
|
||||
SessionKey: sessionKey,
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() second call error = %v", err)
|
||||
}
|
||||
if resp2 != "ok" {
|
||||
t.Fatalf("second response = %q, want %q", resp2, "ok")
|
||||
}
|
||||
if provider.calls != 3 {
|
||||
t.Fatalf("calls after second turn = %d, want %d", provider.calls, 3)
|
||||
}
|
||||
if !slices.Equal(provider.mediaSeen, []bool{true, false, false}) {
|
||||
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false, false})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user