mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(session): tighten legacy boundary and tool context
This commit is contained in:
+38
-1
@@ -1,6 +1,10 @@
|
||||
package tools
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// Tool is the interface that all tools must implement.
|
||||
type Tool interface {
|
||||
@@ -25,6 +29,9 @@ var (
|
||||
ctxKeyChatID = &toolCtxKey{"chatID"}
|
||||
ctxKeyMessageID = &toolCtxKey{"messageID"}
|
||||
ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"}
|
||||
ctxKeyAgentID = &toolCtxKey{"agentID"}
|
||||
ctxKeySessionKey = &toolCtxKey{"sessionKey"}
|
||||
ctxKeySessionScope = &toolCtxKey{"sessionScope"}
|
||||
)
|
||||
|
||||
// WithToolContext returns a child context carrying channel and chatID.
|
||||
@@ -51,6 +58,18 @@ func WithToolInboundContext(
|
||||
return ctx
|
||||
}
|
||||
|
||||
// WithToolSessionContext returns a child context carrying turn-scoped session metadata.
|
||||
func WithToolSessionContext(
|
||||
ctx context.Context,
|
||||
agentID, sessionKey string,
|
||||
scope *session.SessionScope,
|
||||
) context.Context {
|
||||
ctx = context.WithValue(ctx, ctxKeyAgentID, agentID)
|
||||
ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey)
|
||||
ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope))
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ToolChannel extracts the channel from ctx, or "" if unset.
|
||||
func ToolChannel(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyChannel).(string)
|
||||
@@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string {
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset.
|
||||
func ToolAgentID(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyAgentID).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset.
|
||||
func ToolSessionKey(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeySessionKey).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolSessionScope extracts the active turn's structured session scope from ctx.
|
||||
func ToolSessionScope(ctx context.Context) *session.SessionScope {
|
||||
scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope)
|
||||
return session.CloneScope(scope)
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
// When an async tool finishes its work, it calls this callback with the result.
|
||||
//
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type SendCallback func(channel, chatID, content, replyToMessageID string) error
|
||||
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
sendCallback SendCallbackWithContext
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (t *MessageTool) HasSentInRound() bool {
|
||||
return t.sentInRound.Load()
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil {
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
|
||||
@@ -4,16 +4,22 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
|
||||
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
|
||||
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
@@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
@@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentReplyTo string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentReplyTo = replyToMessageID
|
||||
return nil
|
||||
})
|
||||
@@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var gotAgentID, gotSessionKey string
|
||||
var gotScope *session.SessionScope
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
gotAgentID = ToolAgentID(ctx)
|
||||
gotSessionKey = ToolSessionKey(ctx)
|
||||
gotScope = ToolSessionScope(ctx)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "direct:test-chat-id",
|
||||
},
|
||||
})
|
||||
|
||||
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotAgentID != "main" {
|
||||
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
|
||||
}
|
||||
if gotSessionKey != "sk_v1_tool" {
|
||||
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
|
||||
}
|
||||
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
|
||||
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user