mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix: eliminate data races on shared tool instances (#1080)
* fix: eliminate data races on shared tool instances Signed-off-by: Boris Bliznioukov <blib@mail.com> * fix: remove unused indirect dependency on github.com/gdamore/tcell/v2 Signed-off-by: Boris Bliznioukov <blib@mail.com> * fix: reviewer comments improve context handling for tool execution and ensure defaults for non-conversation callers Signed-off-by: Boris Bliznioukov <blib@mail.com> --------- Signed-off-by: Boris Bliznioukov <blib@mail.com>
This commit is contained in:
committed by
GitHub
parent
204038ec60
commit
aef1e8e8c4
@@ -37,7 +37,6 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/gdamore/tcell/v2 v2.13.8 // indirect
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
|
||||
+11
-34
@@ -543,8 +543,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
|
||||
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(msg.Channel, msg.ChatID)
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
|
||||
resetter.ResetSentInRound()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -659,10 +659,7 @@ func (al *AgentLoop) runAgentLoop(
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Update tool contexts
|
||||
al.updateToolContexts(agent, opts.Channel, opts.ChatID)
|
||||
|
||||
// 2. Build messages (skip history for heartbeat)
|
||||
// 1. Build messages (skip history for heartbeat)
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !opts.NoHistory {
|
||||
@@ -682,10 +679,10 @@ func (al *AgentLoop) runAgentLoop(
|
||||
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 3. Save user message to session
|
||||
// 2. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
// 4. Run LLM iteration loop
|
||||
// 3. Run LLM iteration loop
|
||||
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -694,21 +691,21 @@ func (al *AgentLoop) runAgentLoop(
|
||||
// If last tool had ForUser content and we already sent it, we might not need to send final response
|
||||
// This is controlled by the tool's Silent flag and ForUser content
|
||||
|
||||
// 5. Handle empty response
|
||||
// 4. Handle empty response
|
||||
if finalContent == "" {
|
||||
finalContent = opts.DefaultResponse
|
||||
}
|
||||
|
||||
// 6. Save final assistant message to session
|
||||
// 5. Save final assistant message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||
agent.Sessions.Save(opts.SessionKey)
|
||||
|
||||
// 7. Optional: summarization
|
||||
// 6. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
// 7. Optional: send response via bus
|
||||
if opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
@@ -717,7 +714,7 @@ func (al *AgentLoop) runAgentLoop(
|
||||
})
|
||||
}
|
||||
|
||||
// 9. Log response
|
||||
// 8. Log response
|
||||
responsePreview := utils.Truncate(finalContent, 120)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
|
||||
map[string]any{
|
||||
@@ -1059,7 +1056,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
// Create async callback for tools that implement AsyncExecutor
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
@@ -1141,26 +1138,6 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return finalContent, iteration, nil
|
||||
}
|
||||
|
||||
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||
func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) {
|
||||
// Use ContextualTool interface instead of type assertions
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := agent.Tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := agent.Tools.Get("subagent"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
|
||||
+11
-55
@@ -164,35 +164,21 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolContext_Updates verifies tool context is updated with channel/chatID
|
||||
// TestToolContext_Updates verifies tool context helpers work correctly
|
||||
func TestToolContext_Updates(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
ctx := tools.WithToolContext(context.Background(), "telegram", "chat-42")
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
if got := tools.ToolChannel(ctx); got != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", got)
|
||||
}
|
||||
if got := tools.ToolChatID(ctx); got != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", got)
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "OK"}
|
||||
_ = NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Verify that ContextualTool interface is defined and can be implemented
|
||||
// This test validates the interface contract exists
|
||||
ctxTool := &mockContextualTool{}
|
||||
|
||||
// Verify the tool implements the interface correctly
|
||||
var _ tools.ContextualTool = ctxTool
|
||||
// Empty context returns empty strings
|
||||
if got := tools.ToolChannel(context.Background()); got != "" {
|
||||
t.Errorf("expected empty channel from bare context, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
|
||||
@@ -359,36 +345,6 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
// mockContextualTool tracks context updates
|
||||
type mockContextualTool struct {
|
||||
lastChannel string
|
||||
lastChatID string
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Name() string {
|
||||
return "mock_contextual"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Description() string {
|
||||
return "Mock contextual tool"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.SilentResult("Contextual tool executed")
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) SetContext(channel, chatID string) {
|
||||
m.lastChannel = channel
|
||||
m.lastChatID = chatID
|
||||
}
|
||||
|
||||
// testHelper executes a message and returns the response
|
||||
type testHelper struct {
|
||||
al *AgentLoop
|
||||
|
||||
+50
-38
@@ -10,11 +10,38 @@ type Tool interface {
|
||||
Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
}
|
||||
|
||||
// ContextualTool is an optional interface that tools can implement
|
||||
// to receive the current message context (channel, chatID)
|
||||
type ContextualTool interface {
|
||||
Tool
|
||||
SetContext(channel, chatID string)
|
||||
// --- Request-scoped tool context (channel / chatID) ---
|
||||
//
|
||||
// Carried via context.Value so that concurrent tool calls each receive
|
||||
// their own immutable copy — no mutable state on singleton tool instances.
|
||||
//
|
||||
// Keys are unexported pointer-typed vars — guaranteed collision-free,
|
||||
// and only accessible through the helper functions below.
|
||||
|
||||
type toolCtxKey struct{ name string }
|
||||
|
||||
var (
|
||||
ctxKeyChannel = &toolCtxKey{"channel"}
|
||||
ctxKeyChatID = &toolCtxKey{"chatID"}
|
||||
)
|
||||
|
||||
// WithToolContext returns a child context carrying channel and chatID.
|
||||
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
|
||||
ctx = context.WithValue(ctx, ctxKeyChannel, channel)
|
||||
ctx = context.WithValue(ctx, ctxKeyChatID, chatID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ToolChannel extracts the channel from ctx, or "" if unset.
|
||||
func ToolChannel(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyChannel).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolChatID extracts the chatID from ctx, or "" if unset.
|
||||
func ToolChatID(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyChatID).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
@@ -22,51 +49,36 @@ type ContextualTool interface {
|
||||
//
|
||||
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
|
||||
// The result parameter contains the tool's execution result.
|
||||
//
|
||||
// Example usage in an async tool:
|
||||
//
|
||||
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// // Start async work in background
|
||||
// go func() {
|
||||
// result := doAsyncWork()
|
||||
// if t.callback != nil {
|
||||
// t.callback(ctx, result)
|
||||
// }
|
||||
// }()
|
||||
// return AsyncResult("Async task started")
|
||||
// }
|
||||
type AsyncCallback func(ctx context.Context, result *ToolResult)
|
||||
|
||||
// AsyncTool is an optional interface that tools can implement to support
|
||||
// AsyncExecutor is an optional interface that tools can implement to support
|
||||
// asynchronous execution with completion callbacks.
|
||||
//
|
||||
// Async tools return immediately with an AsyncResult, then notify completion
|
||||
// via the callback set by SetCallback.
|
||||
// Unlike the old AsyncTool pattern (SetCallback + Execute), AsyncExecutor
|
||||
// receives the callback as a parameter of ExecuteAsync. This eliminates the
|
||||
// data race where concurrent calls could overwrite each other's callbacks
|
||||
// on a shared tool instance.
|
||||
//
|
||||
// This is useful for:
|
||||
// - Long-running operations that shouldn't block the agent loop
|
||||
// - Subagent spawns that complete independently
|
||||
// - Background tasks that need to report results later
|
||||
// - Long-running operations that shouldn't block the agent loop
|
||||
// - Subagent spawns that complete independently
|
||||
// - Background tasks that need to report results later
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type SpawnTool struct {
|
||||
// callback AsyncCallback
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
// t.callback = cb
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// go t.runSubagent(ctx, args)
|
||||
// func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
// go func() {
|
||||
// result := t.runSubagent(ctx, args)
|
||||
// if cb != nil { cb(ctx, result) }
|
||||
// }()
|
||||
// return AsyncResult("Subagent spawned, will report back")
|
||||
// }
|
||||
type AsyncTool interface {
|
||||
type AsyncExecutor interface {
|
||||
Tool
|
||||
// SetCallback registers a callback function to be invoked when the async operation completes.
|
||||
// The callback will be called from a goroutine and should handle thread-safety if needed.
|
||||
SetCallback(cb AsyncCallback)
|
||||
// ExecuteAsync runs the tool asynchronously. The callback cb will be
|
||||
// invoked (possibly from another goroutine) when the async operation
|
||||
// completes. cb is guaranteed to be non-nil by the caller (registry).
|
||||
ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult
|
||||
}
|
||||
|
||||
func ToolToSchema(tool Tool) map[string]any {
|
||||
|
||||
+4
-18
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -24,9 +23,6 @@ type CronTool struct {
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
channel string
|
||||
chatID string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
@@ -102,14 +98,6 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// SetContext sets the current session context for job creation
|
||||
func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.channel = channel
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with the given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
@@ -119,7 +107,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
|
||||
switch action {
|
||||
case "add":
|
||||
return t.addJob(args)
|
||||
return t.addJob(ctx, args)
|
||||
case "list":
|
||||
return t.listJobs()
|
||||
case "remove":
|
||||
@@ -133,11 +121,9 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]any) *ToolResult {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult {
|
||||
channel := ToolChannel(ctx)
|
||||
chatID := ToolChatID(ctx)
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
|
||||
+8
-10
@@ -9,10 +9,8 @@ import (
|
||||
type SendCallback func(channel, chatID, content string) error
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
sendCallback SendCallback
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -48,10 +46,10 @@ func (t *MessageTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
t.sentInRound.Store(false) // Reset send tracking for new processing round
|
||||
// ResetSentInRound resets the per-round send tracker.
|
||||
// Called by the agent loop at the start of each inbound message processing round.
|
||||
func (t *MessageTool) ResetSentInRound() {
|
||||
t.sentInRound.Store(false)
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
@@ -73,10 +71,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
chatID, _ := args["chat_id"].(string)
|
||||
|
||||
if channel == "" {
|
||||
channel = t.defaultChannel
|
||||
channel = ToolChannel(ctx)
|
||||
}
|
||||
if chatID == "" {
|
||||
chatID = t.defaultChatID
|
||||
chatID = ToolChatID(ctx)
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
@@ -18,7 +17,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
@@ -60,7 +59,6 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("default-channel", "default-chat-id")
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
@@ -69,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
"channel": "custom-channel",
|
||||
@@ -96,14 +94,13 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
@@ -133,9 +130,8 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{} // content missing
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
@@ -151,7 +147,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No SetContext called, so defaultChannel and defaultChatID are empty
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return nil
|
||||
@@ -175,10 +171,9 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
// No SetSendCallback called
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
+15
-13
@@ -45,8 +45,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
|
||||
// If the tool implements AsyncTool and a non-nil callback is provided,
|
||||
// the callback will be set on the tool before execution.
|
||||
// If the tool implements AsyncExecutor and a non-nil callback is provided,
|
||||
// ExecuteAsync is called instead of Execute — the callback is a parameter,
|
||||
// never stored as mutable state on the tool.
|
||||
func (r *ToolRegistry) ExecuteWithContext(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
@@ -69,22 +70,23 @@ func (r *ToolRegistry) ExecuteWithContext(
|
||||
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
|
||||
}
|
||||
|
||||
// If tool implements ContextualTool, set context
|
||||
if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" {
|
||||
contextualTool.SetContext(channel, chatID)
|
||||
}
|
||||
// Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx).
|
||||
// Always inject — tools validate what they require.
|
||||
ctx = WithToolContext(ctx, channel, chatID)
|
||||
|
||||
// If tool implements AsyncTool and callback is provided, set callback
|
||||
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
|
||||
asyncTool.SetCallback(asyncCallback)
|
||||
logger.DebugCF("tool", "Async callback injected",
|
||||
// If tool implements AsyncExecutor and callback is provided, use ExecuteAsync.
|
||||
// The callback is a call parameter, not mutable state on the tool instance.
|
||||
var result *ToolResult
|
||||
start := time.Now()
|
||||
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
|
||||
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
})
|
||||
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
|
||||
} else {
|
||||
result = tool.Execute(ctx, args)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result := tool.Execute(ctx, args)
|
||||
duration := time.Since(start)
|
||||
|
||||
// Log based on result type
|
||||
|
||||
+32
-22
@@ -25,24 +25,24 @@ func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolRes
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockCtxTool struct {
|
||||
type mockContextAwareTool struct {
|
||||
mockRegistryTool
|
||||
channel string
|
||||
chatID string
|
||||
lastCtx context.Context
|
||||
}
|
||||
|
||||
func (m *mockCtxTool) SetContext(channel, chatID string) {
|
||||
m.channel = channel
|
||||
m.chatID = chatID
|
||||
func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *ToolResult {
|
||||
m.lastCtx = ctx
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockAsyncRegistryTool struct {
|
||||
mockRegistryTool
|
||||
cb AsyncCallback
|
||||
lastCB AsyncCallback
|
||||
}
|
||||
|
||||
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
|
||||
m.cb = cb
|
||||
func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
m.lastCB = cb
|
||||
return m.result
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
@@ -136,34 +136,44 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
|
||||
func TestToolRegistry_ExecuteWithContext_InjectsToolContext(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
ct := &mockContextAwareTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
|
||||
|
||||
if ct.channel != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", ct.channel)
|
||||
if ct.lastCtx == nil {
|
||||
t.Fatal("expected Execute to be called")
|
||||
}
|
||||
if ct.chatID != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
|
||||
if got := ToolChannel(ct.lastCtx); got != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", got)
|
||||
}
|
||||
if got := ToolChatID(ct.lastCtx); got != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
|
||||
func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
ct := &mockContextAwareTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
|
||||
|
||||
if ct.channel != "" || ct.chatID != "" {
|
||||
t.Error("SetContext should not be called with empty channel/chatID")
|
||||
if ct.lastCtx == nil {
|
||||
t.Fatal("expected Execute to be called")
|
||||
}
|
||||
// Empty values are still injected; tools decide what to do with them.
|
||||
if got := ToolChannel(ct.lastCtx); got != "" {
|
||||
t.Errorf("expected empty channel, got %q", got)
|
||||
}
|
||||
if got := ToolChatID(ct.lastCtx); got != "" {
|
||||
t.Errorf("expected empty chatID, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,14 +189,14 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
|
||||
cb := func(_ context.Context, _ *ToolResult) { called = true }
|
||||
|
||||
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
|
||||
if at.cb == nil {
|
||||
t.Error("expected SetCallback to have been called")
|
||||
if at.lastCB == nil {
|
||||
t.Error("expected ExecuteAsync to have received a callback")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("expected async result")
|
||||
}
|
||||
|
||||
at.cb(context.Background(), SilentResult("done"))
|
||||
at.lastCB(context.Background(), SilentResult("done"))
|
||||
if !called {
|
||||
t.Error("expected callback to be invoked")
|
||||
}
|
||||
|
||||
+27
-17
@@ -8,25 +8,18 @@ import (
|
||||
|
||||
type SpawnTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
allowlistCheck func(targetAgentID string) bool
|
||||
callback AsyncCallback // For async completion notification
|
||||
}
|
||||
|
||||
// Compile-time check: SpawnTool implements AsyncExecutor.
|
||||
var _ AsyncExecutor = (*SpawnTool)(nil)
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
return &SpawnTool{
|
||||
manager: manager,
|
||||
originChannel: "cli",
|
||||
originChatID: "direct",
|
||||
manager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCallback implements AsyncTool interface for async completion notification
|
||||
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
t.callback = cb
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Name() string {
|
||||
return "spawn"
|
||||
}
|
||||
@@ -56,16 +49,21 @@ func (t *SpawnTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SpawnTool) SetContext(channel, chatID string) {
|
||||
t.originChannel = channel
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
|
||||
t.allowlistCheck = check
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
return t.execute(ctx, args, nil)
|
||||
}
|
||||
|
||||
// ExecuteAsync implements AsyncExecutor. The callback is passed through to the
|
||||
// subagent manager as a call parameter — never stored on the SpawnTool instance.
|
||||
func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
return t.execute(ctx, args, cb)
|
||||
}
|
||||
|
||||
func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok || strings.TrimSpace(task) == "" {
|
||||
return ErrorResult("task is required and must be a non-empty string")
|
||||
@@ -85,8 +83,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Read channel/chatID from context (injected by registry).
|
||||
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
|
||||
// to preserve the same defaults as the original NewSpawnTool constructor.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
// Pass callback to manager for async completion notification
|
||||
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
|
||||
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
+14
-12
@@ -252,16 +252,12 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
|
||||
// and returns the result directly in the ToolResult.
|
||||
type SubagentTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
manager *SubagentManager
|
||||
}
|
||||
|
||||
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
|
||||
return &SubagentTool{
|
||||
manager: manager,
|
||||
originChannel: "cli",
|
||||
originChatID: "direct",
|
||||
manager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,11 +286,6 @@ func (t *SubagentTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SubagentTool) SetContext(channel, chatID string) {
|
||||
t.originChannel = channel
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
@@ -341,13 +332,24 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
|
||||
// to preserve the same defaults as the original NewSubagentTool constructor.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, t.originChannel, t.originChatID)
|
||||
}, messages, channel, chatID)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
@@ -50,9 +50,8 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager.SetLLMOptions(2048, 0.6)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("cli", "direct")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
args := map[string]any{"task": "Do something"}
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -147,28 +146,14 @@ func TestSubagentTool_Parameters(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_SetContext verifies context setting
|
||||
func TestSubagentTool_SetContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
tool.SetContext("test-channel", "test-chat")
|
||||
|
||||
// Verify context is set (we can't directly access private fields,
|
||||
// but we can verify it doesn't crash)
|
||||
// The actual context usage is tested in Execute tests
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_Success tests successful execution
|
||||
func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("telegram", "chat-123")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
|
||||
args := map[string]any{
|
||||
"task": "Write a haiku about coding",
|
||||
"label": "haiku-task",
|
||||
@@ -297,12 +282,9 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
// Set context
|
||||
channel := "test-channel"
|
||||
chatID := "test-chat"
|
||||
tool.SetContext(channel, chatID)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), channel, chatID)
|
||||
args := map[string]any{
|
||||
"task": "Test context passing",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user