feat(pico): add support for tool_calls in chat messages

This commit is contained in:
lc6464
2026-04-25 23:43:10 +08:00
parent 77be169db4
commit 5cd10b594a
20 changed files with 815 additions and 409 deletions
+2
View File
@@ -111,8 +111,10 @@ const (
sessionKeyAgentPrefix = "agent:"
pendingTurnPrefix = "pending-"
metadataKeyMessageKind = "message_kind"
metadataKeyToolCalls = "tool_calls"
messageKindThought = "thought"
messageKindToolFeedback = "tool_feedback"
messageKindToolCalls = "tool_calls"
metadataKeyAccountID = "account_id"
metadataKeyGuildID = "guild_id"
metadataKeyTeamID = "team_id"
+90
View File
@@ -4,13 +4,17 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
)
func (al *AgentLoop) maybePublishError(ctx context.Context, channel, chatID, sessionKey string, err error) bool {
@@ -123,6 +127,92 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent,
}
}
func (al *AgentLoop) publishPicoToolCallInterim(
ctx context.Context,
ts *turnState,
reasoningContent string,
content string,
toolCalls []providers.ToolCall,
) {
if ts == nil || ts.chatID == "" || al == nil || al.bus == nil {
return
}
if strings.TrimSpace(reasoningContent) != "" {
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
err := al.bus.PublishOutbound(
pubCtx,
outboundMessageForTurnWithKind(ts, reasoningContent, messageKindThought),
)
pubCancel()
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
!errors.Is(err, context.Canceled) &&
!errors.Is(err, bus.ErrBusClosed) {
logger.WarnCF("agent", "Failed to publish pico reasoning", map[string]any{
"channel": ts.channel,
"chat_id": ts.chatID,
"error": err.Error(),
})
}
}
if !ts.opts.AllowInterimPicoPublish {
return
}
if strings.TrimSpace(content) != "" {
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
err := al.bus.PublishOutbound(pubCtx, outboundMessageForTurn(ts, content))
pubCancel()
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
!errors.Is(err, context.Canceled) &&
!errors.Is(err, bus.ErrBusClosed) {
logger.WarnCF("agent", "Failed to publish pico interim assistant content", map[string]any{
"channel": ts.channel,
"chat_id": ts.chatID,
"error": err.Error(),
})
}
}
visibleToolCalls := utils.BuildVisibleToolCalls(
toolCalls,
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
)
if len(visibleToolCalls) == 0 {
return
}
rawToolCalls, err := json.Marshal(visibleToolCalls)
if err != nil {
logger.WarnCF("agent", "Failed to serialize pico tool calls", map[string]any{
"channel": ts.channel,
"chat_id": ts.chatID,
"error": err.Error(),
})
return
}
msg := outboundMessageForTurnWithKind(ts, "", messageKindToolCalls)
if msg.Context.Raw == nil {
msg.Context.Raw = map[string]string{}
}
msg.Context.Raw[metadataKeyToolCalls] = string(rawToolCalls)
pubCtx, pubCancel := context.WithTimeout(ctx, 3*time.Second)
err = al.bus.PublishOutbound(pubCtx, msg)
pubCancel()
if err != nil && !errors.Is(err, context.DeadlineExceeded) &&
!errors.Is(err, context.Canceled) &&
!errors.Is(err, bus.ErrBusClosed) {
logger.WarnCF("agent", "Failed to publish pico tool calls", map[string]any{
"channel": ts.channel,
"chat_id": ts.chatID,
"error": err.Error(),
})
}
}
func (al *AgentLoop) handleReasoning(
ctx context.Context,
reasoningContent, channelName, channelID string,
+33 -16
View File
@@ -3987,6 +3987,7 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
select {
case outbound := <-msgBus.OutboundChan():
escapedHeartbeatFile := strings.ReplaceAll(heartbeatFile, `\`, `\\`)
if outbound.Channel != "telegram" {
t.Fatalf("tool feedback channel = %q, want %q", outbound.Channel, "telegram")
}
@@ -4008,7 +4009,7 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
if !strings.Contains(outbound.Content, "\"path\":") {
t.Fatalf("tool feedback content = %q, want serialized tool arguments", outbound.Content)
}
if !strings.Contains(outbound.Content, heartbeatFile) {
if !strings.Contains(outbound.Content, escapedHeartbeatFile) {
t.Fatalf("tool feedback content = %q, want tool argument value", outbound.Content)
}
if strings.Contains(outbound.Content, "Previous turn explanation") {
@@ -4250,6 +4251,7 @@ func TestProcessMessage_DoesNotLeakReasoningContentInToolFeedback(t *testing.T)
select {
case outbound := <-msgBus.OutboundChan():
escapedHeartbeatFile := strings.ReplaceAll(heartbeatFile, `\`, `\\`)
if !strings.Contains(outbound.Content, "`read_file`") {
t.Fatalf("tool feedback content = %q, want read_file summary", outbound.Content)
}
@@ -4262,7 +4264,7 @@ func TestProcessMessage_DoesNotLeakReasoningContentInToolFeedback(t *testing.T)
if !strings.Contains(outbound.Content, "\"path\":") {
t.Fatalf("tool feedback content = %q, want serialized tool arguments", outbound.Content)
}
if !strings.Contains(outbound.Content, heartbeatFile) {
if !strings.Contains(outbound.Content, escapedHeartbeatFile) {
t.Fatalf("tool feedback content = %q, want tool argument value", outbound.Content)
}
if strings.Contains(outbound.Content, "Read README.md first") {
@@ -4422,22 +4424,28 @@ func TestRun_PicoPublishesAssistantContentDuringToolCallsWithoutFinalDuplicate(t
t.Fatalf("PublishInbound() error = %v", err)
}
outputs := make([]string, 0, 2)
outputs := make([]bus.OutboundMessage, 0, 3)
deadline := time.After(2 * time.Second)
for len(outputs) < 2 {
for len(outputs) < 3 {
select {
case outbound := <-msgBus.OutboundChan():
outputs = append(outputs, outbound.Content)
outputs = append(outputs, outbound)
case <-deadline:
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
}
}
if outputs[0] != "intermediate model text" {
t.Fatalf("first outbound content = %q, want %q", outputs[0], "intermediate model text")
if outputs[0].Content != "intermediate model text" {
t.Fatalf("first outbound content = %q, want %q", outputs[0].Content, "intermediate model text")
}
if outputs[1] != "final model text" {
t.Fatalf("second outbound content = %q, want %q", outputs[1], "final model text")
if outputs[1].Context.Raw[metadataKeyMessageKind] != messageKindToolCalls {
t.Fatalf("second outbound = %+v, want tool_calls message", outputs[1])
}
if !strings.Contains(outputs[1].Context.Raw[metadataKeyToolCalls], "tool_limit_test_tool") {
t.Fatalf("second outbound tool_calls = %q, want tool name", outputs[1].Context.Raw[metadataKeyToolCalls])
}
if outputs[2].Content != "final model text" {
t.Fatalf("third outbound content = %q, want %q", outputs[2].Content, "final model text")
}
runCancel()
@@ -4552,22 +4560,31 @@ func TestRun_PicoToolFeedbackSuppressesDuplicateInterimAssistantContent(t *testi
t.Fatalf("PublishInbound() error = %v", err)
}
outputs := make([]string, 0, 2)
outputs := make([]bus.OutboundMessage, 0, 3)
deadline := time.After(2 * time.Second)
for len(outputs) < 2 {
for len(outputs) < 3 {
select {
case outbound := <-msgBus.OutboundChan():
outputs = append(outputs, outbound.Content)
outputs = append(outputs, outbound)
case <-deadline:
t.Fatalf("timed out waiting for pico outputs, got %v", outputs)
}
}
if outputs[0] != "🔧 `tool_limit_test_tool`\nintermediate model text\n```json\n{\n \"value\": \"x\"\n}\n```" {
t.Fatalf("first outbound content = %q, want tool feedback summary", outputs[0])
if outputs[0].Content != "intermediate model text" {
t.Fatalf("first outbound content = %q, want %q", outputs[0].Content, "intermediate model text")
}
if outputs[1] != "final model text" {
t.Fatalf("second outbound content = %q, want %q", outputs[1], "final model text")
if outputs[1].Context.Raw[metadataKeyMessageKind] != messageKindToolCalls {
t.Fatalf("second outbound = %+v, want tool_calls message", outputs[1])
}
if outputs[1].Content != "" {
t.Fatalf("second outbound content = %q, want empty tool_calls content", outputs[1].Content)
}
if !strings.Contains(outputs[1].Context.Raw[metadataKeyToolCalls], "tool_limit_test_tool") {
t.Fatalf("second outbound tool_calls = %q, want tool name", outputs[1].Context.Raw[metadataKeyToolCalls])
}
if outputs[2].Content != "final model text" {
t.Fatalf("third outbound content = %q, want %q", outputs[2].Content, "final model text")
}
runCancel()
+2 -2
View File
@@ -80,7 +80,7 @@ toolLoop:
},
)
if shouldPublishToolFeedback(al.cfg, ts) {
if shouldPublishToolFeedback(al.cfg, ts) && ts.channel != "pico" {
toolFeedbackMaxLen := al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength()
toolFeedbackExplanation := toolFeedbackExplanationForToolCall(
exec.response,
@@ -362,7 +362,7 @@ toolLoop:
},
)
if shouldPublishToolFeedback(al.cfg, ts) {
if shouldPublishToolFeedback(al.cfg, ts) && ts.channel != "pico" {
toolFeedbackMaxLen := al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength()
toolFeedbackExplanation := toolFeedbackExplanationForToolCall(
exec.response,
+14 -26
View File
@@ -10,7 +10,6 @@ import (
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -383,7 +382,11 @@ func (p *Pipeline) CallLLM(
}
reasoningContent := responseReasoningContent(exec.response)
if ts.channel == "pico" {
shouldPublishPicoToolCallInterim := ts.channel == "pico" && len(exec.response.ToolCalls) > 0
if shouldPublishPicoToolCallInterim {
// Pico tool-call turns publish their reasoning/content/tool summary as a
// structured sequence after the tool-call payload is normalized below.
} else if ts.channel == "pico" {
go al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID)
} else {
go al.handleReasoning(
@@ -419,30 +422,6 @@ func (p *Pipeline) CallLLM(
}
logger.DebugCF("agent", "LLM response", llmResponseFields)
if al.bus != nil &&
ts.channel == "pico" &&
len(exec.response.ToolCalls) > 0 &&
ts.opts.AllowInterimPicoPublish &&
!shouldPublishToolFeedback(al.cfg, ts) {
if strings.TrimSpace(exec.response.Content) != "" {
outCtx, outCancel := context.WithTimeout(turnCtx, 3*time.Second)
publishErr := al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Content: exec.response.Content,
})
outCancel()
if publishErr != nil {
logger.WarnCF("agent", "Failed to publish pico interim tool-call content", map[string]any{
"error": publishErr.Error(),
"channel": ts.channel,
"chat_id": ts.chatID,
"iteration": iteration,
})
}
}
}
// No-tool-call path: steering check and direct response
if len(exec.response.ToolCalls) == 0 || exec.gracefulTerminal {
responseContent := exec.response.Content
@@ -531,6 +510,15 @@ func (p *Pipeline) CallLLM(
ts.recordPersistedMessage(assistantMsg)
ts.ingestMessage(turnCtx, al, assistantMsg)
}
if shouldPublishPicoToolCallInterim {
al.publishPicoToolCallInterim(
turnCtx,
ts,
reasoningContent,
exec.response.Content,
assistantMsg.ToolCalls,
)
}
return ControlToolLoop, nil
}
+31 -1
View File
@@ -23,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
// picoConn represents a single WebSocket connection.
@@ -57,8 +58,17 @@ func outboundMessageIsToolFeedback(msg bus.OutboundMessage) bool {
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), "tool_feedback")
}
func outboundMessageIsToolCalls(msg bus.OutboundMessage) bool {
if len(msg.Context.Raw) == 0 {
return false
}
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindToolCalls)
}
func outboundMessageFinalizesTrackedToolFeedback(msg bus.OutboundMessage) bool {
return !outboundMessageIsToolFeedback(msg) && !outboundMessageIsThought(msg)
return !outboundMessageIsToolFeedback(msg) &&
!outboundMessageIsThought(msg) &&
!outboundMessageIsToolCalls(msg)
}
// writeJSON sends a JSON message to the connection with write locking.
@@ -289,6 +299,7 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
}
isThought := outboundMessageIsThought(msg)
isToolFeedback := outboundMessageIsToolFeedback(msg)
isToolCalls := outboundMessageIsToolCalls(msg)
if isToolFeedback {
if msgID, handled, err := c.progress.Update(ctx, msg.ChatID, msg.Content); handled {
if err != nil {
@@ -315,6 +326,12 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
PayloadKeyThought: isThought,
"message_id": msgID,
}
if isToolCalls {
payload[PayloadKeyKind] = MessageKindToolCalls
if toolCalls, ok := picoToolCallsPayload(msg); ok {
payload[PayloadKeyToolCalls] = toolCalls
}
}
setContextUsagePayload(payload, msg.ContextUsage)
outMsg := newMessage(TypeMessageCreate, payload)
@@ -1070,6 +1087,19 @@ func setContextUsagePayload(payload map[string]any, u *bus.ContextUsage) {
}
}
func picoToolCallsPayload(msg bus.OutboundMessage) ([]utils.VisibleToolCall, bool) {
raw := strings.TrimSpace(msg.Context.Raw[PayloadKeyToolCalls])
if raw == "" {
return nil, false
}
var toolCalls []utils.VisibleToolCall
if err := json.Unmarshal([]byte(raw), &toolCalls); err != nil || len(toolCalls) == 0 {
return nil, false
}
return toolCalls, true
}
func (c *PicoChannel) editMessage(
ctx context.Context,
chatID string,
+6 -3
View File
@@ -19,10 +19,13 @@ const (
TypeError = "error"
TypePong = "pong"
PayloadKeyContent = "content"
PayloadKeyThought = "thought"
PayloadKeyContent = "content"
PayloadKeyThought = "thought"
PayloadKeyKind = "kind"
PayloadKeyToolCalls = "tool_calls"
MessageKindThought = "thought"
MessageKindThought = "thought"
MessageKindToolCalls = "tool_calls"
)
// PicoMessage is the wire format for all Pico Protocol messages.
+109
View File
@@ -0,0 +1,109 @@
package utils
import (
"bytes"
"encoding/json"
"strings"
"github.com/sipeed/picoclaw/pkg/providers"
)
type VisibleToolCall struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function *VisibleToolCallFunction `json:"function,omitempty"`
ExtraContent *VisibleToolCallExtraContent `json:"extra_content,omitempty"`
}
type VisibleToolCallFunction struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type VisibleToolCallExtraContent struct {
ToolFeedbackExplanation string `json:"tool_feedback_explanation,omitempty"`
}
func BuildVisibleToolCalls(
toolCalls []providers.ToolCall,
maxArgsLen int,
) []VisibleToolCall {
if len(toolCalls) == 0 {
return nil
}
visible := make([]VisibleToolCall, 0, len(toolCalls))
for _, tc := range toolCalls {
name, _ := VisibleToolCallNameAndArguments(tc)
argsPreview := VisibleToolCallArgumentsPreview(tc, maxArgsLen)
explanation := ""
if tc.ExtraContent != nil {
explanation = strings.TrimSpace(tc.ExtraContent.ToolFeedbackExplanation)
if maxArgsLen > 0 {
explanation = Truncate(explanation, maxArgsLen)
}
}
if name == "" && explanation == "" && argsPreview == "" {
continue
}
visibleCall := VisibleToolCall{
ID: strings.TrimSpace(tc.ID),
Type: strings.TrimSpace(tc.Type),
}
if visibleCall.Type == "" {
visibleCall.Type = "function"
}
if name != "" || argsPreview != "" {
visibleCall.Function = &VisibleToolCallFunction{
Name: name,
Arguments: argsPreview,
}
}
if explanation != "" {
visibleCall.ExtraContent = &VisibleToolCallExtraContent{
ToolFeedbackExplanation: explanation,
}
}
visible = append(visible, visibleCall)
}
if len(visible) == 0 {
return nil
}
return visible
}
func VisibleToolCallNameAndArguments(tc providers.ToolCall) (string, string) {
name := strings.TrimSpace(tc.Name)
argsJSON := ""
if tc.Function != nil {
if name == "" {
name = strings.TrimSpace(tc.Function.Name)
}
argsJSON = strings.TrimSpace(tc.Function.Arguments)
}
if argsJSON == "" && len(tc.Arguments) > 0 {
if encodedArgs, err := json.Marshal(tc.Arguments); err == nil {
argsJSON = string(encodedArgs)
}
}
return name, strings.TrimSpace(argsJSON)
}
func VisibleToolCallArgumentsPreview(tc providers.ToolCall, maxLen int) string {
_, argsJSON := VisibleToolCallNameAndArguments(tc)
if argsJSON == "" {
return ""
}
var pretty bytes.Buffer
if err := json.Indent(&pretty, []byte(argsJSON), "", " "); err == nil {
argsJSON = pretty.String()
}
if maxLen > 0 {
return Truncate(argsJSON, maxLen)
}
return argsJSON
}