mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(session): tighten legacy boundary and tool context
This commit is contained in:
+18
-1
@@ -241,12 +241,23 @@ func registerSharedTools(
|
||||
// Message tool
|
||||
if cfg.Tools.IsToolEnabled("message") {
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
messageTool.SetSendCallback(func(
|
||||
ctx context.Context,
|
||||
channel, chatID, content, replyToMessageID string,
|
||||
) error {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID)
|
||||
outboundAgentID, outboundSessionKey, outboundScope := outboundTurnMetadata(
|
||||
tools.ToolAgentID(ctx),
|
||||
tools.ToolSessionKey(ctx),
|
||||
tools.ToolSessionScope(ctx),
|
||||
)
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Context: outboundCtx,
|
||||
AgentID: outboundAgentID,
|
||||
SessionKey: outboundSessionKey,
|
||||
Scope: outboundScope,
|
||||
Content: content,
|
||||
ReplyToMessageID: replyToMessageID,
|
||||
})
|
||||
@@ -2748,6 +2759,12 @@ turnLoop:
|
||||
ts.opts.Dispatch.MessageID(),
|
||||
ts.opts.Dispatch.ReplyToMessageID(),
|
||||
)
|
||||
execCtx = tools.WithToolSessionContext(
|
||||
execCtx,
|
||||
ts.agent.ID,
|
||||
ts.sessionKey,
|
||||
ts.opts.Dispatch.SessionScope,
|
||||
)
|
||||
toolResult := ts.agent.Tools.ExecuteWithContext(
|
||||
execCtx,
|
||||
toolName,
|
||||
|
||||
@@ -1274,6 +1274,36 @@ func (m *handledUserProvider) GetDefaultModel() string {
|
||||
return "handled-user-model"
|
||||
}
|
||||
|
||||
type messageToolProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *messageToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_message",
|
||||
Type: "function",
|
||||
Name: "message",
|
||||
Arguments: map[string]any{"content": "direct tool message"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
return &providers.LLMResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *messageToolProvider) GetDefaultModel() string {
|
||||
return "message-tool-model"
|
||||
}
|
||||
|
||||
type artifactThenSendProvider struct {
|
||||
calls int
|
||||
}
|
||||
@@ -3058,6 +3088,53 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = t.TempDir()
|
||||
cfg.Agents.Defaults.ModelName = "test-model"
|
||||
cfg.Agents.Defaults.MaxTokens = 4096
|
||||
cfg.Agents.Defaults.MaxToolIterations = 10
|
||||
cfg.Session.Dimensions = []string{"chat"}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &messageToolProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user-1",
|
||||
ChatID: "chat-1",
|
||||
Content: "send a direct message",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response == "" {
|
||||
t.Fatal("expected processMessage() to return a final loop response")
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "direct tool message" {
|
||||
t.Fatalf("outbound content = %q, want direct tool message", outbound.Content)
|
||||
}
|
||||
if outbound.AgentID != "main" {
|
||||
t.Fatalf("outbound agent_id = %q, want main", outbound.AgentID)
|
||||
}
|
||||
if outbound.SessionKey == "" {
|
||||
t.Fatal("expected message tool outbound to carry session_key")
|
||||
}
|
||||
if outbound.Scope == nil || outbound.Scope.Values["chat"] != "direct:chat-1" {
|
||||
t.Fatalf("unexpected message tool outbound scope: %+v", outbound.Scope)
|
||||
}
|
||||
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
|
||||
t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected message tool outbound")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
+3
-15
@@ -324,28 +324,16 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
if !ok || agent == nil {
|
||||
continue
|
||||
}
|
||||
scopeReader, ok := agent.Sessions.(interface {
|
||||
GetSessionScope(sessionKey string) *session.SessionScope
|
||||
})
|
||||
if !ok {
|
||||
resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey)
|
||||
if resolvedAgentID == "" {
|
||||
continue
|
||||
}
|
||||
scope := scopeReader.GetSessionScope(sessionKey)
|
||||
if scope == nil || strings.TrimSpace(scope.AgentID) == "" {
|
||||
continue
|
||||
}
|
||||
if scopedAgent, ok := registry.GetAgent(scope.AgentID); ok {
|
||||
if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok {
|
||||
return scopedAgent
|
||||
}
|
||||
return agent
|
||||
}
|
||||
|
||||
if parsed := session.ParseLegacyAgentSessionKey(sessionKey); parsed != nil {
|
||||
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
||||
return agent
|
||||
}
|
||||
}
|
||||
|
||||
return registry.GetDefaultAgent()
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,26 @@ func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey {
|
||||
return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest}
|
||||
}
|
||||
|
||||
// ResolveAgentID returns the routed agent ID associated with a session. It
|
||||
// prefers structured session scope metadata when available and falls back to
|
||||
// legacy agent-scoped session keys for compatibility.
|
||||
func ResolveAgentID(store any, sessionKey string) string {
|
||||
if scopeReader, ok := store.(interface {
|
||||
GetSessionScope(sessionKey string) *SessionScope
|
||||
}); ok {
|
||||
scope := scopeReader.GetSessionScope(sessionKey)
|
||||
if scope != nil && strings.TrimSpace(scope.AgentID) != "" {
|
||||
return routing.NormalizeAgentID(scope.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
if parsed := ParseLegacyAgentSessionKey(sessionKey); parsed != nil {
|
||||
return routing.NormalizeAgentID(parsed.AgentID)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func BuildLegacyMainAlias(agentID string) string {
|
||||
return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,14 @@ package session
|
||||
|
||||
import "testing"
|
||||
|
||||
type testScopeReader struct {
|
||||
scope *SessionScope
|
||||
}
|
||||
|
||||
func (r testScopeReader) GetSessionScope(sessionKey string) *SessionScope {
|
||||
return CloneScope(r.scope)
|
||||
}
|
||||
|
||||
func TestIsExplicitSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
@@ -70,3 +78,23 @@ func TestBuildMainSessionKey(t *testing.T) {
|
||||
t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAgentID_PrefersSessionScope(t *testing.T) {
|
||||
store := testScopeReader{
|
||||
scope: &SessionScope{
|
||||
Version: ScopeVersionV1,
|
||||
AgentID: "Support",
|
||||
Channel: "slack",
|
||||
},
|
||||
}
|
||||
|
||||
if got := ResolveAgentID(store, "sk_v1_anything"); got != "support" {
|
||||
t.Fatalf("ResolveAgentID() = %q, want support", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAgentID_FallsBackToLegacyKey(t *testing.T) {
|
||||
if got := ResolveAgentID(nil, "agent:Sales:telegram:direct:user123"); got != "sales" {
|
||||
t.Fatalf("ResolveAgentID() = %q, want sales", got)
|
||||
}
|
||||
}
|
||||
|
||||
+38
-1
@@ -1,6 +1,10 @@
|
||||
package tools
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// Tool is the interface that all tools must implement.
|
||||
type Tool interface {
|
||||
@@ -25,6 +29,9 @@ var (
|
||||
ctxKeyChatID = &toolCtxKey{"chatID"}
|
||||
ctxKeyMessageID = &toolCtxKey{"messageID"}
|
||||
ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"}
|
||||
ctxKeyAgentID = &toolCtxKey{"agentID"}
|
||||
ctxKeySessionKey = &toolCtxKey{"sessionKey"}
|
||||
ctxKeySessionScope = &toolCtxKey{"sessionScope"}
|
||||
)
|
||||
|
||||
// WithToolContext returns a child context carrying channel and chatID.
|
||||
@@ -51,6 +58,18 @@ func WithToolInboundContext(
|
||||
return ctx
|
||||
}
|
||||
|
||||
// WithToolSessionContext returns a child context carrying turn-scoped session metadata.
|
||||
func WithToolSessionContext(
|
||||
ctx context.Context,
|
||||
agentID, sessionKey string,
|
||||
scope *session.SessionScope,
|
||||
) context.Context {
|
||||
ctx = context.WithValue(ctx, ctxKeyAgentID, agentID)
|
||||
ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey)
|
||||
ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope))
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ToolChannel extracts the channel from ctx, or "" if unset.
|
||||
func ToolChannel(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyChannel).(string)
|
||||
@@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string {
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset.
|
||||
func ToolAgentID(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyAgentID).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset.
|
||||
func ToolSessionKey(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeySessionKey).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolSessionScope extracts the active turn's structured session scope from ctx.
|
||||
func ToolSessionScope(ctx context.Context) *session.SessionScope {
|
||||
scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope)
|
||||
return session.CloneScope(scope)
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
// When an async tool finishes its work, it calls this callback with the result.
|
||||
//
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type SendCallback func(channel, chatID, content, replyToMessageID string) error
|
||||
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
sendCallback SendCallbackWithContext
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (t *MessageTool) HasSentInRound() bool {
|
||||
return t.sentInRound.Load()
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil {
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
|
||||
@@ -4,16 +4,22 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
|
||||
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
|
||||
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
@@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
@@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentReplyTo string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentReplyTo = replyToMessageID
|
||||
return nil
|
||||
})
|
||||
@@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var gotAgentID, gotSessionKey string
|
||||
var gotScope *session.SessionScope
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
gotAgentID = ToolAgentID(ctx)
|
||||
gotSessionKey = ToolSessionKey(ctx)
|
||||
gotScope = ToolSessionScope(ctx)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "direct:test-chat-id",
|
||||
},
|
||||
})
|
||||
|
||||
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotAgentID != "main" {
|
||||
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
|
||||
}
|
||||
if gotSessionKey != "sk_v1_tool" {
|
||||
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
|
||||
}
|
||||
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
|
||||
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user