diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index ef2951365..d7461e76f 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -75,6 +75,8 @@ type processOptions struct { SessionKey string // Session identifier for history/context Channel string // Target channel for tool execution ChatID string // Target chat ID for tool execution + MessageID string // Current inbound platform message ID + ReplyToMessageID string // Current inbound reply target message ID SenderID string // Current sender ID for dynamic context SenderDisplayName string // Current sender display name for dynamic context UserMessage string // User message content (may include prefix) @@ -104,6 +106,7 @@ const ( metadataKeyAccountID = "account_id" metadataKeyGuildID = "guild_id" metadataKeyTeamID = "team_id" + metadataKeyReplyToMessage = "reply_to_message_id" metadataKeyParentPeerKind = "parent_peer_kind" metadataKeyParentPeerID = "parent_peer_id" ) @@ -222,17 +225,37 @@ func registerSharedTools( // Message tool if cfg.Tools.IsToolEnabled("message") { messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { + messageTool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, + Channel: channel, + ChatID: chatID, + Content: content, + ReplyToMessageID: replyToMessageID, }) }) agent.Tools.Register(messageTool) } + if cfg.Tools.IsToolEnabled("reaction") { + reactionTool := tools.NewReactionTool() + reactionTool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error { + if al.channelManager == nil { + return fmt.Errorf("channel manager not configured") + } + ch, ok := al.channelManager.GetChannel(channel) + if !ok { + return fmt.Errorf("channel %s not found", channel) + } + rc, ok := ch.(channels.ReactionCapable) + if !ok { + return fmt.Errorf("channel %s does not support reactions", channel) + } + _, err := rc.ReactToMessage(ctx, chatID, messageID) + return err + }) + agent.Tools.Register(reactionTool) + } // Send file tool (outbound media via MediaStore — store injected later by SetMediaStore) if cfg.Tools.IsToolEnabled("send_file") { @@ -1315,6 +1338,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) SessionKey: sessionKey, Channel: msg.Channel, ChatID: msg.ChatID, + MessageID: msg.MessageID, + ReplyToMessageID: inboundMetadata(msg, metadataKeyReplyToMessage), SenderID: msg.SenderID, SenderDisplayName: msg.Sender.DisplayName, UserMessage: msg.Content, @@ -2384,8 +2409,15 @@ turnLoop: } toolStart := time.Now() - toolResult := ts.agent.Tools.ExecuteWithContext( + execCtx := tools.WithToolInboundContext( turnCtx, + ts.channel, + ts.chatID, + ts.opts.MessageID, + ts.opts.ReplyToMessageID, + ) + toolResult := ts.agent.Tools.ExecuteWithContext( + execCtx, toolName, toolArgs, ts.channel, diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 25d20c689..58149f92c 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -531,6 +531,20 @@ func TestToolContext_Updates(t *testing.T) { if got := tools.ToolChannel(context.Background()); got != "" { t.Errorf("expected empty channel from bare context, got %q", got) } + + inboundCtx := tools.WithToolInboundContext( + context.Background(), + "telegram", + "chat-42", + "msg-123", + "msg-100", + ) + if got := tools.ToolMessageID(inboundCtx); got != "msg-123" { + t.Errorf("expected messageID 'msg-123', got %q", got) + } + if got := tools.ToolReplyToMessageID(inboundCtx); got != "msg-100" { + t.Errorf("expected replyToMessageID 'msg-100', got %q", got) + } } // TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved diff --git a/pkg/tools/base.go b/pkg/tools/base.go index ec743e164..afee95692 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -21,8 +21,10 @@ type Tool interface { type toolCtxKey struct{ name string } var ( - ctxKeyChannel = &toolCtxKey{"channel"} - ctxKeyChatID = &toolCtxKey{"chatID"} + ctxKeyChannel = &toolCtxKey{"channel"} + ctxKeyChatID = &toolCtxKey{"chatID"} + ctxKeyMessageID = &toolCtxKey{"messageID"} + ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"} ) // WithToolContext returns a child context carrying channel and chatID. @@ -32,6 +34,23 @@ func WithToolContext(ctx context.Context, channel, chatID string) context.Contex return ctx } +// WithToolMessageContext returns a child context carrying inbound message IDs. +func WithToolMessageContext(ctx context.Context, messageID, replyToMessageID string) context.Context { + ctx = context.WithValue(ctx, ctxKeyMessageID, messageID) + ctx = context.WithValue(ctx, ctxKeyReplyToMessageID, replyToMessageID) + return ctx +} + +// WithToolInboundContext returns a child context carrying channel/chat and inbound IDs. +func WithToolInboundContext( + ctx context.Context, + channel, chatID, messageID, replyToMessageID string, +) context.Context { + ctx = WithToolContext(ctx, channel, chatID) + ctx = WithToolMessageContext(ctx, messageID, replyToMessageID) + return ctx +} + // ToolChannel extracts the channel from ctx, or "" if unset. func ToolChannel(ctx context.Context) string { v, _ := ctx.Value(ctxKeyChannel).(string) @@ -44,6 +63,18 @@ func ToolChatID(ctx context.Context) string { return v } +// ToolMessageID extracts the current inbound message ID from ctx, or "" if unset. +func ToolMessageID(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyMessageID).(string) + return v +} + +// ToolReplyToMessageID extracts the current inbound reply target from ctx, or "" if unset. +func ToolReplyToMessageID(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyReplyToMessageID).(string) + return v +} + // AsyncCallback is a function type that async tools use to notify completion. // When an async tool finishes its work, it calls this callback with the result. // diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 438ceeddd..064065a38 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -6,7 +6,7 @@ import ( "sync/atomic" ) -type SendCallback func(channel, chatID, content string) error +type SendCallback func(channel, chatID, content, replyToMessageID string) error type MessageTool struct { sendCallback SendCallback @@ -41,6 +41,10 @@ func (t *MessageTool) Parameters() map[string]any { "type": "string", "description": "Optional: target chat/user ID", }, + "reply_to_message_id": map[string]any{ + "type": "string", + "description": "Optional: reply target message ID for channels that support threaded replies", + }, }, "required": []string{"content"}, } @@ -69,6 +73,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes channel, _ := args["channel"].(string) chatID, _ := args["chat_id"].(string) + replyToMessageID, _ := args["reply_to_message_id"].(string) if channel == "" { channel = ToolChannel(ctx) @@ -85,7 +90,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } - if err := t.sendCallback(channel, chatID, content); err != nil { + if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil { return &ToolResult{ ForLLM: fmt.Sprintf("sending message: %v", err), IsError: true, diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index 05630972e..93a611ee0 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -10,7 +10,7 @@ func TestMessageTool_Execute_Success(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID, sentContent string - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { sentChannel = channel sentChatID = chatID sentContent = content @@ -61,7 +61,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID string - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { sentChannel = channel sentChatID = chatID return nil @@ -96,7 +96,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() sendErr := errors.New("network error") - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { return sendErr }) @@ -149,7 +149,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() // No WithToolContext — channel/chatID are empty - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { return nil }) @@ -251,4 +251,37 @@ func TestMessageTool_Parameters(t *testing.T) { if chatIDProp["type"] != "string" { t.Error("Expected chat_id type to be 'string'") } + + // Check reply_to_message_id property (optional) + replyToProp, ok := props["reply_to_message_id"].(map[string]any) + if !ok { + t.Error("Expected 'reply_to_message_id' property") + } + if replyToProp["type"] != "string" { + t.Error("Expected reply_to_message_id type to be 'string'") + } +} + +func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) { + tool := NewMessageTool() + + var sentReplyTo string + tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + sentReplyTo = replyToMessageID + return nil + }) + + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") + args := map[string]any{ + "content": "Reply test", + "reply_to_message_id": "msg-123", + } + + result := tool.Execute(ctx, args) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if sentReplyTo != "msg-123" { + t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo) + } } diff --git a/pkg/tools/reaction.go b/pkg/tools/reaction.go new file mode 100644 index 000000000..3455b07a9 --- /dev/null +++ b/pkg/tools/reaction.go @@ -0,0 +1,87 @@ +package tools + +import ( + "context" + "fmt" +) + +type ReactionCallback func(ctx context.Context, channel, chatID, messageID string) error + +type ReactionTool struct { + reactionCallback ReactionCallback +} + +func NewReactionTool() *ReactionTool { + return &ReactionTool{} +} + +func (t *ReactionTool) Name() string { + return "reaction" +} + +func (t *ReactionTool) Description() string { + return "Add a reaction to a message. Defaults to the current inbound message when message_id is omitted." +} + +func (t *ReactionTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "message_id": map[string]any{ + "type": "string", + "description": "Optional: target message ID; defaults to the current inbound message", + }, + "channel": map[string]any{ + "type": "string", + "description": "Optional: target channel (telegram, whatsapp, etc.)", + }, + "chat_id": map[string]any{ + "type": "string", + "description": "Optional: target chat/user ID", + }, + }, + } +} + +func (t *ReactionTool) SetReactionCallback(callback ReactionCallback) { + t.reactionCallback = callback +} + +func (t *ReactionTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + channel, _ := args["channel"].(string) + chatID, _ := args["chat_id"].(string) + messageID, _ := args["message_id"].(string) + + if channel == "" { + channel = ToolChannel(ctx) + } + if chatID == "" { + chatID = ToolChatID(ctx) + } + if messageID == "" { + messageID = ToolMessageID(ctx) + } + + if channel == "" || chatID == "" { + return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true} + } + if messageID == "" { + return &ToolResult{ForLLM: "message_id is required", IsError: true} + } + if t.reactionCallback == nil { + return &ToolResult{ForLLM: "Reaction not configured", IsError: true} + } + + if err := t.reactionCallback(ctx, channel, chatID, messageID); err != nil { + return &ToolResult{ + ForLLM: fmt.Sprintf("adding reaction: %v", err), + IsError: true, + Err: err, + } + } + + return &ToolResult{ + ForLLM: fmt.Sprintf("Reaction added to %s:%s message %s", channel, chatID, messageID), + Silent: true, + } +} diff --git a/pkg/tools/reaction_test.go b/pkg/tools/reaction_test.go new file mode 100644 index 000000000..6fc90445a --- /dev/null +++ b/pkg/tools/reaction_test.go @@ -0,0 +1,96 @@ +package tools + +import ( + "context" + "errors" + "testing" +) + +func TestReactionTool_Execute_UsesContextMessageIDByDefault(t *testing.T) { + tool := NewReactionTool() + + var gotChannel, gotChatID, gotMessageID string + tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error { + gotChannel = channel + gotChatID = chatID + gotMessageID = messageID + return nil + }) + + ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-100", "") + result := tool.Execute(ctx, map[string]any{}) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if gotChannel != "telegram" || gotChatID != "chat-1" || gotMessageID != "msg-100" { + t.Fatalf("unexpected callback args: channel=%q chatID=%q messageID=%q", gotChannel, gotChatID, gotMessageID) + } +} + +func TestReactionTool_Execute_AllowsExplicitMessageIDOverride(t *testing.T) { + tool := NewReactionTool() + + var gotMessageID string + tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error { + gotMessageID = messageID + return nil + }) + + ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-context", "") + result := tool.Execute(ctx, map[string]any{"message_id": "msg-explicit"}) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if gotMessageID != "msg-explicit" { + t.Fatalf("expected explicit message id, got %q", gotMessageID) + } +} + +func TestReactionTool_Execute_MissingMessageID(t *testing.T) { + tool := NewReactionTool() + tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error { return nil }) + + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{}) + if !result.IsError { + t.Fatal("expected error") + } + if result.ForLLM != "message_id is required" { + t.Fatalf("unexpected error message: %q", result.ForLLM) + } +} + +func TestReactionTool_Execute_CallbackError(t *testing.T) { + tool := NewReactionTool() + tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error { + return errors.New("unsupported") + }) + + ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-100", "") + result := tool.Execute(ctx, map[string]any{}) + if !result.IsError { + t.Fatal("expected error") + } + if result.Err == nil { + t.Fatal("expected wrapped error") + } +} + +func TestReactionTool_Parameters(t *testing.T) { + tool := NewReactionTool() + params := tool.Parameters() + + props, ok := params["properties"].(map[string]any) + if !ok { + t.Fatal("expected properties map") + } + if _, ok := props["message_id"]; !ok { + t.Fatal("expected message_id parameter") + } + if _, ok := props["channel"]; !ok { + t.Fatal("expected channel parameter") + } + if _, ok := props["chat_id"]; !ok { + t.Fatal("expected chat_id parameter") + } +} diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index db52749f6..16bd30928 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -190,6 +190,33 @@ func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) { } } +func TestToolRegistry_ExecuteWithContext_PreservesMessageContext(t *testing.T) { + r := NewToolRegistry() + ct := &mockContextAwareTool{ + mockRegistryTool: *newMockTool("ctx_tool", "needs context"), + } + r.Register(ct) + + baseCtx := WithToolMessageContext(context.Background(), "msg-123", "msg-100") + r.ExecuteWithContext(baseCtx, "ctx_tool", nil, "telegram", "chat-42", nil) + + if ct.lastCtx == nil { + t.Fatal("expected Execute to be called") + } + if got := ToolChannel(ct.lastCtx); got != "telegram" { + t.Errorf("expected channel 'telegram', got %q", got) + } + if got := ToolChatID(ct.lastCtx); got != "chat-42" { + t.Errorf("expected chatID 'chat-42', got %q", got) + } + if got := ToolMessageID(ct.lastCtx); got != "msg-123" { + t.Errorf("expected messageID 'msg-123', got %q", got) + } + if got := ToolReplyToMessageID(ct.lastCtx); got != "msg-100" { + t.Errorf("expected replyToMessageID 'msg-100', got %q", got) + } +} + func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) { r := NewToolRegistry() at := &mockAsyncRegistryTool{