refactor(session): tighten legacy boundary and tool context

This commit is contained in:
Hoshina
2026-04-07 22:39:46 +08:00
parent 9f23ec22d6
commit 3d60385958
8 changed files with 237 additions and 26 deletions
+18 -1
View File
@@ -241,12 +241,23 @@ func registerSharedTools(
// Message tool
if cfg.Tools.IsToolEnabled("message") {
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
messageTool.SetSendCallback(func(
ctx context.Context,
channel, chatID, content, replyToMessageID string,
) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID)
outboundAgentID, outboundSessionKey, outboundScope := outboundTurnMetadata(
tools.ToolAgentID(ctx),
tools.ToolSessionKey(ctx),
tools.ToolSessionScope(ctx),
)
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Context: outboundCtx,
AgentID: outboundAgentID,
SessionKey: outboundSessionKey,
Scope: outboundScope,
Content: content,
ReplyToMessageID: replyToMessageID,
})
@@ -2748,6 +2759,12 @@ turnLoop:
ts.opts.Dispatch.MessageID(),
ts.opts.Dispatch.ReplyToMessageID(),
)
execCtx = tools.WithToolSessionContext(
execCtx,
ts.agent.ID,
ts.sessionKey,
ts.opts.Dispatch.SessionScope,
)
toolResult := ts.agent.Tools.ExecuteWithContext(
execCtx,
toolName,
+77
View File
@@ -1274,6 +1274,36 @@ func (m *handledUserProvider) GetDefaultModel() string {
return "handled-user-model"
}
type messageToolProvider struct {
calls int
}
func (m *messageToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "",
ToolCalls: []providers.ToolCall{{
ID: "call_message",
Type: "function",
Name: "message",
Arguments: map[string]any{"content": "direct tool message"},
}},
}, nil
}
return &providers.LLMResponse{}, nil
}
func (m *messageToolProvider) GetDefaultModel() string {
return "message-tool-model"
}
type artifactThenSendProvider struct {
calls int
}
@@ -3058,6 +3088,53 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
}
}
func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = t.TempDir()
cfg.Agents.Defaults.ModelName = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
cfg.Session.Dimensions = []string{"chat"}
msgBus := bus.NewMessageBus()
provider := &messageToolProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "user-1",
ChatID: "chat-1",
Content: "send a direct message",
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response == "" {
t.Fatal("expected processMessage() to return a final loop response")
}
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "direct tool message" {
t.Fatalf("outbound content = %q, want direct tool message", outbound.Content)
}
if outbound.AgentID != "main" {
t.Fatalf("outbound agent_id = %q, want main", outbound.AgentID)
}
if outbound.SessionKey == "" {
t.Fatal("expected message tool outbound to carry session_key")
}
if outbound.Scope == nil || outbound.Scope.Values["chat"] != "direct:chat-1" {
t.Fatalf("unexpected message tool outbound scope: %+v", outbound.Scope)
}
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context)
}
case <-time.After(2 * time.Second):
t.Fatal("expected message tool outbound")
}
}
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
+3 -15
View File
@@ -324,28 +324,16 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
if !ok || agent == nil {
continue
}
scopeReader, ok := agent.Sessions.(interface {
GetSessionScope(sessionKey string) *session.SessionScope
})
if !ok {
resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey)
if resolvedAgentID == "" {
continue
}
scope := scopeReader.GetSessionScope(sessionKey)
if scope == nil || strings.TrimSpace(scope.AgentID) == "" {
continue
}
if scopedAgent, ok := registry.GetAgent(scope.AgentID); ok {
if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok {
return scopedAgent
}
return agent
}
if parsed := session.ParseLegacyAgentSessionKey(sessionKey); parsed != nil {
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
return agent
}
}
return registry.GetDefaultAgent()
}
+20
View File
@@ -62,6 +62,26 @@ func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey {
return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest}
}
// ResolveAgentID returns the routed agent ID associated with a session. It
// prefers structured session scope metadata when available and falls back to
// legacy agent-scoped session keys for compatibility.
func ResolveAgentID(store any, sessionKey string) string {
if scopeReader, ok := store.(interface {
GetSessionScope(sessionKey string) *SessionScope
}); ok {
scope := scopeReader.GetSessionScope(sessionKey)
if scope != nil && strings.TrimSpace(scope.AgentID) != "" {
return routing.NormalizeAgentID(scope.AgentID)
}
}
if parsed := ParseLegacyAgentSessionKey(sessionKey); parsed != nil {
return routing.NormalizeAgentID(parsed.AgentID)
}
return ""
}
func BuildLegacyMainAlias(agentID string) string {
return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID))
}
+28
View File
@@ -2,6 +2,14 @@ package session
import "testing"
type testScopeReader struct {
scope *SessionScope
}
func (r testScopeReader) GetSessionScope(sessionKey string) *SessionScope {
return CloneScope(r.scope)
}
func TestIsExplicitSessionKey(t *testing.T) {
tests := []struct {
key string
@@ -70,3 +78,23 @@ func TestBuildMainSessionKey(t *testing.T) {
t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got)
}
}
func TestResolveAgentID_PrefersSessionScope(t *testing.T) {
store := testScopeReader{
scope: &SessionScope{
Version: ScopeVersionV1,
AgentID: "Support",
Channel: "slack",
},
}
if got := ResolveAgentID(store, "sk_v1_anything"); got != "support" {
t.Fatalf("ResolveAgentID() = %q, want support", got)
}
}
func TestResolveAgentID_FallsBackToLegacyKey(t *testing.T) {
if got := ResolveAgentID(nil, "agent:Sales:telegram:direct:user123"); got != "sales" {
t.Fatalf("ResolveAgentID() = %q, want sales", got)
}
}
+38 -1
View File
@@ -1,6 +1,10 @@
package tools
import "context"
import (
"context"
"github.com/sipeed/picoclaw/pkg/session"
)
// Tool is the interface that all tools must implement.
type Tool interface {
@@ -25,6 +29,9 @@ var (
ctxKeyChatID = &toolCtxKey{"chatID"}
ctxKeyMessageID = &toolCtxKey{"messageID"}
ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"}
ctxKeyAgentID = &toolCtxKey{"agentID"}
ctxKeySessionKey = &toolCtxKey{"sessionKey"}
ctxKeySessionScope = &toolCtxKey{"sessionScope"}
)
// WithToolContext returns a child context carrying channel and chatID.
@@ -51,6 +58,18 @@ func WithToolInboundContext(
return ctx
}
// WithToolSessionContext returns a child context carrying turn-scoped session metadata.
func WithToolSessionContext(
ctx context.Context,
agentID, sessionKey string,
scope *session.SessionScope,
) context.Context {
ctx = context.WithValue(ctx, ctxKeyAgentID, agentID)
ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey)
ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope))
return ctx
}
// ToolChannel extracts the channel from ctx, or "" if unset.
func ToolChannel(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChannel).(string)
@@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string {
return v
}
// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset.
func ToolAgentID(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyAgentID).(string)
return v
}
// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset.
func ToolSessionKey(ctx context.Context) string {
v, _ := ctx.Value(ctxKeySessionKey).(string)
return v
}
// ToolSessionScope extracts the active turn's structured session scope from ctx.
func ToolSessionScope(ctx context.Context) *session.SessionScope {
scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope)
return session.CloneScope(scope)
}
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
+4 -4
View File
@@ -6,10 +6,10 @@ import (
"sync/atomic"
)
type SendCallback func(channel, chatID, content, replyToMessageID string) error
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
type MessageTool struct {
sendCallback SendCallback
sendCallback SendCallbackWithContext
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
@@ -61,7 +61,7 @@ func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound.Load()
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
t.sendCallback = callback
}
@@ -90,7 +90,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil {
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
+49 -5
View File
@@ -4,16 +4,22 @@ import (
"context"
"errors"
"testing"
"github.com/sipeed/picoclaw/pkg/session"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
}
return nil
})
@@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
return nil
@@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return sendErr
})
@@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return nil
})
@@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
tool := NewMessageTool()
var sentReplyTo string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentReplyTo = replyToMessageID
return nil
})
@@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
}
}
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
tool := NewMessageTool()
var gotAgentID, gotSessionKey string
var gotScope *session.SessionScope
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
gotAgentID = ToolAgentID(ctx)
gotSessionKey = ToolSessionKey(ctx)
gotScope = ToolSessionScope(ctx)
return nil
})
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "direct:test-chat-id",
},
})
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if gotAgentID != "main" {
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
}
if gotSessionKey != "sk_v1_tool" {
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
}
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
}
}