mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(pico): add support for tool_calls in chat messages
This commit is contained in:
@@ -111,8 +111,10 @@ const (
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
pendingTurnPrefix = "pending-"
|
||||
metadataKeyMessageKind = "message_kind"
|
||||
metadataKeyToolCalls = "tool_calls"
|
||||
messageKindThought = "thought"
|
||||
messageKindToolFeedback = "tool_feedback"
|
||||
messageKindToolCalls = "tool_calls"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
|
||||
@@ -4,13 +4,17 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
func (al *AgentLoop) maybePublishError(ctx context.Context, channel, chatID, sessionKey string, err error) bool {
|
||||
@@ -123,6 +127,92 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) publishPicoToolCallInterim(
|
||||
ctx context.Context,
|
||||
ts *turnState,
|
||||
reasoningContent string,
|
||||
content string,
|
||||
toolCalls []providers.ToolCall,
|
||||
) {
|
||||
if ts == nil || ts.chatID == "" || al == nil || al.bus == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(reasoningContent) != "" {
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err := al.bus.PublishOutbound(
|
||||
pubCtx,
|
||||
outboundMessageForTurnWithKind(ts, reasoningContent, messageKindThought),
|
||||
)
|
||||
pubCancel()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
|
||||
!errors.Is(err, context.Canceled) &&
|
||||
!errors.Is(err, bus.ErrBusClosed) {
|
||||
logger.WarnCF("agent", "Failed to publish pico reasoning", map[string]any{
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if !ts.opts.AllowInterimPicoPublish {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(content) != "" {
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err := al.bus.PublishOutbound(pubCtx, outboundMessageForTurn(ts, content))
|
||||
pubCancel()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
|
||||
!errors.Is(err, context.Canceled) &&
|
||||
!errors.Is(err, bus.ErrBusClosed) {
|
||||
logger.WarnCF("agent", "Failed to publish pico interim assistant content", map[string]any{
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
visibleToolCalls := utils.BuildVisibleToolCalls(
|
||||
toolCalls,
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
if len(visibleToolCalls) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
rawToolCalls, err := json.Marshal(visibleToolCalls)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to serialize pico tool calls", map[string]any{
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
msg := outboundMessageForTurnWithKind(ts, "", messageKindToolCalls)
|
||||
if msg.Context.Raw == nil {
|
||||
msg.Context.Raw = map[string]string{}
|
||||
}
|
||||
msg.Context.Raw[metadataKeyToolCalls] = string(rawToolCalls)
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
err = al.bus.PublishOutbound(pubCtx, msg)
|
||||
pubCancel()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
|
||||
!errors.Is(err, context.Canceled) &&
|
||||
!errors.Is(err, bus.ErrBusClosed) {
|
||||
logger.WarnCF("agent", "Failed to publish pico tool calls", map[string]any{
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, channelName, channelID string,
|
||||
|
||||
+33
-16
@@ -3987,6 +3987,7 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
escapedHeartbeatFile := strings.ReplaceAll(heartbeatFile, `\`, `\\`)
|
||||
if outbound.Channel != "telegram" {
|
||||
t.Fatalf("tool feedback channel = %q, want %q", outbound.Channel, "telegram")
|
||||
}
|
||||
@@ -4008,7 +4009,7 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
if !strings.Contains(outbound.Content, "\"path\":") {
|
||||
t.Fatalf("tool feedback content = %q, want serialized tool arguments", outbound.Content)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, heartbeatFile) {
|
||||
if !strings.Contains(outbound.Content, escapedHeartbeatFile) {
|
||||
t.Fatalf("tool feedback content = %q, want tool argument value", outbound.Content)
|
||||
}
|
||||
if strings.Contains(outbound.Content, "Previous turn explanation") {
|
||||
@@ -4250,6 +4251,7 @@ func TestProcessMessage_DoesNotLeakReasoningContentInToolFeedback(t *testing.T)
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
escapedHeartbeatFile := strings.ReplaceAll(heartbeatFile, `\`, `\\`)
|
||||
if !strings.Contains(outbound.Content, "`read_file`") {
|
||||
t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content)
|
||||
}
|
||||
@@ -4262,7 +4264,7 @@ func TestProcessMessage_DoesNotLeakReasoningContentInToolFeedback(t *testing.T)
|
||||
if !strings.Contains(outbound.Content, "\"path\":") {
|
||||
t.Fatalf("tool feedback content = %q, want serialized tool arguments", outbound.Content)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, heartbeatFile) {
|
||||
if !strings.Contains(outbound.Content, escapedHeartbeatFile) {
|
||||
t.Fatalf("tool feedback content = %q, want tool argument value", outbound.Content)
|
||||
}
|
||||
if strings.Contains(outbound.Content, "Read README.md first") {
|
||||
@@ -4422,22 +4424,28 @@ func TestRun_PicoPublishesAssistantContentDuringToolCallsWithoutFinalDuplicate(t
|
||||
t.Fatalf("PublishInbound() error = %v", err)
|
||||
}
|
||||
|
||||
outputs := make([]string, 0, 2)
|
||||
outputs := make([]bus.OutboundMessage, 0, 3)
|
||||
deadline := time.After(2 * time.Second)
|
||||
for len(outputs) < 2 {
|
||||
for len(outputs) < 3 {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
outputs = append(outputs, outbound.Content)
|
||||
outputs = append(outputs, outbound)
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
|
||||
}
|
||||
}
|
||||
|
||||
if outputs[0] != "intermediate model text" {
|
||||
t.Fatalf("first outbound content = %q, want %q", outputs[0], "intermediate model text")
|
||||
if outputs[0].Content != "intermediate model text" {
|
||||
t.Fatalf("first outbound content = %q, want %q", outputs[0].Content, "intermediate model text")
|
||||
}
|
||||
if outputs[1] != "final model text" {
|
||||
t.Fatalf("second outbound content = %q, want %q", outputs[1], "final model text")
|
||||
if outputs[1].Context.Raw[metadataKeyMessageKind] != messageKindToolCalls {
|
||||
t.Fatalf("second outbound = %+v, want tool_calls message", outputs[1])
|
||||
}
|
||||
if !strings.Contains(outputs[1].Context.Raw[metadataKeyToolCalls], "tool_limit_test_tool") {
|
||||
t.Fatalf("second outbound tool_calls = %q, want tool name", outputs[1].Context.Raw[metadataKeyToolCalls])
|
||||
}
|
||||
if outputs[2].Content != "final model text" {
|
||||
t.Fatalf("third outbound content = %q, want %q", outputs[2].Content, "final model text")
|
||||
}
|
||||
|
||||
runCancel()
|
||||
@@ -4552,22 +4560,31 @@ func TestRun_PicoToolFeedbackSuppressesDuplicateInterimAssistantContent(t *testi
|
||||
t.Fatalf("PublishInbound() error = %v", err)
|
||||
}
|
||||
|
||||
outputs := make([]string, 0, 2)
|
||||
outputs := make([]bus.OutboundMessage, 0, 3)
|
||||
deadline := time.After(2 * time.Second)
|
||||
for len(outputs) < 2 {
|
||||
for len(outputs) < 3 {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
outputs = append(outputs, outbound.Content)
|
||||
outputs = append(outputs, outbound)
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
|
||||
}
|
||||
}
|
||||
|
||||
if outputs[0] != "🔧 `tool_limit_test_tool`\nintermediate model text\n```json\n{\n \"value\": \"x\"\n}\n```" {
|
||||
t.Fatalf("first outbound content = %q, want tool feedback summary", outputs[0])
|
||||
if outputs[0].Content != "intermediate model text" {
|
||||
t.Fatalf("first outbound content = %q, want %q", outputs[0].Content, "intermediate model text")
|
||||
}
|
||||
if outputs[1] != "final model text" {
|
||||
t.Fatalf("second outbound content = %q, want %q", outputs[1], "final model text")
|
||||
if outputs[1].Context.Raw[metadataKeyMessageKind] != messageKindToolCalls {
|
||||
t.Fatalf("second outbound = %+v, want tool_calls message", outputs[1])
|
||||
}
|
||||
if outputs[1].Content != "" {
|
||||
t.Fatalf("second outbound content = %q, want empty tool_calls content", outputs[1].Content)
|
||||
}
|
||||
if !strings.Contains(outputs[1].Context.Raw[metadataKeyToolCalls], "tool_limit_test_tool") {
|
||||
t.Fatalf("second outbound tool_calls = %q, want tool name", outputs[1].Context.Raw[metadataKeyToolCalls])
|
||||
}
|
||||
if outputs[2].Content != "final model text" {
|
||||
t.Fatalf("third outbound content = %q, want %q", outputs[2].Content, "final model text")
|
||||
}
|
||||
|
||||
runCancel()
|
||||
|
||||
@@ -80,7 +80,7 @@ toolLoop:
|
||||
},
|
||||
)
|
||||
|
||||
if shouldPublishToolFeedback(al.cfg, ts) {
|
||||
if shouldPublishToolFeedback(al.cfg, ts) && ts.channel != "pico" {
|
||||
toolFeedbackMaxLen := al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength()
|
||||
toolFeedbackExplanation := toolFeedbackExplanationForToolCall(
|
||||
exec.response,
|
||||
@@ -362,7 +362,7 @@ toolLoop:
|
||||
},
|
||||
)
|
||||
|
||||
if shouldPublishToolFeedback(al.cfg, ts) {
|
||||
if shouldPublishToolFeedback(al.cfg, ts) && ts.channel != "pico" {
|
||||
toolFeedbackMaxLen := al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength()
|
||||
toolFeedbackExplanation := toolFeedbackExplanationForToolCall(
|
||||
exec.response,
|
||||
|
||||
+14
-26
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -383,7 +382,11 @@ func (p *Pipeline) CallLLM(
|
||||
}
|
||||
|
||||
reasoningContent := responseReasoningContent(exec.response)
|
||||
if ts.channel == "pico" {
|
||||
shouldPublishPicoToolCallInterim := ts.channel == "pico" && len(exec.response.ToolCalls) > 0
|
||||
if shouldPublishPicoToolCallInterim {
|
||||
// Pico tool-call turns publish their reasoning/content/tool summary as a
|
||||
// structured sequence after the tool-call payload is normalized below.
|
||||
} else if ts.channel == "pico" {
|
||||
go al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID)
|
||||
} else {
|
||||
go al.handleReasoning(
|
||||
@@ -419,30 +422,6 @@ func (p *Pipeline) CallLLM(
|
||||
}
|
||||
logger.DebugCF("agent", "LLM response", llmResponseFields)
|
||||
|
||||
if al.bus != nil &&
|
||||
ts.channel == "pico" &&
|
||||
len(exec.response.ToolCalls) > 0 &&
|
||||
ts.opts.AllowInterimPicoPublish &&
|
||||
!shouldPublishToolFeedback(al.cfg, ts) {
|
||||
if strings.TrimSpace(exec.response.Content) != "" {
|
||||
outCtx, outCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
publishErr := al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: exec.response.Content,
|
||||
})
|
||||
outCancel()
|
||||
if publishErr != nil {
|
||||
logger.WarnCF("agent", "Failed to publish pico interim tool-call content", map[string]any{
|
||||
"error": publishErr.Error(),
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"iteration": iteration,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No-tool-call path: steering check and direct response
|
||||
if len(exec.response.ToolCalls) == 0 || exec.gracefulTerminal {
|
||||
responseContent := exec.response.Content
|
||||
@@ -531,6 +510,15 @@ func (p *Pipeline) CallLLM(
|
||||
ts.recordPersistedMessage(assistantMsg)
|
||||
ts.ingestMessage(turnCtx, al, assistantMsg)
|
||||
}
|
||||
if shouldPublishPicoToolCallInterim {
|
||||
al.publishPicoToolCallInterim(
|
||||
turnCtx,
|
||||
ts,
|
||||
reasoningContent,
|
||||
exec.response.Content,
|
||||
assistantMsg.ToolCalls,
|
||||
)
|
||||
}
|
||||
|
||||
return ControlToolLoop, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user