refactor(session): tighten legacy boundary and tool context

This commit is contained in:
Hoshina
2026-04-07 22:39:46 +08:00
parent 9f23ec22d6
commit 3d60385958
8 changed files with 237 additions and 26 deletions
+38 -1
View File
@@ -1,6 +1,10 @@
package tools
import "context"
import (
"context"
"github.com/sipeed/picoclaw/pkg/session"
)
// Tool is the interface that all tools must implement.
type Tool interface {
@@ -25,6 +29,9 @@ var (
ctxKeyChatID = &toolCtxKey{"chatID"}
ctxKeyMessageID = &toolCtxKey{"messageID"}
ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"}
ctxKeyAgentID = &toolCtxKey{"agentID"}
ctxKeySessionKey = &toolCtxKey{"sessionKey"}
ctxKeySessionScope = &toolCtxKey{"sessionScope"}
)
// WithToolContext returns a child context carrying channel and chatID.
@@ -51,6 +58,18 @@ func WithToolInboundContext(
return ctx
}
// WithToolSessionContext returns a child context carrying turn-scoped session metadata.
func WithToolSessionContext(
ctx context.Context,
agentID, sessionKey string,
scope *session.SessionScope,
) context.Context {
ctx = context.WithValue(ctx, ctxKeyAgentID, agentID)
ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey)
ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope))
return ctx
}
// ToolChannel extracts the channel from ctx, or "" if unset.
func ToolChannel(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChannel).(string)
@@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string {
return v
}
// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset.
func ToolAgentID(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyAgentID).(string)
return v
}
// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset.
func ToolSessionKey(ctx context.Context) string {
v, _ := ctx.Value(ctxKeySessionKey).(string)
return v
}
// ToolSessionScope extracts the active turn's structured session scope from ctx.
func ToolSessionScope(ctx context.Context) *session.SessionScope {
scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope)
return session.CloneScope(scope)
}
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
+4 -4
View File
@@ -6,10 +6,10 @@ import (
"sync/atomic"
)
type SendCallback func(channel, chatID, content, replyToMessageID string) error
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
type MessageTool struct {
sendCallback SendCallback
sendCallback SendCallbackWithContext
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
@@ -61,7 +61,7 @@ func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound.Load()
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
t.sendCallback = callback
}
@@ -90,7 +90,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil {
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
+49 -5
View File
@@ -4,16 +4,22 @@ import (
"context"
"errors"
"testing"
"github.com/sipeed/picoclaw/pkg/session"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
}
return nil
})
@@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
return nil
@@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return sendErr
})
@@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return nil
})
@@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
tool := NewMessageTool()
var sentReplyTo string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentReplyTo = replyToMessageID
return nil
})
@@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
}
}
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
tool := NewMessageTool()
var gotAgentID, gotSessionKey string
var gotScope *session.SessionScope
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
gotAgentID = ToolAgentID(ctx)
gotSessionKey = ToolSessionKey(ctx)
gotScope = ToolSessionScope(ctx)
return nil
})
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "direct:test-chat-id",
},
})
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if gotAgentID != "main" {
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
}
if gotSessionKey != "sk_v1_tool" {
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
}
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
}
}