diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 39cd4ccf9..26b35c2f1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -241,12 +241,23 @@ func registerSharedTools( // Message tool if cfg.Tools.IsToolEnabled("message") { messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + messageTool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + ) error { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID) + outboundAgentID, outboundSessionKey, outboundScope := outboundTurnMetadata( + tools.ToolAgentID(ctx), + tools.ToolSessionKey(ctx), + tools.ToolSessionScope(ctx), + ) return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Context: outboundCtx, + AgentID: outboundAgentID, + SessionKey: outboundSessionKey, + Scope: outboundScope, Content: content, ReplyToMessageID: replyToMessageID, }) @@ -2748,6 +2759,12 @@ turnLoop: ts.opts.Dispatch.MessageID(), ts.opts.Dispatch.ReplyToMessageID(), ) + execCtx = tools.WithToolSessionContext( + execCtx, + ts.agent.ID, + ts.sessionKey, + ts.opts.Dispatch.SessionScope, + ) toolResult := ts.agent.Tools.ExecuteWithContext( execCtx, toolName, diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 64ea7a943..975956bcb 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -1274,6 +1274,36 @@ func (m *handledUserProvider) GetDefaultModel() string { return "handled-user-model" } +type messageToolProvider struct { + calls int +} + +func (m *messageToolProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{{ + ID: "call_message", + Type: "function", + Name: "message", + Arguments: map[string]any{"content": "direct tool message"}, + }}, + }, nil + } + return &providers.LLMResponse{}, nil +} + +func (m *messageToolProvider) GetDefaultModel() string { + return "message-tool-model" +} + type artifactThenSendProvider struct { calls int } @@ -3058,6 +3088,53 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) { } } +func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = t.TempDir() + cfg.Agents.Defaults.ModelName = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 + cfg.Session.Dimensions = []string{"chat"} + + msgBus := bus.NewMessageBus() + provider := &messageToolProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ + Channel: "telegram", + SenderID: "user-1", + ChatID: "chat-1", + Content: "send a direct message", + })) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response == "" { + t.Fatal("expected processMessage() to return a final loop response") + } + + select { + case outbound := <-msgBus.OutboundChan(): + if outbound.Content != "direct tool message" { + t.Fatalf("outbound content = %q, want direct tool message", outbound.Content) + } + if outbound.AgentID != "main" { + t.Fatalf("outbound agent_id = %q, want main", outbound.AgentID) + } + if outbound.SessionKey == "" { + t.Fatal("expected message tool outbound to carry session_key") + } + if outbound.Scope == nil || outbound.Scope.Values["chat"] != "direct:chat-1" { + t.Fatalf("unexpected message tool outbound scope: %+v", outbound.Scope) + } + if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" { + t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context) + } + case <-time.After(2 * time.Second): + t.Fatal("expected message tool outbound") + } +} + func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) { store := media.NewFileMediaStore() dir := t.TempDir() diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index 6c9ef19c5..a7051890d 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -324,28 +324,16 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance { if !ok || agent == nil { continue } - scopeReader, ok := agent.Sessions.(interface { - GetSessionScope(sessionKey string) *session.SessionScope - }) - if !ok { + resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey) + if resolvedAgentID == "" { continue } - scope := scopeReader.GetSessionScope(sessionKey) - if scope == nil || strings.TrimSpace(scope.AgentID) == "" { - continue - } - if scopedAgent, ok := registry.GetAgent(scope.AgentID); ok { + if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok { return scopedAgent } return agent } - if parsed := session.ParseLegacyAgentSessionKey(sessionKey); parsed != nil { - if agent, ok := registry.GetAgent(parsed.AgentID); ok { - return agent - } - } - return registry.GetDefaultAgent() } diff --git a/pkg/session/key.go b/pkg/session/key.go index 6f1ee438f..fb0836bc1 100644 --- a/pkg/session/key.go +++ b/pkg/session/key.go @@ -62,6 +62,26 @@ func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey { return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest} } +// ResolveAgentID returns the routed agent ID associated with a session. It +// prefers structured session scope metadata when available and falls back to +// legacy agent-scoped session keys for compatibility. +func ResolveAgentID(store any, sessionKey string) string { + if scopeReader, ok := store.(interface { + GetSessionScope(sessionKey string) *SessionScope + }); ok { + scope := scopeReader.GetSessionScope(sessionKey) + if scope != nil && strings.TrimSpace(scope.AgentID) != "" { + return routing.NormalizeAgentID(scope.AgentID) + } + } + + if parsed := ParseLegacyAgentSessionKey(sessionKey); parsed != nil { + return routing.NormalizeAgentID(parsed.AgentID) + } + + return "" +} + func BuildLegacyMainAlias(agentID string) string { return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID)) } diff --git a/pkg/session/key_test.go b/pkg/session/key_test.go index ede38d468..6cdf397e1 100644 --- a/pkg/session/key_test.go +++ b/pkg/session/key_test.go @@ -2,6 +2,14 @@ package session import "testing" +type testScopeReader struct { + scope *SessionScope +} + +func (r testScopeReader) GetSessionScope(sessionKey string) *SessionScope { + return CloneScope(r.scope) +} + func TestIsExplicitSessionKey(t *testing.T) { tests := []struct { key string @@ -70,3 +78,23 @@ func TestBuildMainSessionKey(t *testing.T) { t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got) } } + +func TestResolveAgentID_PrefersSessionScope(t *testing.T) { + store := testScopeReader{ + scope: &SessionScope{ + Version: ScopeVersionV1, + AgentID: "Support", + Channel: "slack", + }, + } + + if got := ResolveAgentID(store, "sk_v1_anything"); got != "support" { + t.Fatalf("ResolveAgentID() = %q, want support", got) + } +} + +func TestResolveAgentID_FallsBackToLegacyKey(t *testing.T) { + if got := ResolveAgentID(nil, "agent:Sales:telegram:direct:user123"); got != "sales" { + t.Fatalf("ResolveAgentID() = %q, want sales", got) + } +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index afee95692..e1f9aacc0 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -1,6 +1,10 @@ package tools -import "context" +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/session" +) // Tool is the interface that all tools must implement. type Tool interface { @@ -25,6 +29,9 @@ var ( ctxKeyChatID = &toolCtxKey{"chatID"} ctxKeyMessageID = &toolCtxKey{"messageID"} ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"} + ctxKeyAgentID = &toolCtxKey{"agentID"} + ctxKeySessionKey = &toolCtxKey{"sessionKey"} + ctxKeySessionScope = &toolCtxKey{"sessionScope"} ) // WithToolContext returns a child context carrying channel and chatID. @@ -51,6 +58,18 @@ func WithToolInboundContext( return ctx } +// WithToolSessionContext returns a child context carrying turn-scoped session metadata. +func WithToolSessionContext( + ctx context.Context, + agentID, sessionKey string, + scope *session.SessionScope, +) context.Context { + ctx = context.WithValue(ctx, ctxKeyAgentID, agentID) + ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey) + ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope)) + return ctx +} + // ToolChannel extracts the channel from ctx, or "" if unset. func ToolChannel(ctx context.Context) string { v, _ := ctx.Value(ctxKeyChannel).(string) @@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string { return v } +// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset. +func ToolAgentID(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyAgentID).(string) + return v +} + +// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset. +func ToolSessionKey(ctx context.Context) string { + v, _ := ctx.Value(ctxKeySessionKey).(string) + return v +} + +// ToolSessionScope extracts the active turn's structured session scope from ctx. +func ToolSessionScope(ctx context.Context) *session.SessionScope { + scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope) + return session.CloneScope(scope) +} + // AsyncCallback is a function type that async tools use to notify completion. // When an async tool finishes its work, it calls this callback with the result. // diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 064065a38..ec04f042e 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -6,10 +6,10 @@ import ( "sync/atomic" ) -type SendCallback func(channel, chatID, content, replyToMessageID string) error +type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error type MessageTool struct { - sendCallback SendCallback + sendCallback SendCallbackWithContext sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round } @@ -61,7 +61,7 @@ func (t *MessageTool) HasSentInRound() bool { return t.sentInRound.Load() } -func (t *MessageTool) SetSendCallback(callback SendCallback) { +func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) { t.sendCallback = callback } @@ -90,7 +90,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } - if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil { + if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil { return &ToolResult{ ForLLM: fmt.Sprintf("sending message: %v", err), IsError: true, diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index 93a611ee0..649593252 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -4,16 +4,22 @@ import ( "context" "errors" "testing" + + "github.com/sipeed/picoclaw/pkg/session" ) func TestMessageTool_Execute_Success(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID, sentContent string - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { sentChannel = channel sentChatID = chatID sentContent = content + if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil { + t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v", + ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx)) + } return nil }) @@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID string - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { sentChannel = channel sentChatID = chatID return nil @@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() sendErr := errors.New("network error") - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { return sendErr }) @@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() // No WithToolContext — channel/chatID are empty - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { return nil }) @@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) { tool := NewMessageTool() var sentReplyTo string - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { sentReplyTo = replyToMessageID return nil }) @@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) { t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo) } } + +func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) { + tool := NewMessageTool() + + var gotAgentID, gotSessionKey string + var gotScope *session.SessionScope + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + gotAgentID = ToolAgentID(ctx) + gotSessionKey = ToolSessionKey(ctx) + gotScope = ToolSessionScope(ctx) + return nil + }) + + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") + ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "telegram", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "direct:test-chat-id", + }, + }) + + result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"}) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if gotAgentID != "main" { + t.Fatalf("ToolAgentID() = %q, want main", gotAgentID) + } + if gotSessionKey != "sk_v1_tool" { + t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey) + } + if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" { + t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope) + } +}