Merge remote-tracking branch 'origin/main' into refactor/line-sdk

# Conflicts:
#	pkg/channels/line/line.go
This commit is contained in:
ex-takashima
2026-04-15 23:07:04 +09:00
424 changed files with 36520 additions and 8067 deletions
+54 -25
View File
@@ -685,43 +685,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
// tool result messages following it. This is required by strict providers
// like DeepSeek that enforce: "An assistant message with 'tool_calls' must
// be followed by tool messages responding to each 'tool_call_id'."
//
// Deduplication is scoped to the contiguous tool-result block that follows a
// single assistant tool-call message. Some providers legitimately reuse call
// IDs across separate turns (for example "call_0"), so global deduplication
// would incorrectly delete later valid tool results and leave an
// assistant(tool_calls) -> assistant sequence behind.
final := make([]providers.Message, 0, len(sanitized))
seenToolCallID := make(map[string]bool)
for i := 0; i < len(sanitized); i++ {
msg := sanitized[i]
// Deduplicate tool results by ToolCallID
if msg.Role == "tool" && msg.ToolCallID != "" {
if seenToolCallID[msg.ToolCallID] {
logger.DebugCF("agent", "Dropping duplicate tool result", map[string]any{
"tool_call_id": msg.ToolCallID,
})
continue
}
seenToolCallID[msg.ToolCallID] = true
}
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
// Collect expected tool_call IDs
expected := make(map[string]bool, len(msg.ToolCalls))
invalidToolCallID := false
for _, tc := range msg.ToolCalls {
if tc.ID == "" {
invalidToolCallID = true
continue
}
expected[tc.ID] = false
}
// Check following messages for matching tool results
toolMsgCount := 0
for j := i + 1; j < len(sanitized); j++ {
if sanitized[j].Role != "tool" {
block := make([]providers.Message, 0, len(expected))
seenInBlock := make(map[string]bool, len(expected))
j := i + 1
for ; j < len(sanitized); j++ {
next := sanitized[j]
if next.Role != "tool" {
break
}
toolMsgCount++
if _, exists := expected[sanitized[j].ToolCallID]; exists {
expected[sanitized[j].ToolCallID] = true
if next.ToolCallID == "" {
logger.DebugCF("agent", "Dropping tool result without tool_call_id", map[string]any{})
continue
}
if _, ok := expected[next.ToolCallID]; !ok {
logger.DebugCF("agent", "Dropping unexpected tool result", map[string]any{
"tool_call_id": next.ToolCallID,
})
continue
}
if seenInBlock[next.ToolCallID] {
logger.DebugCF("agent", "Dropping duplicate tool result in tool block", map[string]any{
"tool_call_id": next.ToolCallID,
})
continue
}
seenInBlock[next.ToolCallID] = true
expected[next.ToolCallID] = true
block = append(block, next)
}
// If any tool_call_id is missing, drop this assistant message and its partial tool messages
allFound := true
allFound := !invalidToolCallID
if invalidToolCallID {
logger.DebugCF("agent", "Dropping assistant message with empty tool_call_id", map[string]any{})
}
for toolCallID, found := range expected {
if !found {
allFound = false
@@ -731,7 +748,7 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
map[string]any{
"missing_tool_call_id": toolCallID,
"expected_count": len(expected),
"found_count": toolMsgCount,
"found_count": len(block),
},
)
break
@@ -739,11 +756,23 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
}
if !allFound {
// Skip this assistant message and its tool messages
i += toolMsgCount
i = j - 1
continue
}
final = append(final, msg)
final = append(final, block...)
i = j - 1
continue
}
if msg.Role == "tool" {
logger.DebugCF("agent", "Dropping orphaned tool message after validation", map[string]any{
"tool_call_id": msg.ToolCallID,
})
continue
}
final = append(final, msg)
}
+12 -2
View File
@@ -42,7 +42,7 @@ func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) e
if result, ok := m.forceCompression(req.SessionKey); ok {
m.al.emitEvent(
EventKindContextCompress,
m.al.newTurnEventScope("", req.SessionKey).meta(0, "forceCompression", "turn.context.compress"),
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
ContextCompressPayload{
Reason: req.Reason,
DroppedMessages: result.DroppedMessages,
@@ -61,6 +61,16 @@ func (m *legacyContextManager) Ingest(_ context.Context, _ *IngestRequest) error
return nil
}
func (m *legacyContextManager) Clear(_ context.Context, sessionKey string) error {
agent := m.al.registry.GetDefaultAgent()
if agent == nil || agent.Sessions == nil {
return fmt.Errorf("sessions not initialized")
}
agent.Sessions.SetHistory(sessionKey, []providers.Message{})
agent.Sessions.SetSummary(sessionKey, "")
return agent.Sessions.Save(sessionKey)
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
// It runs asynchronously in a goroutine.
func (m *legacyContextManager) maybeSummarize(sessionKey string) {
@@ -237,7 +247,7 @@ func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey
agent.Sessions.Save(sessionKey)
m.al.emitEvent(
EventKindSessionSummarize,
m.al.newTurnEventScope(agent.ID, sessionKey).meta(0, "summarizeSession", "turn.session.summarize"),
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
SessionSummarizePayload{
SummarizedMessages: len(validMessages),
KeptMessages: keepCount,
+4
View File
@@ -24,6 +24,10 @@ type ContextManager interface {
// Ingest records a message into the ContextManager's own storage.
// Called after each message is persisted to session JSONL.
Ingest(ctx context.Context, req *IngestRequest) error
// Clear removes all stored context for a session (messages, summaries, etc.).
// Called when the user issues /clear or /reset.
Clear(ctx context.Context, sessionKey string) error
}
// AssembleRequest is the input to Assemble.
+3
View File
@@ -690,6 +690,7 @@ func (m *noopContextManager) Assemble(_ context.Context, req *AssembleRequest) (
}
func (m *noopContextManager) Compact(_ context.Context, _ *CompactRequest) error { return nil }
func (m *noopContextManager) Ingest(_ context.Context, _ *IngestRequest) error { return nil }
func (m *noopContextManager) Clear(_ context.Context, _ string) error { return nil }
// trackingContextManager tracks call counts for each method.
type trackingContextManager struct {
@@ -726,6 +727,8 @@ func (m *trackingContextManager) Ingest(_ context.Context, req *IngestRequest) e
return nil
}
func (m *trackingContextManager) Clear(_ context.Context, _ string) error { return nil }
// resetCMRegistry clears the global factory registry and returns a cleanup
// function that restores the original state after the test.
func resetCMRegistry() func() {
+14 -1
View File
@@ -1,4 +1,4 @@
//go:build !mipsle && !netbsd
//go:build !mipsle && !netbsd && !(freebsd && arm)
package agent
@@ -154,6 +154,19 @@ func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest)
return err
}
// Clear removes all stored context for a session (seahorse DB + JSONL).
func (m *seahorseContextManager) Clear(ctx context.Context, sessionKey string) error {
if err := m.engine.ClearSession(ctx, sessionKey); err != nil {
return err
}
if m.sessions != nil {
m.sessions.SetHistory(sessionKey, []providers.Message{})
m.sessions.SetSummary(sessionKey, "")
return m.sessions.Save(sessionKey)
}
return nil
}
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
if m.sessions == nil {
+1 -1
View File
@@ -1,4 +1,4 @@
//go:build mipsle || netbsd
//go:build mipsle || netbsd || (freebsd && arm)
package agent
+41
View File
@@ -213,6 +213,47 @@ func TestSanitizeHistoryForProvider_DuplicateToolResults(t *testing.T) {
}
}
func TestSanitizeHistoryForProvider_ReusedToolCallIDAcrossRounds(t *testing.T) {
history := []providers.Message{
msg("user", "first"),
assistantWithTools("call_0"),
toolResult("call_0"),
msg("assistant", "first done"),
msg("user", "second"),
assistantWithTools("call_0"),
toolResult("call_0"),
msg("assistant", "second done"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 8 {
t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "assistant", "tool", "assistant")
if result[2].ToolCallID != "call_0" || result[6].ToolCallID != "call_0" {
t.Fatalf(
"expected both tool results to be preserved, got IDs %q and %q",
result[2].ToolCallID,
result[6].ToolCallID,
)
}
}
func TestSanitizeHistoryForProvider_DropsAssistantWithEmptyToolCallID(t *testing.T) {
history := []providers.Message{
msg("user", "do something"),
assistantWithTools(""),
toolResult(""),
msg("assistant", "done"),
}
result := sanitizeHistoryForProvider(history)
if len(result) != 2 {
t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant")
}
func roles(msgs []providers.Message) []string {
r := make([]string, len(msgs))
for i, m := range msgs {
+147
View File
@@ -0,0 +1,147 @@
package agent
import (
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
// DispatchRequest is the normalized runtime input passed into the agent loop
// after routing and session allocation have completed.
type DispatchRequest struct {
SessionKey string
SessionAliases []string
InboundContext *bus.InboundContext
RouteResult *routing.ResolvedRoute
SessionScope *session.SessionScope
UserMessage string
Media []string
}
func (r DispatchRequest) Channel() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.Channel
}
func (r DispatchRequest) ChatID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.ChatID
}
func (r DispatchRequest) MessageID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.MessageID
}
func (r DispatchRequest) ReplyToMessageID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.ReplyToMessageID
}
func (r DispatchRequest) SenderID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.SenderID
}
func normalizeProcessOptionsInPlace(opts *processOptions) {
if opts == nil {
return
}
*opts = normalizeProcessOptions(*opts)
}
func normalizeProcessOptions(opts processOptions) processOptions {
if opts.Dispatch.SessionKey == "" {
opts.Dispatch.SessionKey = strings.TrimSpace(opts.SessionKey)
}
if len(opts.Dispatch.SessionAliases) == 0 && len(opts.SessionAliases) > 0 {
opts.Dispatch.SessionAliases = append([]string(nil), opts.SessionAliases...)
}
if opts.Dispatch.UserMessage == "" {
opts.Dispatch.UserMessage = opts.UserMessage
}
if len(opts.Dispatch.Media) == 0 && len(opts.Media) > 0 {
opts.Dispatch.Media = append([]string(nil), opts.Media...)
}
if opts.Dispatch.RouteResult == nil {
opts.Dispatch.RouteResult = cloneResolvedRoute(opts.RouteResult)
}
if opts.Dispatch.SessionScope == nil {
opts.Dispatch.SessionScope = session.CloneScope(opts.SessionScope)
}
if opts.Dispatch.InboundContext == nil {
if opts.InboundContext != nil {
opts.Dispatch.InboundContext = cloneInboundContext(opts.InboundContext)
} else if opts.Channel != "" || opts.ChatID != "" || opts.SenderID != "" ||
opts.MessageID != "" || opts.ReplyToMessageID != "" {
inbound := bus.InboundContext{
Channel: strings.TrimSpace(opts.Channel),
ChatID: strings.TrimSpace(opts.ChatID),
SenderID: strings.TrimSpace(opts.SenderID),
MessageID: strings.TrimSpace(opts.MessageID),
ReplyToMessageID: strings.TrimSpace(opts.ReplyToMessageID),
}
inbound.ChatType = inferChatTypeFromSessionScope(opts.Dispatch.SessionScope)
if inbound.Channel != "" || inbound.ChatID != "" || inbound.SenderID != "" ||
inbound.MessageID != "" || inbound.ReplyToMessageID != "" {
inbound = bus.NormalizeInboundMessage(bus.InboundMessage{Context: inbound}).Context
opts.Dispatch.InboundContext = &inbound
}
}
}
// Keep legacy mirrors populated while the rest of the runtime migrates.
opts.SessionKey = opts.Dispatch.SessionKey
opts.SessionAliases = append([]string(nil), opts.Dispatch.SessionAliases...)
opts.UserMessage = opts.Dispatch.UserMessage
opts.Media = append([]string(nil), opts.Dispatch.Media...)
opts.InboundContext = cloneInboundContext(opts.Dispatch.InboundContext)
opts.RouteResult = cloneResolvedRoute(opts.Dispatch.RouteResult)
opts.SessionScope = session.CloneScope(opts.Dispatch.SessionScope)
if opts.InboundContext != nil {
if opts.Channel == "" {
opts.Channel = opts.InboundContext.Channel
}
if opts.ChatID == "" {
opts.ChatID = opts.InboundContext.ChatID
}
if opts.MessageID == "" {
opts.MessageID = opts.InboundContext.MessageID
}
if opts.ReplyToMessageID == "" {
opts.ReplyToMessageID = opts.InboundContext.ReplyToMessageID
}
if opts.SenderID == "" {
opts.SenderID = opts.InboundContext.SenderID
}
}
return opts
}
func inferChatTypeFromSessionScope(scope *session.SessionScope) string {
if scope == nil || len(scope.Values) == 0 {
return ""
}
chatValue := strings.TrimSpace(scope.Values["chat"])
if chatValue == "" {
return ""
}
chatType, _, ok := strings.Cut(chatValue, ":")
if !ok {
return ""
}
return strings.ToLower(strings.TrimSpace(chatType))
}
+135
View File
@@ -0,0 +1,135 @@
package agent
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
func TestNormalizeProcessOptions_PopulatesDispatchFromLegacyFields(t *testing.T) {
opts := normalizeProcessOptions(processOptions{
SessionKey: "session-1",
SessionAliases: []string{"legacy:one"},
Channel: "telegram",
ChatID: "chat-1",
MessageID: "msg-1",
ReplyToMessageID: "reply-1",
SenderID: "user-1",
UserMessage: "hello",
Media: []string{"media://one"},
})
if opts.Dispatch.SessionKey != "session-1" {
t.Fatalf("Dispatch.SessionKey = %q, want session-1", opts.Dispatch.SessionKey)
}
if len(opts.Dispatch.SessionAliases) != 1 || opts.Dispatch.SessionAliases[0] != "legacy:one" {
t.Fatalf("Dispatch.SessionAliases = %v, want [legacy:one]", opts.Dispatch.SessionAliases)
}
if opts.Dispatch.Channel() != "telegram" || opts.Dispatch.ChatID() != "chat-1" {
t.Fatalf(
"dispatch addressing = (%q,%q), want (telegram,chat-1)",
opts.Dispatch.Channel(),
opts.Dispatch.ChatID(),
)
}
if opts.Dispatch.SenderID() != "user-1" || opts.Dispatch.MessageID() != "msg-1" {
t.Fatalf("dispatch sender/message = (%q,%q)", opts.Dispatch.SenderID(), opts.Dispatch.MessageID())
}
if opts.Dispatch.ReplyToMessageID() != "reply-1" {
t.Fatalf("Dispatch.ReplyToMessageID() = %q, want reply-1", opts.Dispatch.ReplyToMessageID())
}
if opts.Dispatch.UserMessage != "hello" {
t.Fatalf("Dispatch.UserMessage = %q, want hello", opts.Dispatch.UserMessage)
}
if len(opts.Dispatch.Media) != 1 || opts.Dispatch.Media[0] != "media://one" {
t.Fatalf("Dispatch.Media = %v, want [media://one]", opts.Dispatch.Media)
}
}
func TestNormalizeProcessOptions_UsesDispatchAsSourceOfTruth(t *testing.T) {
inbound := &bus.InboundContext{
Channel: "slack",
ChatID: "C123",
ChatType: "channel",
SenderID: "U123",
MessageID: "m-1",
ReplyToMessageID: "parent-1",
}
route := &routing.ResolvedRoute{
AgentID: "support",
Channel: "slack",
AccountID: "workspace-a",
MatchedBy: "dispatch.rule:test",
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"chat", "sender"},
},
}
scope := &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "support",
Channel: "slack",
Account: "workspace-a",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "channel:c123",
},
}
opts := normalizeProcessOptions(processOptions{
Dispatch: DispatchRequest{
SessionKey: "sk_v1_example",
SessionAliases: []string{"agent:support:slack:channel:c123"},
InboundContext: inbound,
RouteResult: route,
SessionScope: scope,
UserMessage: "hello",
Media: []string{"media://one"},
},
})
if opts.SessionKey != "sk_v1_example" {
t.Fatalf("SessionKey = %q, want sk_v1_example", opts.SessionKey)
}
if opts.Channel != "slack" || opts.ChatID != "C123" {
t.Fatalf("legacy mirrors = (%q,%q), want (slack,C123)", opts.Channel, opts.ChatID)
}
if opts.SenderID != "U123" || opts.MessageID != "m-1" {
t.Fatalf("legacy sender/message = (%q,%q)", opts.SenderID, opts.MessageID)
}
if opts.ReplyToMessageID != "parent-1" {
t.Fatalf("ReplyToMessageID = %q, want parent-1", opts.ReplyToMessageID)
}
if opts.RouteResult == nil || opts.RouteResult.AgentID != "support" {
t.Fatalf("RouteResult = %#v, want support route", opts.RouteResult)
}
if opts.SessionScope == nil || opts.SessionScope.AgentID != "support" {
t.Fatalf("SessionScope = %#v, want support scope", opts.SessionScope)
}
}
func TestNormalizeProcessOptions_InfersLegacyChatTypeFromSessionScope(t *testing.T) {
opts := normalizeProcessOptions(processOptions{
Channel: "telegram",
ChatID: "-100123",
SenderID: "user-1",
UserMessage: "hello",
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "group:-100123",
},
},
})
if opts.Dispatch.InboundContext == nil {
t.Fatal("Dispatch.InboundContext is nil")
}
if opts.Dispatch.InboundContext.ChatType != "group" {
t.Fatalf("Dispatch.InboundContext.ChatType = %q, want group", opts.Dispatch.InboundContext.ChatType)
}
}
+39 -7
View File
@@ -10,6 +10,8 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -136,6 +138,31 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
InboundContext: &bus.InboundContext{
Channel: "cli",
ChatID: "direct",
ChatType: "direct",
SenderID: "tester",
},
RouteResult: &routing.ResolvedRoute{
AgentID: "main",
Channel: "cli",
AccountID: routing.DefaultAccountID,
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
MatchedBy: "default",
},
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "cli",
Account: routing.DefaultAccountID,
Dimensions: []string{"sender"},
Values: map[string]string{
"sender": "tester",
},
},
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
@@ -176,6 +203,18 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
if evt.Meta.SessionKey != "session-1" {
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
}
if evt.Context == nil || evt.Context.Inbound == nil {
t.Fatalf("event %d missing inbound turn context", i)
}
if evt.Context.Inbound.Channel != "cli" || evt.Context.Inbound.SenderID != "tester" {
t.Fatalf("event %d inbound context = %+v", i, evt.Context.Inbound)
}
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
t.Fatalf("event %d missing route context: %+v", i, evt.Context.Route)
}
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "tester" {
t.Fatalf("event %d missing session scope: %+v", i, evt.Context.Scope)
}
}
startPayload, ok := events[0].Payload.(TurnStartPayload)
@@ -472,7 +511,6 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
// Use legacyContextManager's summarizeSession via contextManager interface
lcm := &legacyContextManager{al: al}
lcm.summarizeSession(defaultAgent, "session-1")
@@ -572,12 +610,6 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
if payload.SourceTool != "async_followup" {
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
}
if payload.Channel != "cli" {
t.Fatalf("expected channel cli, got %q", payload.Channel)
}
if payload.ChatID != "direct" {
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
}
if payload.ContentLen != len("background result") {
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
}
+2 -4
View File
@@ -86,6 +86,7 @@ type Event struct {
Kind EventKind
Time time.Time
Meta EventMeta
Context *TurnContext
Payload any
}
@@ -98,6 +99,7 @@ type EventMeta struct {
Iteration int
TracePath string
Source string
turnContext *TurnContext
}
// TurnEndStatus describes the terminal state of a turn.
@@ -114,8 +116,6 @@ const (
// TurnStartPayload describes the start of a turn.
type TurnStartPayload struct {
Channel string
ChatID string
UserMessage string
MediaCount int
}
@@ -217,8 +217,6 @@ type SteeringInjectedPayload struct {
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
type FollowUpQueuedPayload struct {
SourceTool string
Channel string
ChatID string
ContentLen int
}
+11 -2
View File
@@ -12,7 +12,9 @@ import (
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
@@ -90,7 +92,8 @@ type processHookAfterLLMResponse struct {
type processHookBeforeToolResponse struct {
processHookDecisionResponse
Call *ToolCallHookRequest `json:"call,omitempty"`
Call *ToolCallHookRequest `json:"call,omitempty"`
Result *tools.ToolResult `json:"result,omitempty"` // Result returned directly by hook (for respond action)
}
type processHookAfterToolResponse struct {
@@ -120,7 +123,9 @@ func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (
if err != nil {
return nil, fmt.Errorf("create process hook stderr: %w", err)
}
if err := cmd.Start(); err != nil {
// Route hook subprocess startup through the shared isolation entry point so
// process hooks inherit the same isolation behavior as other child processes.
if err := isolation.Start(cmd); err != nil {
return nil, fmt.Errorf("start process hook: %w", err)
}
@@ -241,6 +246,10 @@ func (ph *ProcessHook) BeforeTool(
if resp.Call == nil {
resp.Call = call
}
// If hook returned a Result, carry it in ToolCallHookRequest
if resp.Result != nil {
resp.Call.HookResult = resp.Result
}
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
+126
View File
@@ -7,10 +7,13 @@ import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -178,6 +181,76 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
}
}
func TestAgentLoop_MountProcessHook_IsolationSupportsRelativeDirAndCommand(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("linux-only isolation path handling")
}
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
root := t.TempDir()
t.Setenv(config.EnvHome, filepath.Join(root, "picoclaw-home"))
binDir := filepath.Join(root, "bin")
hookDir := filepath.Join(root, "hooks")
if err := os.MkdirAll(binDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(hookDir, 0o755); err != nil {
t.Fatal(err)
}
writeFakeBwrap(t, filepath.Join(binDir, "bwrap"))
t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH"))
linkTestBinary(t, os.Args[0], filepath.Join(hookDir, "hook-helper"))
cfg := config.DefaultConfig()
cfg.Isolation.Enabled = true
isolation.Configure(cfg)
t.Cleanup(func() { isolation.Configure(config.DefaultConfig()) })
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
relHookDir, err := filepath.Rel(cwd, hookDir)
if err != nil {
t.Fatal(err)
}
mountErr := al.MountProcessHook(context.Background(), "ipc-relative", ProcessHookOptions{
Command: []string{"./hook-helper", "-test.run=TestProcessHook_HelperProcess", "--"},
Dir: relHookDir,
Env: processHookHelperEnv("rewrite", ""),
InterceptLLM: true,
})
if mountErr != nil {
t.Fatalf("MountProcessHook failed with relative dir/command under isolation: %v", mountErr)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-relative",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked llm content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
}
func processHookHelperCommand() []string {
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
}
@@ -193,6 +266,59 @@ func processHookHelperEnv(mode, eventLog string) []string {
return env
}
func writeFakeBwrap(t *testing.T, path string) {
t.Helper()
script := `#!/bin/sh
set -eu
workdir=
while [ "$#" -gt 0 ]; do
case "$1" in
--)
shift
break
;;
--chdir)
workdir="$2"
shift 2
;;
--bind|--ro-bind)
shift 3
;;
--proc|--dev)
shift 2
;;
--die-with-parent|--unshare-ipc)
shift
;;
*)
shift
;;
esac
done
if [ -n "$workdir" ]; then
cd "$workdir"
fi
exec "$@"
`
if err := os.WriteFile(path, []byte(script), 0o755); err != nil {
t.Fatalf("write fake bwrap: %v", err)
}
}
func linkTestBinary(t *testing.T, source, target string) {
t.Helper()
if err := os.Symlink(source, target); err == nil {
return
}
data, err := os.ReadFile(source)
if err != nil {
t.Fatalf("read test binary: %v", err)
}
if err := os.WriteFile(target, data, 0o755); err != nil {
t.Fatalf("create hook helper binary: %v", err)
}
}
func waitForFileContains(t *testing.T, path, substring string) {
t.Helper()
+34 -13
View File
@@ -25,6 +25,7 @@ type HookAction string
const (
HookActionContinue HookAction = "continue"
HookActionModify HookAction = "modify"
HookActionRespond HookAction = "respond" // Return result directly, skip tool execution. SECURITY: This bypasses ApproveTool checks, allowing hooks to return results for any tool (including sensitive ones like bash) without approval. Use with caution.
HookActionDenyTool HookAction = "deny_tool"
HookActionAbortTurn HookAction = "abort_turn"
HookActionHardAbort HookAction = "hard_abort"
@@ -89,12 +90,11 @@ type ToolApprover interface {
type LLMHookRequest struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
Tools []providers.ToolDefinition `json:"tools,omitempty"`
Options map[string]any `json:"options,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
}
@@ -103,6 +103,8 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Messages = cloneProviderMessages(r.Messages)
cloned.Tools = cloneToolDefinitions(r.Tools)
cloned.Options = cloneStringAnyMap(r.Options)
@@ -111,10 +113,9 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
type LLMHookResponse struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Model string `json:"model"`
Response *providers.LLMResponse `json:"response,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *LLMHookResponse) Clone() *LLMHookResponse {
@@ -122,16 +123,20 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Response = cloneLLMResponse(r.Response)
return &cloned
}
type ToolCallHookRequest struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
HookResult *tools.ToolResult `json:"hook_result,omitempty"` // Result returned directly by hook (for respond action). Media is supported - see Media handling section in docs.
}
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
@@ -139,16 +144,18 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.HookResult = cloneToolResult(r.HookResult)
return &cloned
}
type ToolApprovalRequest struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
@@ -156,18 +163,19 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolResultHookResponse struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Result *tools.ToolResult `json:"result,omitempty"`
Duration time.Duration `json:"duration"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
@@ -175,6 +183,8 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.Result = cloneToolResult(r.Result)
return &cloned
@@ -382,6 +392,10 @@ func (hm *HookManager) BeforeTool(
if next != nil {
current = next
}
case HookActionRespond:
// Hook returns result directly, skip tool execution
// Carry HookResult in ToolCallHookRequest and return
return next, decision
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
@@ -793,6 +807,13 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
if len(result.Media) > 0 {
cloned.Media = append([]string(nil), result.Media...)
}
if len(result.ArtifactTags) > 0 {
cloned.ArtifactTags = append([]string(nil), result.ArtifactTags...)
}
if len(result.Messages) > 0 {
cloned.Messages = make([]providers.Message, len(result.Messages))
copy(cloned.Messages, result.Messages)
}
return &cloned
}
+582 -1
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"errors"
"os"
"sync"
"testing"
@@ -10,6 +11,8 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -106,7 +109,8 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
}
type llmObserverHook struct {
eventCh chan Event
eventCh chan Event
lastInbound *bus.InboundContext
}
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
@@ -123,6 +127,9 @@ func (h *llmObserverHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
if req.Context != nil {
h.lastInbound = cloneInboundContext(req.Context.Inbound)
}
next := req.Clone()
next.Model = "hook-model"
return next, HookDecision{Action: HookActionModify}, nil
@@ -155,6 +162,31 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
InboundContext: &bus.InboundContext{
Channel: "cli",
ChatID: "direct",
ChatType: "direct",
SenderID: "hook-user",
},
RouteResult: &routing.ResolvedRoute{
AgentID: "main",
Channel: "cli",
AccountID: routing.DefaultAccountID,
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
MatchedBy: "default",
},
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "cli",
Account: routing.DefaultAccountID,
Dimensions: []string{"sender"},
Values: map[string]string{
"sender": "hook-user",
},
},
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
@@ -169,12 +201,30 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
if lastModel != "hook-model" {
t.Fatalf("expected model hook-model, got %q", lastModel)
}
if hook.lastInbound == nil {
t.Fatal("expected hook to receive inbound context")
}
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
}
if hook.lastInbound != nil && hook.lastInbound.ChatID != "direct" {
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
}
select {
case evt := <-hook.eventCh:
if evt.Kind != EventKindTurnEnd {
t.Fatalf("expected turn end event, got %v", evt.Kind)
}
if evt.Context == nil || evt.Context.Inbound == nil {
t.Fatal("expected observer event to carry inbound context")
}
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route)
}
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" {
t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for hook observer event")
}
@@ -343,3 +393,534 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
}
}
// respondHook is a test hook for testing HookActionRespond functionality
type respondHook struct {
respondTools map[string]bool // tool names to respond to
}
func (h *respondHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if h.respondTools[call.Tool] {
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: "hook-responded: " + call.Tool,
ForUser: "",
Silent: false,
IsError: false,
}
return next, HookDecision{Action: HookActionRespond}, nil
}
return call, HookDecision{Action: HookActionContinue}, nil
}
func (h *respondHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
// Should not be called since respond skips tool execution
return result, HookDecision{Action: HookActionContinue}, nil
}
func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("respond-hook", &respondHook{
respondTools: map[string]bool{"echo_text": true},
})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
// Verify response comes from hook, not tool
expected := "hook-responded: echo_text"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
// Verify event stream has ToolExecEnd, not actual tool execution
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
if !ok {
t.Fatal("expected tool exec end event")
}
payload, ok := endEvt.Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
}
if payload.Tool != "echo_text" {
t.Fatalf("expected tool echo_text, got %q", payload.Tool)
}
if payload.ForLLMLen != len(expected) {
t.Fatalf("expected ForLLMLen %d, got %d", len(expected), payload.ForLLMLen)
}
}
// denyToolHook tests HookActionDenyTool functionality
type denyToolHook struct {
denyTools map[string]bool
}
func (h *denyToolHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if h.denyTools[call.Tool] {
return call, HookDecision{Action: HookActionDenyTool, Reason: "tool denied by hook"}, nil
}
return call, HookDecision{Action: HookActionContinue}, nil
}
func (h *denyToolHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
return result, HookDecision{Action: HookActionContinue}, nil
}
func TestAgentLoop_Hooks_ToolDenyAction(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("deny-hook", &denyToolHook{
denyTools: map[string]bool{"echo_text": true},
})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by hook: tool denied by hook"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
}
func TestHookManager_BeforeTool_RespondAction(t *testing.T) {
hm := NewHookManager(nil)
defer hm.Close()
hook := &respondHook{
respondTools: map[string]bool{"test_tool": true},
}
if err := hm.Mount(NamedHook("respond-test", hook)); err != nil {
t.Fatalf("mount hook: %v", err)
}
req := &ToolCallHookRequest{
Tool: "test_tool",
Arguments: map[string]any{"arg": "value"},
}
result, decision := hm.BeforeTool(context.Background(), req)
if decision.Action != HookActionRespond {
t.Fatalf("expected action %q, got %q", HookActionRespond, decision.Action)
}
if result.HookResult == nil {
t.Fatal("expected HookResult to be set")
}
if result.HookResult.ForLLM != "hook-responded: test_tool" {
t.Fatalf("unexpected HookResult.ForLLM: %q", result.HookResult.ForLLM)
}
}
type respondWithMediaHook struct {
respondTools map[string]bool
media []string
responseHandled bool
forLLM string
}
func (h *respondWithMediaHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if h.respondTools[call.Tool] {
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: h.forLLM,
ForUser: "media result",
Media: h.media,
ResponseHandled: h.responseHandled,
Silent: false,
IsError: false,
}
return next, HookDecision{Action: HookActionRespond}, nil
}
return call, HookDecision{Action: HookActionContinue}, nil
}
func (h *respondWithMediaHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
return result, HookDecision{Action: HookActionContinue}, nil
}
type errorMediaChannel struct {
fakeChannel
sendErr error
}
func (f *errorMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
return nil, f.sendErr
}
func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &respondWithMediaHook{
respondTools: map[string]bool{"media_tool": true},
media: []string{"media://test/image.png"},
responseHandled: true,
forLLM: "media sent successfully",
}
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
al.channelManager = newStartedTestChannelManager(t, al.bus, al.mediaStore, "discord", &errorMediaChannel{
sendErr: errors.New("channel unavailable"),
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-media-err",
Channel: "discord",
ChatID: "chat1",
UserMessage: "send media",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
if !ok {
t.Fatal("expected ToolExecEnd event")
}
payload, ok := endEvt.Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
}
if !payload.IsError {
t.Fatal("expected IsError=true when SendMedia fails")
}
if payload.ForLLMLen < 30 {
t.Fatalf("expected ForLLM to contain error message, got ForLLMLen=%d", payload.ForLLMLen)
}
}
func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &respondWithMediaHook{
respondTools: map[string]bool{"media_tool": true},
media: []string{"media://test/image.png"},
responseHandled: true,
forLLM: "media queued",
}
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-bus-fallback",
Channel: "cli",
ChatID: "chat1",
UserMessage: "send media",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
if !ok {
t.Fatal("expected ToolExecEnd event")
}
payload, ok := endEvt.Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
}
if payload.IsError {
t.Fatal("expected IsError=false for bus fallback (media queued, not delivered)")
}
if resp != "done" {
t.Fatalf("expected response 'done', got %q", resp)
}
}
type multiToolProvider struct {
mu sync.Mutex
callCount int
toolCalls []providers.ToolCall
finalContent string
}
func (p *multiToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.callCount++
if p.callCount == 1 && len(p.toolCalls) > 0 {
return &providers.LLMResponse{
ToolCalls: p.toolCalls,
}, nil
}
return &providers.LLMResponse{
Content: p.finalContent,
}, nil
}
func (p *multiToolProvider) GetDefaultModel() string {
return "multi-tool-provider"
}
func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, _, cleanup := newHookTestLoop(t, provider)
defer cleanup()
tool1ExecCh := make(chan struct{}, 1)
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond, execCh: tool1ExecCh})
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
hook := &respondHook{
respondTools: map[string]bool{"tool_one": true},
}
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"run tools",
sessionKey,
"cli",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
time.Sleep(50 * time.Millisecond)
if err := al.InterruptGraceful("stop now"); err != nil {
t.Fatalf("InterruptGraceful failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for result")
}
events := collectEventStream(sub.C)
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
if len(skippedEvts) < 1 {
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
}
for _, evt := range skippedEvts {
payload, ok := evt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
}
if payload.Reason != "graceful interrupt requested" {
t.Fatalf("expected skip reason 'graceful interrupt requested', got %q", payload.Reason)
}
}
}
func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, _, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond})
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
hook := &respondHook{
respondTools: map[string]bool{"tool_one": true},
}
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"run tools",
sessionKey,
"cli",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
collectedEvents := make([]Event, 0, 8)
steered := false
deadline := time.After(3 * time.Second)
for !steered {
select {
case evt := <-sub.C:
collectedEvents = append(collectedEvents, evt)
if evt.Kind != EventKindToolExecEnd {
continue
}
payload, ok := evt.Payload.(ToolExecEndPayload)
if !ok || payload.Tool != "tool_one" {
continue
}
al.Steer(providers.Message{Role: "user", Content: "change direction"})
steered = true
case <-deadline:
t.Fatal("timeout waiting for tool_one to finish before steering")
}
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for result")
}
events := append(collectedEvents, collectEventStream(sub.C)...)
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
if len(skippedEvts) < 1 {
t.Fatal("expected at least one ToolExecSkipped event after steering")
}
for _, evt := range skippedEvts {
payload, ok := evt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
}
if payload.Reason != "queued user steering message" {
t.Fatalf("expected skip reason 'queued user steering message', got %q", payload.Reason)
}
}
}
func filterEvents(events []Event, kind EventKind) []Event {
var result []Event
for _, evt := range events {
if evt.Kind == kind {
result = append(result, evt)
}
}
return result
}
+7
View File
@@ -9,6 +9,7 @@ import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/memory"
@@ -64,6 +65,12 @@ func NewAgentInstance(
cfg *config.Config,
provider providers.LLMProvider,
) *AgentInstance {
if cfg != nil {
// Keep the subprocess isolation runtime aligned with the latest loaded config
// before any tools or providers start spawning child processes.
isolation.Configure(cfg)
}
workspace := resolveAgentWorkspace(agentCfg, defaults)
os.MkdirAll(workspace, 0o755)
+60
View File
@@ -0,0 +1,60 @@
package agent
import (
"strings"
"github.com/sipeed/picoclaw/pkg/providers"
)
func messagesContainMedia(messages []providers.Message) bool {
for _, msg := range messages {
for _, ref := range msg.Media {
if strings.TrimSpace(ref) != "" {
return true
}
}
}
return false
}
func stripMessageMedia(messages []providers.Message) []providers.Message {
if !messagesContainMedia(messages) {
return messages
}
stripped := make([]providers.Message, len(messages))
for i, msg := range messages {
stripped[i] = msg
stripped[i].Media = nil
}
return stripped
}
func isVisionUnsupportedError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
// OpenRouter (and OpenAI-compatible) style.
if strings.Contains(msg, "no endpoints found that support image input") {
return true
}
// Common provider variants.
if strings.Contains(msg, "does not support image input") ||
strings.Contains(msg, "does not support image inputs") ||
strings.Contains(msg, "does not support images") ||
strings.Contains(msg, "image input is not supported") ||
strings.Contains(msg, "images are not supported") ||
strings.Contains(msg, "does not support vision") ||
strings.Contains(msg, "unsupported content type: image_url") {
return true
}
// Some providers return a generic "invalid" message that still mentions image_url.
if strings.Contains(msg, "image_url") && strings.Contains(msg, "invalid") {
return true
}
return false
}
+817 -239
View File
File diff suppressed because it is too large Load Diff
+10
View File
@@ -24,6 +24,16 @@ type mcpRuntime struct {
initErr error
}
func (r *mcpRuntime) reset() *mcp.Manager {
r.mu.Lock()
manager := r.manager
r.manager = nil
r.initErr = nil
r.initOnce = sync.Once{}
r.mu.Unlock()
return manager
}
func (r *mcpRuntime) setManager(manager *mcp.Manager) {
r.mu.Lock()
r.manager = manager
+60
View File
@@ -7,13 +7,73 @@
package agent
import (
"context"
"errors"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/mcp"
)
func boolPtr(b bool) *bool { return &b }
func TestMCPRuntimeResetClearsState(t *testing.T) {
var rt mcpRuntime
manager := mcp.NewManager()
rt.setManager(manager)
rt.setInitErr(errors.New("stale init error"))
rt.initOnce.Do(func() {})
got := rt.reset()
if got != manager {
t.Fatalf("reset() manager = %p, want %p", got, manager)
}
if rt.hasManager() {
t.Fatal("expected manager to be cleared after reset")
}
if err := rt.getInitErr(); err != nil {
t.Fatalf("getInitErr() = %v, want nil", err)
}
reran := false
rt.initOnce.Do(func() { reran = true })
if !reran {
t.Fatal("expected initOnce to be reset")
}
}
func TestReloadProviderAndConfig_ResetsMCPRuntime(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
defer al.Close()
manager := mcp.NewManager()
al.mcp.setManager(manager)
al.mcp.setInitErr(errors.New("stale init error"))
al.mcp.initOnce.Do(func() {})
if !al.mcp.hasManager() {
t.Fatal("expected MCP manager to exist before reload")
}
if err := al.ReloadProviderAndConfig(context.Background(), &mockProvider{}, cfg); err != nil {
t.Fatalf("ReloadProviderAndConfig() error = %v", err)
}
if al.mcp.hasManager() {
t.Fatal("expected MCP manager to be cleared when reloaded config has MCP disabled")
}
if err := al.mcp.getInitErr(); err != nil {
t.Fatalf("getInitErr() = %v, want nil", err)
}
reran := false
al.mcp.initOnce.Do(func() { reran = true })
if !reran {
t.Fatal("expected MCP initOnce to be reset after reload")
}
}
func TestServerIsDeferred(t *testing.T) {
tests := []struct {
name string
+859 -105
View File
File diff suppressed because it is too large Load Diff
+4 -3
View File
@@ -3,6 +3,7 @@ package agent
import (
"sync"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -64,9 +65,9 @@ func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
return agent, ok
}
// ResolveRoute determines which agent handles the message.
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
return r.resolver.ResolveRoute(input)
// ResolveRoute determines which agent handles the normalized inbound context.
func (r *AgentRegistry) ResolveRoute(inbound bus.InboundContext) routing.ResolvedRoute {
return r.resolver.ResolveRoute(inbound)
}
// ListAgentIDs returns all registered agent IDs.
+35 -8
View File
@@ -3,12 +3,14 @@ package agent
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -290,12 +292,22 @@ func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
sessionKey, channel, chatID string,
scope *session.SessionScope,
steeringMsgs []providers.Message,
) (string, error) {
dispatch := DispatchRequest{
SessionKey: sessionKey,
SessionScope: session.CloneScope(scope),
}
if channel != "" || chatID != "" {
dispatch.InboundContext = &bus.InboundContext{
Channel: channel,
ChatID: chatID,
ChatType: inferChatTypeFromSessionScope(scope),
}
}
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
Dispatch: dispatch,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
@@ -310,9 +322,19 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
return nil
}
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
return agent
agentIDs := registry.ListAgentIDs()
sort.Strings(agentIDs)
for _, agentID := range agentIDs {
agent, ok := registry.GetAgent(agentID)
if !ok || agent == nil {
continue
}
resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey)
if resolvedAgentID == "" {
continue
}
if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok {
return scopedAgent
}
}
@@ -352,7 +374,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
}
}
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
var scope *session.SessionScope
if metaStore, ok := agent.Sessions.(session.MetadataAwareSessionStore); ok {
scope = metaStore.GetSessionScope(sessionKey)
}
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, scope, steeringMsgs)
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
+89 -36
View File
@@ -17,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -357,7 +358,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
},
},
Session: config.SessionConfig{
DMScope: "per-peer",
Dimensions: []string{"sender"},
},
}
@@ -365,14 +366,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
activeMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "active turn",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "active turn",
}
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
if !ok {
@@ -380,14 +380,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
}
otherMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user2",
ChatID: "chat2",
Content: "other session",
Peer: bus.Peer{
Kind: "direct",
ID: "user2",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat2",
ChatType: "direct",
SenderID: "user2",
},
Content: "other session",
}
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
if !ok {
@@ -422,9 +421,9 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
select {
case <-ctx.Done():
t.Fatalf("timeout waiting for requeued message on outbound bus")
case requeued := <-msgBus.OutboundChan():
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
t.Fatalf("timeout waiting for requeued message on inbound bus")
case requeued := <-msgBus.InboundChan():
if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID ||
requeued.Content != otherMsg.Content {
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
}
@@ -841,24 +840,22 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
}()
first := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "first message",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "first message",
}
late := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "late append",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "late append",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
@@ -949,7 +946,7 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
@@ -1013,6 +1010,62 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
}
}
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
List: []config.AgentConfig{
{ID: "sales", Default: true},
{ID: "support"},
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
support, ok := al.registry.GetAgent("support")
if !ok || support == nil {
t.Fatal("expected support agent")
}
metaStore, ok := support.Sessions.(session.MetadataAwareSessionStore)
if !ok {
t.Fatal("support session store does not support metadata")
}
alias := "agent:support:slack:channel:c001"
key := session.BuildOpaqueSessionKey(alias)
scope := &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "support",
Channel: "slack",
Account: "default",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "channel:c001",
},
}
metaStore.EnsureSessionMetadata(key, scope, []string{alias})
got := al.agentForSession(key)
if got == nil {
t.Fatal("agentForSession() returned nil")
}
if got.ID != "support" {
t.Fatalf("agentForSession() = %q, want %q", got.ID, "support")
}
}
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -1060,7 +1113,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.SetMediaStore(store)
@@ -1168,7 +1221,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
@@ -1322,7 +1375,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
+13 -7
View File
@@ -351,15 +351,17 @@ func spawnSubTurn(
}
// Create processOptions for the child turn
dispatch := DispatchRequest{
SessionKey: childID,
UserMessage: cfg.SystemPrompt,
Media: nil,
InboundContext: cloneInboundContext(parentTS.opts.Dispatch.InboundContext),
}
opts := processOptions{
SessionKey: childID,
Channel: parentTS.channel,
ChatID: parentTS.chatID,
SenderID: parentTS.opts.SenderID,
Dispatch: dispatch,
SenderID: parentTS.opts.Dispatch.SenderID(),
SenderDisplayName: parentTS.opts.SenderDisplayName,
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
SystemPromptOverride: cfg.ActualSystemPrompt,
Media: nil,
InitialSteeringMessages: cfg.InitialMessages,
DefaultResponse: "",
EnableSummary: false,
@@ -369,7 +371,11 @@ func spawnSubTurn(
}
// Create event scope for the child turn
scope := al.newTurnEventScope(agent.ID, childID)
scope := al.newTurnEventScope(
agent.ID,
childID,
newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope),
)
// Create child turnState using the new API
childTS := newTurnState(&agent, opts, scope)
+15 -12
View File
@@ -56,6 +56,7 @@ type turnState struct {
turnID string
agentID string
sessionKey string
turnCtx *TurnContext
channel string
chatID string
@@ -115,11 +116,12 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.SessionKey,
channel: opts.Channel,
chatID: opts.ChatID,
userMessage: opts.UserMessage,
media: append([]string(nil), opts.Media...),
sessionKey: opts.Dispatch.SessionKey,
turnCtx: cloneTurnContext(scope.context),
channel: opts.Dispatch.Channel(),
chatID: opts.Dispatch.ChatID(),
userMessage: opts.Dispatch.UserMessage,
media: append([]string(nil), opts.Dispatch.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
@@ -127,7 +129,7 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
}
return ts
@@ -302,12 +304,13 @@ func (ts *turnState) hardAbortRequested() bool {
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
snap := ts.snapshot()
return EventMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
turnContext: cloneTurnContext(ts.turnCtx),
}
}
+92
View File
@@ -0,0 +1,92 @@
package agent
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
// TurnContext carries normalized turn-scoped facts that can be shared across
// events, hooks, and other runtime observers without re-parsing legacy fields.
type TurnContext struct {
Inbound *bus.InboundContext `json:"inbound,omitempty"`
Route *routing.ResolvedRoute `json:"route,omitempty"`
Scope *session.SessionScope `json:"scope,omitempty"`
}
func newTurnContext(
inbound *bus.InboundContext,
route *routing.ResolvedRoute,
scope *session.SessionScope,
) *TurnContext {
if inbound == nil && route == nil && scope == nil {
return nil
}
return &TurnContext{
Inbound: cloneInboundContext(inbound),
Route: cloneResolvedRoute(route),
Scope: session.CloneScope(scope),
}
}
func cloneTurnContext(ctx *TurnContext) *TurnContext {
if ctx == nil {
return nil
}
cloned := *ctx
cloned.Inbound = cloneInboundContext(ctx.Inbound)
cloned.Route = cloneResolvedRoute(ctx.Route)
cloned.Scope = session.CloneScope(ctx.Scope)
return &cloned
}
func cloneInboundContext(ctx *bus.InboundContext) *bus.InboundContext {
if ctx == nil {
return nil
}
cloned := *ctx
cloned.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
cloned.Raw = cloneStringMap(ctx.Raw)
return &cloned
}
func cloneStringMap(src map[string]string) map[string]string {
if len(src) == 0 {
return nil
}
cloned := make(map[string]string, len(src))
for k, v := range src {
cloned[k] = v
}
return cloned
}
func cloneEventMeta(meta EventMeta) EventMeta {
meta.turnContext = cloneTurnContext(meta.turnContext)
return meta
}
func cloneResolvedRoute(route *routing.ResolvedRoute) *routing.ResolvedRoute {
if route == nil {
return nil
}
cloned := *route
cloned.SessionPolicy = routing.SessionPolicy{
Dimensions: append([]string(nil), route.SessionPolicy.Dimensions...),
IdentityLinks: cloneIdentityLinks(route.SessionPolicy.IdentityLinks),
}
return &cloned
}
func cloneIdentityLinks(src map[string][]string) map[string][]string {
if len(src) == 0 {
return nil
}
cloned := make(map[string][]string, len(src))
for canonical, ids := range src {
dup := make([]string, len(ids))
copy(dup, ids)
cloned[canonical] = dup
}
return cloned
}
+10 -9
View File
@@ -226,8 +226,7 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err})
}
if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: channelType,
ChatID: acc.chatID,
Context: bus.NewOutboundContext(channelType, acc.chatID, ""),
Content: "Goodbye! Leaving the voice channel.",
}); err != nil {
logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err})
@@ -238,14 +237,16 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally."
if err := a.bus.PublishInbound(ctx, bus.InboundMessage{
Channel: channelType,
SenderID: acc.speakerID,
ChatID: acc.chatID,
Content: res.Text + oralPrompt,
Peer: bus.Peer{Kind: "channel", ID: acc.chatID},
Metadata: map[string]string{
"is_voice": "true",
Context: bus.InboundContext{
Channel: channelType,
ChatID: acc.chatID,
ChatType: "channel",
SenderID: acc.speakerID,
Raw: map[string]string{
"is_voice": "true",
},
},
Content: res.Text + oralPrompt,
}); err != nil {
logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err})
}
+2 -2
View File
@@ -185,8 +185,8 @@ func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) {
if !strings.Contains(msg.Content, "hello there") {
t.Fatalf("unexpected inbound content: %q", msg.Content)
}
if msg.Metadata["is_voice"] != "true" {
t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata)
if msg.Context.Raw["is_voice"] != "true" {
t.Fatalf("expected is_voice metadata, got %#v", msg.Context.Raw)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("expected inbound publish")
+19 -1
View File
@@ -12,6 +12,12 @@ import (
// ErrBusClosed is returned when publishing to a closed MessageBus.
var ErrBusClosed = errors.New("message bus closed")
var (
ErrMissingInboundContext = errors.New("inbound message context is required")
ErrMissingOutboundContext = errors.New("outbound message context is required")
ErrMissingOutboundMediaContext = errors.New("outbound media context is required")
)
const defaultBusBufferSize = 64
// StreamDelegate is implemented by the channel Manager to provide streaming
@@ -49,7 +55,7 @@ func NewMessageBus() *MessageBus {
inbound: make(chan InboundMessage, defaultBusBufferSize),
outbound: make(chan OutboundMessage, defaultBusBufferSize),
outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize),
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer.
voiceControls: make(chan VoiceControl, defaultBusBufferSize),
done: make(chan struct{}),
}
@@ -84,6 +90,10 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
msg = NormalizeInboundMessage(msg)
if msg.Context.isZero() {
return ErrMissingInboundContext
}
return publish(ctx, mb, mb.inbound, msg)
}
@@ -92,6 +102,10 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
}
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
msg = NormalizeOutboundMessage(msg)
if msg.Context.isZero() {
return ErrMissingOutboundContext
}
return publish(ctx, mb, mb.outbound, msg)
}
@@ -100,6 +114,10 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
}
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
msg = NormalizeOutboundMediaMessage(msg)
if msg.Context.isZero() {
return ErrMissingOutboundMediaContext
}
return publish(ctx, mb, mb.outboundMedia, msg)
}
+438 -15
View File
@@ -14,10 +14,13 @@ func TestPublishConsume(t *testing.T) {
ctx := context.Background()
msg := InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "hello",
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "hello",
}
if err := mb.PublishInbound(ctx, msg); err != nil {
@@ -34,6 +37,138 @@ func TestPublishConsume(t *testing.T) {
if got.Channel != "test" {
t.Fatalf("expected channel 'test', got %q", got.Channel)
}
if got.Context.Channel != "test" {
t.Fatalf("expected context channel 'test', got %q", got.Context.Channel)
}
if got.Context.ChatID != "chat1" {
t.Fatalf("expected context chat ID 'chat1', got %q", got.Context.ChatID)
}
if got.Context.SenderID != "user1" {
t.Fatalf("expected context sender ID 'user1', got %q", got.Context.SenderID)
}
}
func TestPublishInbound_NormalizesContext(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := InboundMessage{
Context: InboundContext{
Channel: "slack",
Account: "workspace-a",
ChatID: "C456/1712",
ChatType: "group",
TopicID: "1712",
SpaceID: "T001",
SpaceType: "team",
SenderID: "U123",
MessageID: "1712.01",
ReplyToMessageID: "1700.01",
Mentioned: true,
},
Content: "hello",
}
if err := mb.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
got := <-mb.InboundChan()
if got.Context.Channel != "slack" {
t.Fatalf("expected context channel slack, got %q", got.Context.Channel)
}
if got.Context.Account != "workspace-a" {
t.Fatalf("expected context account workspace-a, got %q", got.Context.Account)
}
if got.Context.ChatType != "group" {
t.Fatalf("expected context chat type group, got %q", got.Context.ChatType)
}
if got.Context.TopicID != "1712" {
t.Fatalf("expected topic 1712, got %q", got.Context.TopicID)
}
if got.Context.SpaceType != "team" || got.Context.SpaceID != "T001" {
t.Fatalf("expected team space T001, got %q/%q", got.Context.SpaceType, got.Context.SpaceID)
}
if !got.Context.Mentioned {
t.Fatal("expected mentioned=true in context")
}
if got.Context.ReplyToMessageID != "1700.01" {
t.Fatalf("expected reply_to_message_id 1700.01, got %q", got.Context.ReplyToMessageID)
}
}
func TestPublishInbound_MirrorsContextIntoConvenienceFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := InboundMessage{
Context: InboundContext{
Channel: "telegram",
Account: "bot-a",
ChatID: "-1001",
ChatType: "group",
TopicID: "42",
SpaceID: "guild-9",
SpaceType: "guild",
SenderID: "user-1",
MessageID: "777",
Mentioned: true,
ReplyToMessageID: "666",
},
Content: "hi",
}
if err := mb.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
got := <-mb.InboundChan()
if got.Channel != "telegram" {
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
}
if got.ChatID != "-1001" {
t.Fatalf("expected legacy chat ID -1001, got %q", got.ChatID)
}
if got.SenderID != "user-1" {
t.Fatalf("expected legacy sender ID user-1, got %q", got.SenderID)
}
if got.MessageID != "777" {
t.Fatalf("expected legacy message ID 777, got %q", got.MessageID)
}
if got.Context.Account != "bot-a" || got.Context.SpaceID != "guild-9" || got.Context.TopicID != "42" {
t.Fatalf("unexpected normalized context: %+v", got.Context)
}
}
func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := InboundMessage{
Channel: "pico",
ChatID: "session-1",
SenderID: "user-1",
MessageID: "msg-1",
Content: "hello",
}
if err := mb.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
got := <-mb.InboundChan()
if got.Context.Channel != "pico" {
t.Fatalf("expected context channel pico, got %q", got.Context.Channel)
}
if got.Context.ChatID != "session-1" {
t.Fatalf("expected context chat ID session-1, got %q", got.Context.ChatID)
}
if got.Context.SenderID != "user-1" {
t.Fatalf("expected context sender ID user-1, got %q", got.Context.SenderID)
}
if got.Context.MessageID != "msg-1" {
t.Fatalf("expected context message ID msg-1, got %q", got.Context.MessageID)
}
}
func TestPublishOutboundSubscribe(t *testing.T) {
@@ -43,8 +178,10 @@ func TestPublishOutboundSubscribe(t *testing.T) {
ctx := context.Background()
msg := OutboundMessage{
Channel: "telegram",
ChatID: "123",
Context: InboundContext{
Channel: "telegram",
ChatID: "123",
},
Content: "world",
}
@@ -59,6 +196,222 @@ func TestPublishOutboundSubscribe(t *testing.T) {
if got.Content != "world" {
t.Fatalf("expected content 'world', got %q", got.Content)
}
if got.Context.Channel != "telegram" || got.Context.ChatID != "123" {
t.Fatalf("expected normalized outbound context, got %+v", got.Context)
}
}
func TestPublishOutbound_MirrorsContextToLegacyFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMessage{
Context: InboundContext{
Channel: "telegram",
ChatID: "chat-42",
ReplyToMessageID: "msg-9",
},
AgentID: "main",
SessionKey: "sk_v1_123",
Scope: &OutboundScope{
Version: 1,
AgentID: "main",
Channel: "telegram",
Account: "bot-a",
Dimensions: []string{"chat", "sender"},
Values: map[string]string{
"chat": "direct:chat-42",
"sender": "user-1",
},
},
Content: "reply",
}
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got := <-mb.OutboundChan()
if got.Channel != "telegram" {
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
}
if got.ChatID != "chat-42" {
t.Fatalf("expected legacy chat ID chat-42, got %q", got.ChatID)
}
if got.ReplyToMessageID != "msg-9" {
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
}
if got.AgentID != "main" || got.SessionKey != "sk_v1_123" {
t.Fatalf("unexpected outbound turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
}
if got.Scope == nil || got.Scope.AgentID != "main" || got.Scope.Values["chat"] != "direct:chat-42" {
t.Fatalf("unexpected outbound scope: %+v", got.Scope)
}
if got.Context.Channel != "telegram" || got.Context.ChatID != "chat-42" {
t.Fatalf("unexpected outbound context: %+v", got.Context)
}
}
func TestPublishOutbound_PreservesExplicitReplyToMessageID(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMessage{
Context: InboundContext{
Channel: "telegram",
ChatID: "chat-42",
},
ReplyToMessageID: "msg-9",
Content: "reply",
}
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got := <-mb.OutboundChan()
if got.ReplyToMessageID != "msg-9" {
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
}
if got.Context.ReplyToMessageID != "msg-9" {
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
}
}
func TestPublishOutbound_PreservesExplicitReplyToMessageIDWhenContextReplyIsBlank(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMessage{
Context: InboundContext{
Channel: "telegram",
ChatID: "chat-42",
ReplyToMessageID: " ",
},
ReplyToMessageID: "msg-9",
Content: "reply",
}
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got := <-mb.OutboundChan()
if got.ReplyToMessageID != "msg-9" {
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
}
if got.Context.ReplyToMessageID != "msg-9" {
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
}
}
func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMediaMessage{
Context: InboundContext{
Channel: "slack",
ChatID: "C001",
},
AgentID: "support",
SessionKey: "sk_v1_media",
Scope: &OutboundScope{
Version: 1,
AgentID: "support",
Channel: "slack",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "channel:c001",
},
},
Parts: []MediaPart{{Type: "image", Ref: "media://1"}},
}
if err := mb.PublishOutboundMedia(context.Background(), msg); err != nil {
t.Fatalf("PublishOutboundMedia failed: %v", err)
}
got := <-mb.OutboundMediaChan()
if got.Channel != "slack" {
t.Fatalf("expected legacy channel slack, got %q", got.Channel)
}
if got.ChatID != "C001" {
t.Fatalf("expected legacy chat ID C001, got %q", got.ChatID)
}
if got.AgentID != "support" || got.SessionKey != "sk_v1_media" {
t.Fatalf("unexpected outbound media turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
}
if got.Scope == nil || got.Scope.Values["chat"] != "channel:c001" {
t.Fatalf("unexpected outbound media scope: %+v", got.Scope)
}
if got.Context.Channel != "slack" || got.Context.ChatID != "C001" {
t.Fatalf("unexpected outbound media context: %+v", got.Context)
}
}
func TestPublishAudioChunkSubscribe(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
chunk := AudioChunk{
SessionID: "voice-1",
SpeakerID: "speaker-1",
ChatID: "chat-1",
Channel: "discord",
Sequence: 7,
Format: "opus",
Data: []byte{0x01, 0x02},
}
if err := mb.PublishAudioChunk(context.Background(), chunk); err != nil {
t.Fatalf("PublishAudioChunk failed: %v", err)
}
got, ok := <-mb.AudioChunksChan()
if !ok {
t.Fatal("AudioChunksChan returned ok=false")
}
if got.SessionID != "voice-1" || got.Sequence != 7 {
t.Fatalf("unexpected audio chunk: %+v", got)
}
}
func TestPublishVoiceControlSubscribe(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctrl := VoiceControl{
SessionID: "voice-1",
ChatID: "chat-1",
Type: "command",
Action: "start",
}
if err := mb.PublishVoiceControl(context.Background(), ctrl); err != nil {
t.Fatalf("PublishVoiceControl failed: %v", err)
}
got, ok := <-mb.VoiceControlsChan()
if !ok {
t.Fatal("VoiceControlsChan returned ok=false")
}
if got.Type != "command" || got.Action != "start" {
t.Fatalf("unexpected voice control: %+v", got)
}
}
func TestNewOutboundContext_NormalizesReplyAddress(t *testing.T) {
ctx := NewOutboundContext(" telegram ", " chat-42 ", " msg-9 ")
if ctx.Channel != "telegram" {
t.Fatalf("expected channel telegram, got %q", ctx.Channel)
}
if ctx.ChatID != "chat-42" {
t.Fatalf("expected chat_id chat-42, got %q", ctx.ChatID)
}
if ctx.ReplyToMessageID != "msg-9" {
t.Fatalf("expected reply_to_message_id msg-9, got %q", ctx.ReplyToMessageID)
}
}
func TestPublishInbound_ContextCancel(t *testing.T) {
@@ -68,7 +421,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
// Fill the buffer
ctx := context.Background()
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
if err := mb.PublishInbound(ctx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-fill",
ChatType: "direct",
SenderID: "user-fill",
},
Content: "fill",
}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
@@ -77,7 +438,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
err := mb.PublishInbound(cancelCtx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-overflow",
ChatType: "direct",
SenderID: "user-overflow",
},
Content: "overflow",
})
if err == nil {
t.Fatal("expected error from canceled context, got nil")
}
@@ -90,7 +459,15 @@ func TestPublishInbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
err := mb.PublishInbound(context.Background(), InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "test",
})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed, got %v", err)
}
@@ -100,7 +477,13 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
err := mb.PublishOutbound(context.Background(), OutboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
},
Content: "test",
})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed, got %v", err)
}
@@ -112,14 +495,30 @@ func TestConsumeInbound_ContextCancel(t *testing.T) {
defer mb.Close()
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
if err := mb.PublishInbound(context.Background(), InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-fill",
ChatType: "direct",
SenderID: "user-fill",
},
Content: "fill",
}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
mb.PublishInbound(ctx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-cancel",
ChatType: "direct",
SenderID: "user-cancel",
},
Content: "ContextCancel",
})
select {
case <-ctx.Done():
@@ -213,7 +612,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
// Fill the buffer
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
if err := mb.PublishInbound(ctx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-fill",
ChatType: "direct",
SenderID: "user-fill",
},
Content: "fill",
}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
@@ -222,7 +629,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
err := mb.PublishInbound(timeoutCtx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-overflow",
ChatType: "direct",
SenderID: "user-overflow",
},
Content: "overflow",
})
if err == nil {
t.Fatal("expected error when buffer is full and context times out")
}
@@ -240,7 +655,15 @@ func TestCloseIdempotent(t *testing.T) {
mb.Close()
// After close, publish should return ErrBusClosed
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
err := mb.PublishInbound(context.Background(), InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "test",
})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
}
+81
View File
@@ -0,0 +1,81 @@
package bus
import "strings"
// NormalizeInboundMessage ensures the inbound context is normalized and keeps
// convenience mirrors in sync for runtime consumers.
func NormalizeInboundMessage(msg InboundMessage) InboundMessage {
if msg.Context.Channel == "" {
msg.Context.Channel = msg.Channel
}
if msg.Context.ChatID == "" {
msg.Context.ChatID = msg.ChatID
}
if msg.Context.SenderID == "" {
msg.Context.SenderID = msg.SenderID
}
if msg.Context.MessageID == "" {
msg.Context.MessageID = msg.MessageID
}
msg.Context = normalizeInboundContext(msg.Context)
msg.Channel = msg.Context.Channel
msg.SenderID = msg.Context.SenderID
msg.ChatID = msg.Context.ChatID
if msg.MessageID == "" {
msg.MessageID = msg.Context.MessageID
}
if msg.Context.MessageID == "" {
msg.Context.MessageID = msg.MessageID
}
return msg
}
func (ctx InboundContext) isZero() bool {
return ctx.Channel == "" &&
ctx.Account == "" &&
ctx.ChatID == "" &&
ctx.ChatType == "" &&
ctx.TopicID == "" &&
ctx.SpaceID == "" &&
ctx.SpaceType == "" &&
ctx.SenderID == "" &&
ctx.MessageID == "" &&
!ctx.Mentioned &&
ctx.ReplyToMessageID == "" &&
ctx.ReplyToSenderID == "" &&
len(ctx.ReplyHandles) == 0 &&
len(ctx.Raw) == 0
}
func normalizeInboundContext(ctx InboundContext) InboundContext {
ctx.Channel = strings.TrimSpace(ctx.Channel)
ctx.Account = strings.TrimSpace(ctx.Account)
ctx.ChatID = strings.TrimSpace(ctx.ChatID)
ctx.ChatType = normalizeKind(ctx.ChatType)
ctx.TopicID = strings.TrimSpace(ctx.TopicID)
ctx.SpaceID = strings.TrimSpace(ctx.SpaceID)
ctx.SpaceType = normalizeKind(ctx.SpaceType)
ctx.SenderID = strings.TrimSpace(ctx.SenderID)
ctx.MessageID = strings.TrimSpace(ctx.MessageID)
ctx.ReplyToMessageID = strings.TrimSpace(ctx.ReplyToMessageID)
ctx.ReplyToSenderID = strings.TrimSpace(ctx.ReplyToSenderID)
ctx.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
ctx.Raw = cloneStringMap(ctx.Raw)
return ctx
}
func cloneStringMap(src map[string]string) map[string]string {
if len(src) == 0 {
return nil
}
dst := make(map[string]string, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func normalizeKind(kind string) string {
return strings.ToLower(strings.TrimSpace(kind))
}
+84
View File
@@ -0,0 +1,84 @@
package bus
import "strings"
// NewOutboundContext builds the minimal normalized addressing context required
// to deliver an outbound text message or reply.
func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext {
return normalizeInboundContext(InboundContext{
Channel: strings.TrimSpace(channel),
ChatID: strings.TrimSpace(chatID),
ReplyToMessageID: strings.TrimSpace(replyToMessageID),
})
}
// NormalizeOutboundMessage ensures Context is normalized and keeps convenience
// mirrors in sync for runtime consumers.
func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage {
msg.Channel = strings.TrimSpace(msg.Channel)
msg.ChatID = strings.TrimSpace(msg.ChatID)
msg.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID)
if msg.Context.Channel == "" {
msg.Context.Channel = msg.Channel
}
if msg.Context.ChatID == "" {
msg.Context.ChatID = msg.ChatID
}
if msg.Context.ReplyToMessageID == "" {
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
}
msg.Context = normalizeInboundContext(msg.Context)
if msg.Channel == "" {
msg.Channel = msg.Context.Channel
}
if msg.ChatID == "" {
msg.ChatID = msg.Context.ChatID
}
if msg.ReplyToMessageID == "" {
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
}
if msg.Context.ReplyToMessageID == "" {
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
}
msg.Scope = cloneOutboundScope(msg.Scope)
return msg
}
// NormalizeOutboundMediaMessage ensures media outbound messages also carry a
// normalized context while keeping convenience mirrors in sync.
func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage {
msg.Channel = strings.TrimSpace(msg.Channel)
msg.ChatID = strings.TrimSpace(msg.ChatID)
if msg.Context.Channel == "" {
msg.Context.Channel = msg.Channel
}
if msg.Context.ChatID == "" {
msg.Context.ChatID = msg.ChatID
}
msg.Context = normalizeInboundContext(msg.Context)
if msg.Channel == "" {
msg.Channel = msg.Context.Channel
}
if msg.ChatID == "" {
msg.ChatID = msg.Context.ChatID
}
msg.Scope = cloneOutboundScope(msg.Scope)
return msg
}
func cloneOutboundScope(scope *OutboundScope) *OutboundScope {
if scope == nil {
return nil
}
cloned := *scope
if len(scope.Dimensions) > 0 {
cloned.Dimensions = append([]string(nil), scope.Dimensions...)
}
if len(scope.Values) > 0 {
cloned.Values = make(map[string]string, len(scope.Values))
for key, value := range scope.Values {
cloned.Values[key] = value
}
}
return &cloned
}
+64 -25
View File
@@ -1,11 +1,5 @@
package bus
// Peer identifies the routing peer for a message (direct, group, channel, etc.)
type Peer struct {
Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
ID string `json:"id"`
}
// SenderInfo provides structured sender identity information.
type SenderInfo struct {
Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ...
@@ -15,26 +9,67 @@ type SenderInfo struct {
DisplayName string `json:"display_name,omitempty"` // display name
}
// InboundContext captures the normalized, platform-agnostic facts about an
// inbound message. This is the source of truth for routing and session
// allocation.
type InboundContext struct {
Channel string `json:"channel"`
Account string `json:"account,omitempty"`
ChatID string `json:"chat_id"`
ChatType string `json:"chat_type,omitempty"` // direct / group / channel
TopicID string `json:"topic_id,omitempty"`
SpaceID string `json:"space_id,omitempty"`
SpaceType string `json:"space_type,omitempty"` // guild / team / workspace / tenant
SenderID string `json:"sender_id"`
MessageID string `json:"message_id,omitempty"`
Mentioned bool `json:"mentioned,omitempty"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
ReplyToSenderID string `json:"reply_to_sender_id,omitempty"`
ReplyHandles map[string]string `json:"reply_handles,omitempty"`
Raw map[string]string `json:"raw,omitempty"`
}
type InboundMessage struct {
Channel string `json:"channel"`
SenderID string `json:"sender_id"`
Sender SenderInfo `json:"sender"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
Media []string `json:"media,omitempty"`
Peer Peer `json:"peer"` // routing peer
MessageID string `json:"message_id,omitempty"` // platform message ID
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
SessionKey string `json:"session_key"`
Metadata map[string]string `json:"metadata,omitempty"`
Context InboundContext `json:"context"`
Sender SenderInfo `json:"sender"`
Content string `json:"content"`
Media []string `json:"media,omitempty"`
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
SessionKey string `json:"session_key"`
// Convenience mirrors derived from Context for runtime consumers.
Channel string `json:"channel"`
SenderID string `json:"sender_id"`
ChatID string `json:"chat_id"`
MessageID string `json:"message_id,omitempty"` // platform message ID
}
// OutboundScope captures the structured session scope associated with an
// outbound turn result without depending on the session package.
type OutboundScope struct {
Version int `json:"version,omitempty"`
AgentID string `json:"agent_id,omitempty"`
Channel string `json:"channel,omitempty"`
Account string `json:"account,omitempty"`
Dimensions []string `json:"dimensions,omitempty"`
Values map[string]string `json:"values,omitempty"`
}
type OutboundMessage struct {
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Context InboundContext `json:"context"`
AgentID string `json:"agent_id,omitempty"`
SessionKey string `json:"session_key,omitempty"`
Scope *OutboundScope `json:"scope,omitempty"`
Content string `json:"content"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
}
// MediaPart describes a single media attachment to send.
@@ -48,9 +83,13 @@ type MediaPart struct {
// OutboundMediaMessage carries media attachments from Agent to channels via the bus.
type OutboundMediaMessage struct {
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Parts []MediaPart `json:"parts"`
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Context InboundContext `json:"context"`
AgentID string `json:"agent_id,omitempty"`
SessionKey string `json:"session_key,omitempty"`
Scope *OutboundScope `json:"scope,omitempty"`
Parts []MediaPart `json:"parts"`
}
// AudioChunk represents a chunk of streaming voice data.
+80 -33
View File
@@ -327,8 +327,13 @@ import (
)
func init() {
channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewTelegramChannel(cfg, b)
channels.RegisterFactory(config.ChannelTelegram, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil { return nil, err }
c, ok := decoded.(*config.TelegramSettings)
if !ok { return nil, channels.ErrSendFailed }
return NewTelegramChannel(bc, c, b)
})
}
```
@@ -427,8 +432,13 @@ import (
)
func init() {
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewMatrixChannel(cfg, b)
channels.RegisterFactory(config.ChannelMatrix, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil { return nil, err }
c, ok := decoded.(*config.MatrixSettings)
if !ok { return nil, channels.ErrSendFailed }
return NewMatrixChannel(bc, c, b)
})
}
```
@@ -773,41 +783,59 @@ When the Agent finishes processing a message, Manager's `preSend` automatically:
### 3.5 Register Configuration and Gateway Integration
#### Add configuration in `pkg/config/config.go`
#### Add configuration entry
Channels now use a unified map-based configuration (`map[string]*config.Channel`).
Each channel entry stores common fields (`enabled`, `type`, `allow_from`, etc.) at
the top level, with channel-specific settings in the `settings` sub-key:
```json
{
"channels": {
"matrix": {
"enabled": true,
"type": "matrix",
"allow_from": ["@user:example.com"],
"settings": {
"home_server": "https://matrix.org",
"user_id": "@bot:example.com",
"access_token": "enc://..."
}
}
}
}
```
Secure fields (tokens, passwords, API keys) go into `.security.yml`:
```yaml
channels:
matrix:
access_token: "your-matrix-access-token"
```
Channel types must be registered in `channelSettingsFactory` in
`pkg/config/config_channel.go`:
```go
type ChannelsConfig struct {
var channelSettingsFactory = map[string]any{
// ... existing channels
Matrix MatrixChannelConfig `json:"matrix"`
}
type MatrixChannelConfig struct {
Enabled bool `json:"enabled"`
HomeServer string `json:"home_server"`
Token string `json:"token"`
AllowFrom []string `json:"allow_from"`
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
Placeholder PlaceholderConfig `json:"placeholder"`
ReasoningChannelID string `json:"reasoning_channel_id"`
ChannelMatrix: (MatrixSettings{}),
}
```
#### Add entry in Manager.initChannels()
#### No Manager changes needed
```go
// In the initChannels() method of pkg/channels/manager.go
if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
m.initChannel("matrix", "Matrix")
}
```
The Manager uses `InitChannelList()` to validate types and decode settings,
then looks up factories by `bc.Type`. No per-channel entry needed in Manager
just register the factory and the config entry.
> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config:
> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native),
> register both types in `channelSettingsFactory` and branch on config:
> ```go
> if cfg.UseNative {
> m.initChannel("whatsapp_native", "WhatsApp Native")
> } else {
> m.initChannel("whatsapp", "WhatsApp")
> }
> // In config_channel.go:
> ChannelWhatsApp: (WhatsAppSettings{}),
> ChannelWhatsAppNative: (WhatsAppSettings{}),
> ```
#### Add blank import in Gateway
@@ -947,10 +975,29 @@ channels.WithReasoningChannelID(id) // Set reasoning chain routing target
**File**: `pkg/channels/registry.go`
```go
type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
type ChannelFactory func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (Channel, error)
func RegisterFactory(name string, f ChannelFactory) // Called in sub-package init()
func getFactory(name string) (ChannelFactory, bool) // Called internally by Manager
func RegisterFactory(name string, f ChannelFactory) // Called in sub-package init()
func getFactory(name string) (ChannelFactory, bool) // Called internally by Manager
func GetRegisteredFactoryNames() []string // Returns all registered factory names
```
For convenience, `RegisterSafeFactory[S any]` provides automatic type-safe settings decoding:
```go
// Instead of manual GetDecoded() + type assertion:
channels.RegisterFactory(config.ChannelTelegram,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil { return nil, err }
c, ok := decoded.(*config.TelegramSettings)
if !ok { return nil, ErrSendFailed }
return NewTelegramChannel(bc, c, b)
})
// You can use RegisterSafeFactory (same safety, less boilerplate):
channels.RegisterSafeFactory(config.ChannelTelegram, NewTelegramChannel)
```
The factory registry is protected by `sync.RWMutex` and registrations occur during `init()` phase (completed at process startup). Manager looks up factories by name in `initChannel()` and calls them.
+79 -33
View File
@@ -327,8 +327,13 @@ import (
)
func init() {
channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewTelegramChannel(cfg, b)
channels.RegisterFactory(config.ChannelTelegram, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil { return nil, err }
c, ok := decoded.(*config.TelegramSettings)
if !ok { return nil, channels.ErrSendFailed }
return NewTelegramChannel(bc, c, b)
})
}
```
@@ -427,8 +432,13 @@ import (
)
func init() {
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewMatrixChannel(cfg, b)
channels.RegisterFactory(config.ChannelMatrix, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil { return nil, err }
c, ok := decoded.(*config.MatrixSettings)
if !ok { return nil, channels.ErrSendFailed }
return NewMatrixChannel(bc, c, b)
})
}
```
@@ -772,41 +782,58 @@ if c.owner != nil && c.placeholderRecorder != nil {
### 3.5 注册配置和 Gateway 接入
#### `pkg/config/config.go`添加配置
#### 添加配置入口
Channels 现在使用统一的 map 类型配置(`map[string]*config.Channel`)。
每个 channel 条目将通用字段(`enabled``type``allow_from` 等)放在顶层,
channel 特定的设置放在 `settings` 子键中:
```json
{
"channels": {
"matrix": {
"enabled": true,
"type": "matrix",
"allow_from": ["@user:example.com"],
"settings": {
"home_server": "https://matrix.org",
"user_id": "@bot:example.com",
"access_token": "enc://..."
}
}
}
}
```
安全字段(token、密码、API 密钥)放入 `.security.yml`
```yaml
channels:
matrix:
access_token: "your-matrix-access-token"
```
Channel 类型必须在 `pkg/config/config_channel.go``channelSettingsFactory` 中注册:
```go
type ChannelsConfig struct {
var channelSettingsFactory = map[string]any{
// ... 现有 channels
Matrix MatrixChannelConfig `json:"matrix"`
}
type MatrixChannelConfig struct {
Enabled bool `json:"enabled"`
HomeServer string `json:"home_server"`
Token string `json:"token"`
AllowFrom []string `json:"allow_from"`
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
Placeholder PlaceholderConfig `json:"placeholder"`
ReasoningChannelID string `json:"reasoning_channel_id"`
ChannelMatrix: (MatrixSettings{}),
}
```
#### Manager.initChannels() 中添加入口
#### 无需修改 Manager
```go
// pkg/channels/manager.go 的 initChannels() 方法中
if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
m.initChannel("matrix", "Matrix")
}
```
Manager 使用 `InitChannelList()` 来验证类型和解码设置,
然后通过 `bc.Type` 查找工厂。不需要在 Manager 中添加每个 channel 的条目——
只需注册工厂和配置条目即可。
> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),需要在 initChannels 中根据配置分支:
> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),
> 在 `channelSettingsFactory` 中注册两种类型,并根据配置分支:
> ```go
> if cfg.UseNative {
> m.initChannel("whatsapp_native", "WhatsApp Native")
> } else {
> m.initChannel("whatsapp", "WhatsApp")
> }
> // 在 config_channel.go 中:
> ChannelWhatsApp: (WhatsAppSettings{}),
> ChannelWhatsAppNative: (WhatsAppSettings{}),
> ```
#### 在 Gateway 中添加 blank import
@@ -946,10 +973,29 @@ channels.WithReasoningChannelID(id) // 设置思维链路由目标 channe
**文件**`pkg/channels/registry.go`
```go
type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
type ChannelFactory func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (Channel, error)
func RegisterFactory(name string, f ChannelFactory) // 子包 init() 中调用
func getFactory(name string) (ChannelFactory, bool) // Manager 内部调用
func RegisterFactory(name string, f ChannelFactory) // 子包 init() 中调用
func getFactory(name string) (ChannelFactory, bool) // Manager 内部调用
func GetRegisteredFactoryNames() []string // 返回所有已注册的工厂名称
```
为方便使用,`RegisterSafeFactory[S any]` 提供自动类型安全的设置解码:
```go
// 不使用 RegisterSafeFactory(手动 GetDecoded() + 类型断言):
channels.RegisterFactory(config.ChannelTelegram,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil { return nil, err }
c, ok := decoded.(*config.TelegramSettings)
if !ok { return nil, ErrSendFailed }
return NewTelegramChannel(bc, c, b)
})
// 使用 RegisterSafeFactory(同等安全,减少样板代码):
channels.RegisterSafeFactory(config.ChannelTelegram, NewTelegramChannel)
```
工厂注册表使用 `sync.RWMutex` 保护,在 `init()` 阶段注册(进程启动时完成)。Manager 在 `initChannel()` 中通过名字查找工厂并调用它。
+55 -19
View File
@@ -103,6 +103,16 @@ func NewBaseChannel(
allowList []string,
opts ...BaseChannelOption,
) *BaseChannel {
isEmpty := true
for _, s := range allowList {
if s != "" {
isEmpty = false
break
}
}
if isEmpty {
allowList = []string{}
}
bc := &BaseChannel{
config: config,
bus: bus,
@@ -177,6 +187,12 @@ func (c *BaseChannel) Name() string {
return c.name
}
// SetName updates the channel name. Used by the manager after channel creation
// to ensure the name matches the config key (which may differ from the type).
func (c *BaseChannel) SetName(name string) {
c.name = name
}
func (c *BaseChannel) ReasoningChannelID() string {
return c.reasoningChannelID
}
@@ -244,12 +260,11 @@ func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool {
return false
}
func (c *BaseChannel) HandleMessage(
func (c *BaseChannel) HandleMessageWithContext(
ctx context.Context,
peer bus.Peer,
messageID, senderID, chatID, content string,
deliveryChatID, content string,
media []string,
metadata map[string]string,
inboundCtx bus.InboundContext,
senderOpts ...bus.SenderInfo,
) {
// Use SenderInfo-based allow check when available, else fall back to string
@@ -257,6 +272,7 @@ func (c *BaseChannel) HandleMessage(
if len(senderOpts) > 0 {
sender = senderOpts[0]
}
senderID := strings.TrimSpace(inboundCtx.SenderID)
if sender.CanonicalID != "" || sender.PlatformID != "" {
if !c.IsAllowedSender(sender) {
return
@@ -273,20 +289,28 @@ func (c *BaseChannel) HandleMessage(
resolvedSenderID = sender.CanonicalID
}
scope := BuildMediaScope(c.name, chatID, messageID)
if resolvedSenderID == "" {
resolvedSenderID = senderID
}
inboundCtx.Channel = c.name
if inboundCtx.ChatID == "" {
inboundCtx.ChatID = deliveryChatID
}
if inboundCtx.SenderID == "" {
inboundCtx.SenderID = resolvedSenderID
}
scope := BuildMediaScope(c.name, deliveryChatID, inboundCtx.MessageID)
msg := bus.InboundMessage{
Channel: c.name,
SenderID: resolvedSenderID,
Context: inboundCtx,
Sender: sender,
ChatID: chatID,
Content: content,
Media: media,
Peer: peer,
MessageID: messageID,
MediaScope: scope,
Metadata: metadata,
}
msg = bus.NormalizeInboundMessage(msg)
// Auto-trigger typing indicator, message reaction, and placeholder before publishing.
// Each capability is independent — all three may fire for the same message.
@@ -297,14 +321,14 @@ func (c *BaseChannel) HandleMessage(
if c.owner != nil && c.placeholderRecorder != nil {
// Typing
if tc, ok := c.owner.(TypingCapable); ok {
if stop, err := tc.StartTyping(ctx, chatID); err == nil {
c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
if stop, err := tc.StartTyping(ctx, deliveryChatID); err == nil {
c.placeholderRecorder.RecordTypingStop(c.name, deliveryChatID, stop)
}
}
// Reaction
if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
if rc, ok := c.owner.(ReactionCapable); ok && msg.MessageID != "" {
if undo, err := rc.ReactToMessage(ctx, deliveryChatID, msg.MessageID); err == nil {
c.placeholderRecorder.RecordReactionUndo(c.name, deliveryChatID, undo)
}
}
// Placeholder — independent pipeline.
@@ -313,8 +337,8 @@ func (c *BaseChannel) HandleMessage(
// "Thinking…" only once the voice has been processed.
if !audioAnnotationRe.MatchString(content) {
if pc, ok := c.owner.(PlaceholderCapable); ok {
if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" {
c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID)
if phID, err := pc.SendPlaceholder(ctx, deliveryChatID); err == nil && phID != "" {
c.placeholderRecorder.RecordPlaceholder(c.name, deliveryChatID, phID)
}
}
}
@@ -323,12 +347,24 @@ func (c *BaseChannel) HandleMessage(
if err := c.bus.PublishInbound(ctx, msg); err != nil {
logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{
"channel": c.name,
"chat_id": chatID,
"chat_id": deliveryChatID,
"error": err.Error(),
})
}
}
// HandleInboundContext publishes a normalized inbound message using only the
// structured context.
func (c *BaseChannel) HandleInboundContext(
ctx context.Context,
deliveryChatID, content string,
media []string,
inboundCtx bus.InboundContext,
senderOpts ...bus.SenderInfo,
) {
c.HandleMessageWithContext(ctx, deliveryChatID, content, media, inboundCtx, senderOpts...)
}
func (c *BaseChannel) SetRunning(running bool) {
c.running.Store(running)
}
+56
View File
@@ -1,6 +1,7 @@
package channels
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -263,3 +264,58 @@ func TestIsAllowedSender(t *testing.T) {
})
}
}
func TestHandleInboundContext_PublishesNormalizedContext(t *testing.T) {
tests := []struct {
name string
inbound bus.InboundContext
wantChat string
wantSender string
}{
{
name: "direct uses sender as peer",
inbound: bus.InboundContext{
Channel: "test",
ChatID: "chat-1",
ChatType: "direct",
SenderID: "user-1",
MessageID: "msg-1",
},
wantChat: "chat-1",
wantSender: "user-1",
},
{
name: "group uses chat as peer",
inbound: bus.InboundContext{
Channel: "test",
ChatID: "group-1",
ChatType: "group",
SenderID: "user-2",
MessageID: "msg-2",
},
wantChat: "group-1",
wantSender: "user-2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgBus := bus.NewMessageBus()
defer msgBus.Close()
ch := NewBaseChannel("test", nil, msgBus, nil)
ch.HandleInboundContext(context.Background(), tt.inbound.ChatID, "hello", nil, tt.inbound)
msg := <-msgBus.InboundChan()
if msg.ChatID != tt.wantChat {
t.Fatalf("ChatID = %q, want %q", msg.ChatID, tt.wantChat)
}
if msg.SenderID != tt.wantSender {
t.Fatalf("SenderID = %q, want %q", msg.SenderID, tt.wantSender)
}
if msg.Context.ChatType != tt.inbound.ChatType {
t.Fatalf("ChatType = %q, want %q", msg.Context.ChatType, tt.inbound.ChatType)
}
})
}
}
+31 -15
View File
@@ -25,7 +25,7 @@ import (
// It uses WebSocket for receiving messages via stream mode and API for sending
type DingTalkChannel struct {
*channels.BaseChannel
config config.DingTalkConfig
config *config.DingTalkSettings
clientID string
clientSecret string
streamClient *client.StreamClient
@@ -36,7 +36,11 @@ type DingTalkChannel struct {
}
// NewDingTalkChannel creates a new DingTalk channel instance
func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (*DingTalkChannel, error) {
func NewDingTalkChannel(
bc *config.Channel,
cfg *config.DingTalkSettings,
messageBus *bus.MessageBus,
) (*DingTalkChannel, error) {
if cfg.ClientID == "" || cfg.ClientSecret.String() == "" {
return nil, fmt.Errorf("dingtalk client_id and client_secret are required")
}
@@ -44,10 +48,10 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
// Set the logger for the Stream SDK
dinglog.SetLogger(logger.NewLogger("dingtalk"))
base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom,
base := channels.NewBaseChannel("dingtalk", cfg, messageBus, bc.AllowFrom,
channels.WithMaxMessageLength(20000),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &DingTalkChannel{
@@ -181,16 +185,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
"session_webhook": data.SessionWebhook,
}
var peer bus.Peer
var (
chatType string
isMentioned bool
)
if data.ConversationType == "1" {
peerID := senderID
if peerID == "" {
peerID = chatID
}
peer = bus.Peer{Kind: "direct", ID: peerID}
chatType = "direct"
} else {
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
isMentioned := data.IsInAtList
chatType = "group"
isMentioned = data.IsInAtList
if isMentioned {
content = stripLeadingAtMentions(content)
}
@@ -228,8 +231,21 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
return nil, nil
}
// Handle the message through the base channel
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "dingtalk",
ChatID: chatID,
ChatType: chatType,
SenderID: resolvedSenderID,
Mentioned: isMentioned,
Raw: metadata,
}
if data.SessionWebhook != "" {
inboundCtx.ReplyHandles = map[string]string{
"session_webhook": data.SessionWebhook,
}
}
c.HandleInboundContext(ctx, chatID, content, nil, inboundCtx, sender)
// Return nil to indicate we've handled the message asynchronously
// The response will be sent through the message bus
+23 -10
View File
@@ -11,7 +11,11 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
)
func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkChannel, *bus.MessageBus) {
func newTestDingTalkChannel(
t *testing.T,
cfg config.DingTalkSettings,
bc *config.Channel,
) (*DingTalkChannel, *bus.MessageBus) {
t.Helper()
if cfg.ClientID == "" {
@@ -22,7 +26,10 @@ func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkC
}
msgBus := bus.NewMessageBus()
ch, err := NewDingTalkChannel(cfg, msgBus)
if bc == nil {
bc = &config.Channel{Type: config.ChannelDingTalk, Enabled: true}
}
ch, err := NewDingTalkChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("new channel: %v", err)
}
@@ -41,9 +48,12 @@ func mustReceiveInbound(t *testing.T, msgBus *bus.MessageBus) bus.InboundMessage
}
func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention(t *testing.T) {
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{
bc := &config.Channel{
Type: config.ChannelDingTalk,
Enabled: true,
GroupTrigger: config.GroupTriggerConfig{MentionOnly: true},
})
}
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkSettings{}, bc)
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
Text: chatbot.BotCallbackDataTextModel{Content: " @bot /help "},
@@ -65,8 +75,8 @@ func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention
if inbound.ChatID != "group-abc" {
t.Fatalf("chat_id=%q", inbound.ChatID)
}
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" {
t.Fatalf("peer=%+v", inbound.Peer)
if inbound.Context.ChatType != "group" {
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
}
if inbound.Content != "/help" {
t.Fatalf("content=%q", inbound.Content)
@@ -74,7 +84,7 @@ func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention
}
func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *testing.T) {
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{})
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkSettings{}, nil)
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
Text: chatbot.BotCallbackDataTextModel{Content: "ping"},
@@ -93,12 +103,15 @@ func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *te
if inbound.ChatID != "conv-direct-42" {
t.Fatalf("chat_id=%q", inbound.ChatID)
}
if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" {
t.Fatalf("peer=%+v", inbound.Peer)
if inbound.Context.ChatType != "direct" {
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
}
if inbound.SenderID != "dingtalk:openid-user-42" {
if inbound.SenderID != "openid-user-42" {
t.Fatalf("sender_id=%q", inbound.SenderID)
}
if inbound.Sender.CanonicalID != "dingtalk:openid-user-42" {
t.Fatalf("sender canonical_id=%q", inbound.Sender.CanonicalID)
}
if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok {
t.Fatal("expected session webhook keyed by conversation_id")
+22 -3
View File
@@ -7,7 +7,26 @@ import (
)
func init() {
channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewDingTalkChannel(cfg.Channels.DingTalk, b)
})
channels.RegisterFactory(
config.ChannelDingTalk,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.DingTalkSettings)
if !ok {
return nil, channels.ErrSendFailed
}
ch, err := NewDingTalkChannel(bc, c, b)
if err != nil {
return nil, err
}
if channelName != config.ChannelDingTalk {
ch.SetName(channelName)
}
return ch, nil
},
)
}
+31 -13
View File
@@ -38,8 +38,9 @@ var (
type DiscordChannel struct {
*channels.BaseChannel
bc *config.Channel
session *discordgo.Session
config config.DiscordConfig
config *config.DiscordSettings
ctx context.Context
cancel context.CancelFunc
typingMu sync.Mutex
@@ -56,7 +57,11 @@ type DiscordChannel struct {
ttsPlayID uint64
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
func NewDiscordChannel(
bc *config.Channel,
cfg *config.DiscordSettings,
bus *bus.MessageBus,
) (*DiscordChannel, error) {
discordgo.Logger = logger.NewLogger("discord").
WithLevels(map[int]logger.LogLevel{
discordgo.LogError: logger.ERROR,
@@ -73,14 +78,15 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
return nil, err
}
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
base := channels.NewBaseChannel("discord", cfg, bus, bc.AllowFrom,
channels.WithMaxMessageLength(2000),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &DiscordChannel{
BaseChannel: base,
bc: bc,
session: session,
config: cfg,
ctx: context.Background(),
@@ -297,11 +303,11 @@ func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, message
// It sends a placeholder message that will later be edited to the actual
// response via EditMessage (channels.MessageEditor).
func (c *DiscordChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
if !c.config.Placeholder.Enabled {
if !c.bc.Placeholder.Enabled {
return "", nil
}
text := c.config.Placeholder.GetRandomText()
text := c.bc.Placeholder.GetRandomText()
msg, err := c.session.ChannelMessageSend(chatID, text)
if err != nil {
@@ -402,8 +408,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
// In guild (group) channels, apply unified group trigger filtering
// DMs (GuildID is empty) always get a response
isMentioned := false
if m.GuildID != "" {
isMentioned := false
for _, mention := range m.Mentions {
if mention.ID == c.botUserID {
isMentioned = true
@@ -500,14 +506,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
})
peerKind := "channel"
peerID := m.ChannelID
if m.GuildID == "" {
peerKind = "direct"
peerID = senderID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
metadata := map[string]string{
"user_id": senderID,
"username": m.Author.Username,
@@ -516,8 +518,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
"channel_id": m.ChannelID,
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: m.ChannelID,
ChatType: peerKind,
SenderID: senderID,
MessageID: m.ID,
Mentioned: isMentioned,
Raw: metadata,
}
if m.GuildID != "" {
inboundCtx.SpaceID = m.GuildID
inboundCtx.SpaceType = "guild"
}
if m.MessageReference != nil {
inboundCtx.ReplyToMessageID = m.MessageReference.MessageID
}
c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender)
c.HandleInboundContext(c.ctx, m.ChannelID, content, mediaPaths, inboundCtx, sender)
}
// startTyping starts a continuous typing indicator loop for the given chatID.
+19 -7
View File
@@ -8,11 +8,23 @@ import (
)
func init() {
channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
ch, err := NewDiscordChannel(cfg.Channels.Discord, b)
if err == nil {
ch.tts = tts.DetectTTS(cfg)
}
return ch, err
})
channels.RegisterFactory(
config.ChannelDiscord,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.DiscordSettings)
if !ok {
return nil, channels.ErrSendFailed
}
ch, err := NewDiscordChannel(bc, c, b)
if err == nil {
ch.tts = tts.DetectTTS(cfg)
}
return ch, err
},
)
}
+1 -1
View File
@@ -19,7 +19,7 @@ type FeishuChannel struct {
var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures")
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.MessageBus) (*FeishuChannel, error) {
return nil, errors.New(
"feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config",
)
+53 -27
View File
@@ -14,6 +14,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
@@ -37,21 +38,28 @@ const errCodeTenantTokenInvalid = 99991663
type FeishuChannel struct {
*channels.BaseChannel
config config.FeishuConfig
bc *config.Channel
config *config.FeishuSettings
client *lark.Client
wsClient *larkws.Client
tokenCache *tokenCache // custom cache that supports invalidation
botOpenID atomic.Value // stores string; populated lazily for @mention detection
botOpenID atomic.Value // stores string; populated lazily for @mention detection
messageCache sync.Map // caches fetched messages (messageID -> *larkim.Message)
mu sync.Mutex
cancel context.CancelFunc
}
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom,
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
type cachedMessage struct {
msg *larkim.Message
expiry time.Time
}
func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.MessageBus) (*FeishuChannel, error) {
base := channels.NewBaseChannel("feishu", cfg, bus, bc.AllowFrom,
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
tc := newTokenCache()
@@ -61,6 +69,7 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
}
ch := &FeishuChannel{
BaseChannel: base,
bc: bc,
config: cfg,
tokenCache: tc,
client: lark.NewClient(cfg.AppID, cfg.AppSecret.String(), opts...),
@@ -204,14 +213,14 @@ func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, cont
// SendPlaceholder implements channels.PlaceholderCapable.
// Sends an interactive card with placeholder text and returns its message ID.
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
if !c.config.Placeholder.Enabled {
if !c.bc.Placeholder.Enabled {
logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{
"chat_id": chatID,
})
return "", nil
}
text := c.config.Placeholder.GetRandomText()
text := c.bc.Placeholder.GetRandomText()
cardContent, err := buildMarkdownCard(text)
if err != nil {
@@ -439,30 +448,20 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
if content == "" {
content = "[empty message]"
}
metadata := map[string]string{}
if messageID != "" {
metadata["message_id"] = messageID
}
if messageType != "" {
metadata["message_type"] = messageType
}
chatType := stringValue(message.ChatType)
if chatType != "" {
metadata["chat_type"] = chatType
}
if sender != nil && sender.TenantKey != nil {
metadata["tenant_key"] = *sender.TenantKey
}
metadata := buildInboundMetadata(message, sender)
var peer bus.Peer
var (
inboundChatType string
isMentioned bool
)
if chatType == "p2p" {
peer = bus.Peer{Kind: "direct", ID: senderID}
inboundChatType = "direct"
} else {
peer = bus.Peer{Kind: "group", ID: chatID}
inboundChatType = "group"
// Check if bot was mentioned
isMentioned := c.isBotMentioned(message)
isMentioned = c.isBotMentioned(message)
// Strip mention placeholders from content before group trigger check
if len(message.Mentions) > 0 {
@@ -477,14 +476,41 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
content = cleaned
}
if replyTargetID(message) != "" || stringValue(message.ThreadId) != "" {
content, mediaRefs = c.prependReplyContext(ctx, message, chatID, content, mediaRefs)
}
if content == "" {
content = "[empty message]"
}
logger.InfoCF("feishu", "Feishu message received", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
"message_id": messageID,
"preview": utils.Truncate(content, 80),
})
logger.InfoCF("feishu", "Feishu reply linkage", map[string]any{
"message_id": messageID,
"parent_id": stringValue(message.ParentId),
"root_id": stringValue(message.RootId),
"thread_id": stringValue(message.ThreadId),
})
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
inboundCtx := bus.InboundContext{
Channel: "feishu",
ChatID: chatID,
ChatType: inboundChatType,
SenderID: senderID,
MessageID: messageID,
Mentioned: isMentioned,
Raw: metadata,
}
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
inboundCtx.SpaceType = "tenant"
inboundCtx.SpaceID = *sender.TenantKey
}
c.HandleInboundContext(ctx, chatID, content, mediaRefs, inboundCtx, senderInfo)
return nil
}
+298
View File
@@ -0,0 +1,298 @@
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
package feishu
import (
"context"
"fmt"
"strings"
"time"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
const messageCacheTTL = 30 * time.Second
const (
maxReplyContextLen = 600
)
func (c *FeishuChannel) prependReplyContext(
ctx context.Context,
message *larkim.EventMessage,
chatID string,
content string,
mediaRefs []string,
) (string, []string) {
if message == nil {
return content, mediaRefs
}
lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
targetMessageID := c.resolveReplyTargetMessageID(lookupCtx, message)
if targetMessageID == "" {
logger.DebugCF("feishu", "No reply target resolved; skip reply context", map[string]any{
"message_id": stringValue(message.MessageId),
"parent_id": stringValue(message.ParentId),
"root_id": stringValue(message.RootId),
"thread_id": stringValue(message.ThreadId),
})
return content, mediaRefs
}
repliedMessage, err := c.fetchMessageByID(lookupCtx, targetMessageID)
if err != nil {
logger.DebugCF("feishu", "Failed to fetch replied message context", map[string]any{
"target_message_id": targetMessageID,
"error": err.Error(),
})
return content, mediaRefs
}
messageType := stringValue(repliedMessage.MsgType)
rawContent := ""
if repliedMessage.Body != nil {
rawContent = stringValue(repliedMessage.Body.Content)
}
var repliedMediaRefs []string
if store := c.GetMediaStore(); store != nil {
repliedMediaRefs = c.downloadInboundMedia(lookupCtx, chatID, targetMessageID, messageType, rawContent, store)
if messageType == larkim.MsgTypeInteractive {
_, externalURLs := extractCardImageKeys(rawContent)
if len(externalURLs) > 0 {
repliedMediaRefs = append(repliedMediaRefs, externalURLs...)
}
}
}
repliedContent := normalizeRepliedContent(messageType, rawContent, repliedMediaRefs)
if len(repliedMediaRefs) > 0 {
mediaRefs = append(repliedMediaRefs, mediaRefs...)
}
return formatReplyContext(targetMessageID, repliedContent, content), mediaRefs
}
func (c *FeishuChannel) resolveReplyTargetMessageID(ctx context.Context, message *larkim.EventMessage) string {
if targetID := replyTargetID(message); targetID != "" {
logger.DebugCF("feishu", "Resolved reply target from event payload", map[string]any{
"message_id": stringValue(message.MessageId),
"parent_id": stringValue(message.ParentId),
"root_id": stringValue(message.RootId),
"target_id": targetID,
})
return targetID
}
currentMessageID := stringValue(message.MessageId)
if currentMessageID == "" {
return ""
}
if stringValue(message.ThreadId) == "" {
logger.DebugCF("feishu", "No reply target found; message is not in a thread", map[string]any{
"message_id": stringValue(message.MessageId),
})
return ""
}
msg, err := c.fetchMessageByID(ctx, currentMessageID)
if err != nil {
logger.DebugCF("feishu", "Failed to query current message detail for reply info", map[string]any{
"message_id": currentMessageID,
"error": err.Error(),
})
return ""
}
targetID := replyTargetIDFromMessage(msg)
if targetID != "" {
logger.DebugCF("feishu", "Resolved reply target from message detail", map[string]any{
"message_id": currentMessageID,
"parent_id": stringValue(msg.ParentId),
"root_id": stringValue(msg.RootId),
"target_id": targetID,
})
}
return targetID
}
func (c *FeishuChannel) fetchMessageByID(ctx context.Context, messageID string) (*larkim.Message, error) {
if cached, ok := c.messageCache.Load(messageID); ok {
cm := cached.(*cachedMessage)
if time.Now().Before(cm.expiry) {
return cm.msg, nil
}
c.messageCache.Delete(messageID)
}
req := larkim.NewGetMessageReqBuilder().
MessageId(messageID).
Build()
resp, err := c.client.Im.V1.Message.Get(ctx, req)
if err != nil {
return nil, fmt.Errorf("feishu get message: %w", err)
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
return nil, fmt.Errorf("feishu get message api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
if resp.Data == nil || len(resp.Data.Items) == 0 || resp.Data.Items[0] == nil {
return nil, fmt.Errorf("feishu get message: empty response")
}
// Items[0] contains the target message - the Feishu API returns a list
// but we request a single message by ID, so the list always has at most one item.
msg := resp.Data.Items[0]
c.messageCache.Store(messageID, &cachedMessage{msg: msg, expiry: time.Now().Add(messageCacheTTL)})
return msg, nil
}
func replyTargetID(message *larkim.EventMessage) string {
if message == nil {
return ""
}
if parentID := stringValue(message.ParentId); parentID != "" {
return parentID
}
return stringValue(message.RootId)
}
func replyTargetIDFromMessage(message *larkim.Message) string {
if message == nil {
return ""
}
if parentID := stringValue(message.ParentId); parentID != "" {
return parentID
}
return stringValue(message.RootId)
}
func buildInboundMetadata(message *larkim.EventMessage, sender *larkim.EventSender) map[string]string {
metadata := map[string]string{}
if message == nil {
return metadata
}
messageID := stringValue(message.MessageId)
if messageID != "" {
metadata["message_id"] = messageID
}
messageType := stringValue(message.MessageType)
if messageType != "" {
metadata["message_type"] = messageType
}
chatType := stringValue(message.ChatType)
if chatType != "" {
metadata["chat_type"] = chatType
}
parentID := stringValue(message.ParentId)
if parentID != "" {
metadata["parent_id"] = parentID
}
rootID := stringValue(message.RootId)
if rootID != "" {
metadata["root_id"] = rootID
}
if replyTo := replyTargetID(message); replyTo != "" {
metadata["reply_to_message_id"] = replyTo
}
threadID := stringValue(message.ThreadId)
if threadID != "" {
metadata["thread_id"] = threadID
}
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
metadata["tenant_key"] = *sender.TenantKey
}
return metadata
}
func normalizeRepliedContent(messageType, rawContent string, mediaRefs []string) string {
content := extractContent(messageType, rawContent)
if containsFeishuUpgradePlaceholder(rawContent) || containsFeishuUpgradePlaceholder(content) {
content = ""
}
content = appendMediaTags(content, messageType, mediaRefs)
if strings.TrimSpace(content) != "" {
return content
}
switch messageType {
case larkim.MsgTypeImage:
return "[replied image]"
case larkim.MsgTypeFile:
return "[replied file]"
case larkim.MsgTypeAudio:
return "[replied audio]"
case larkim.MsgTypeMedia:
return "[replied video]"
case larkim.MsgTypeInteractive:
return "[replied interactive card]"
default:
return "[replied message content unavailable]"
}
}
func containsFeishuUpgradePlaceholder(s string) bool {
upgradePrompt := "\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef"
upgradePromptEscaped := "\\u8bf7\\u5347\\u7ea7\\u81f3\\u6700\\u65b0\\u7248\\u672c\\u5ba2\\u6237\\u7aef"
return strings.Contains(s, upgradePrompt) || strings.Contains(s, upgradePromptEscaped)
}
func formatReplyContext(parentID, repliedContent, content string) string {
parentID = strings.TrimSpace(parentID)
repliedContent = strings.TrimSpace(repliedContent)
content = strings.TrimSpace(content)
if parentID == "" || repliedContent == "" {
return content
}
repliedContent = utils.Truncate(repliedContent, maxReplyContextLen)
repliedContent = sanitizeReplyContextContent(repliedContent)
content = sanitizeReplyContextContent(content)
header := fmt.Sprintf("[replied_message id=%q]", parentID)
footer := "[/replied_message]"
if content == "" {
return header + "\n" + repliedContent + "\n" + footer
}
if hasLeadingCommandPrefix(content) {
return content + "\n\n" + header + "\n" + repliedContent + "\n" + footer
}
return header + "\n" + repliedContent + "\n" + footer + "\n\n[current_message]\n" + content + "\n[/current_message]"
}
func hasLeadingCommandPrefix(s string) bool {
tokens := strings.Fields(strings.TrimSpace(s))
if len(tokens) == 0 {
return false
}
first := tokens[0]
return strings.HasPrefix(first, "/") || strings.HasPrefix(first, "!")
}
func sanitizeReplyContextContent(s string) string {
tagEscaper := strings.NewReplacer(
"[replied_message", `\[replied_message`,
"[/replied_message]", `\[/replied_message]`,
"[current_message]", `\[current_message]`,
"[/current_message]", `\[/current_message]`,
)
return tagEscaper.Replace(s)
}
+229
View File
@@ -0,0 +1,229 @@
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
package feishu
import (
"strings"
"testing"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
)
func TestBuildInboundMetadata(t *testing.T) {
strPtr := func(s string) *string { return &s }
t.Run("includes basic and reply fields", func(t *testing.T) {
message := &larkim.EventMessage{
MessageId: strPtr("om_msg_1"),
MessageType: strPtr("text"),
ChatType: strPtr("group"),
ParentId: strPtr("om_parent_1"),
RootId: strPtr("om_root_1"),
ThreadId: strPtr("omt_thread_1"),
}
sender := &larkim.EventSender{TenantKey: strPtr("tenant_x")}
got := buildInboundMetadata(message, sender)
if got["message_id"] != "om_msg_1" {
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_1")
}
if got["message_type"] != "text" {
t.Fatalf("message_type = %q, want %q", got["message_type"], "text")
}
if got["chat_type"] != "group" {
t.Fatalf("chat_type = %q, want %q", got["chat_type"], "group")
}
if got["parent_id"] != "om_parent_1" {
t.Fatalf("parent_id = %q, want %q", got["parent_id"], "om_parent_1")
}
if got["reply_to_message_id"] != "om_parent_1" {
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_parent_1")
}
if got["root_id"] != "om_root_1" {
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_1")
}
if got["thread_id"] != "omt_thread_1" {
t.Fatalf("thread_id = %q, want %q", got["thread_id"], "omt_thread_1")
}
if got["tenant_key"] != "tenant_x" {
t.Fatalf("tenant_key = %q, want %q", got["tenant_key"], "tenant_x")
}
})
t.Run("falls back reply_to_message_id to root_id", func(t *testing.T) {
message := &larkim.EventMessage{
MessageId: strPtr("om_msg_3"),
RootId: strPtr("om_root_3"),
}
got := buildInboundMetadata(message, nil)
if got["root_id"] != "om_root_3" {
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_3")
}
if got["reply_to_message_id"] != "om_root_3" {
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_root_3")
}
})
t.Run("omits empty values", func(t *testing.T) {
message := &larkim.EventMessage{
MessageId: strPtr("om_msg_2"),
}
got := buildInboundMetadata(message, nil)
if got["message_id"] != "om_msg_2" {
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_2")
}
if _, ok := got["parent_id"]; ok {
t.Fatalf("parent_id should be absent, got %q", got["parent_id"])
}
if _, ok := got["reply_to_message_id"]; ok {
t.Fatalf("reply_to_message_id should be absent, got %q", got["reply_to_message_id"])
}
if _, ok := got["tenant_key"]; ok {
t.Fatalf("tenant_key should be absent, got %q", got["tenant_key"])
}
})
t.Run("nil message returns empty map", func(t *testing.T) {
got := buildInboundMetadata(nil, nil)
if len(got) != 0 {
t.Fatalf("len(metadata) = %d, want 0", len(got))
}
})
}
func TestFormatReplyContext(t *testing.T) {
t.Run("formats reply context with content", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "new reply")
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]\n\n[current_message]\nnew reply\n[/current_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
t.Run("returns reply context when current content is empty", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "")
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
t.Run("returns original content when parent or replied content missing", func(t *testing.T) {
if got := formatReplyContext("", "original", "new reply"); got != "new reply" {
t.Fatalf("missing parent: got %q, want %q", got, "new reply")
}
if got := formatReplyContext("om_parent_1", "", "new reply"); got != "new reply" {
t.Fatalf("missing replied content: got %q, want %q", got, "new reply")
}
})
t.Run("escapes reserved wrapper tags in payload", func(t *testing.T) {
replied := "payload [replied_message id=\"x\"] x [/replied_message]"
current := "hello [current_message]injected[/current_message]"
got := formatReplyContext("om_parent_1", replied, current)
if !strings.HasPrefix(got, "[replied_message id=\"om_parent_1\"]") {
t.Fatalf("outer replied_message wrapper missing: %q", got)
}
if strings.Contains(got, "\n[replied_message id=\"x\"]") {
t.Fatalf("nested replied_message tag should be escaped: %q", got)
}
if strings.Contains(got, "\n[current_message]injected") {
t.Fatalf("nested current_message tag should be escaped: %q", got)
}
if !strings.Contains(got, `\[replied_message id="x"]`) {
t.Fatalf("escaped replied tag missing: %q", got)
}
})
t.Run("preserves leading slash command prefix", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "/help")
want := "/help\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
t.Run("preserves leading bang command prefix", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "!status now")
want := "!status now\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
}
func TestReplyTargetID(t *testing.T) {
strPtr := func(s string) *string { return &s }
t.Run("prefer parent_id", func(t *testing.T) {
msg := &larkim.EventMessage{ParentId: strPtr("om_parent"), RootId: strPtr("om_root")}
if got := replyTargetID(msg); got != "om_parent" {
t.Fatalf("replyTargetID() = %q, want %q", got, "om_parent")
}
})
t.Run("fallback to root_id", func(t *testing.T) {
msg := &larkim.EventMessage{RootId: strPtr("om_root")}
if got := replyTargetID(msg); got != "om_root" {
t.Fatalf("replyTargetID() = %q, want %q", got, "om_root")
}
})
t.Run("empty when no fields", func(t *testing.T) {
if got := replyTargetID(&larkim.EventMessage{}); got != "" {
t.Fatalf("replyTargetID() = %q, want empty", got)
}
})
}
func TestNormalizeRepliedContent(t *testing.T) {
t.Run("filters feishu upgrade placeholder for interactive", func(t *testing.T) {
raw := `{"text":"\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef\uff0c\u4ee5\u67e5\u770b\u5185\u5bb9"}`
got := normalizeRepliedContent("interactive", raw, nil)
if got != "[replied interactive card]" {
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied interactive card]")
}
})
t.Run("keeps filename and file tag for replied file", func(t *testing.T) {
got := normalizeRepliedContent("file", `{"file_key":"file_xxx","file_name":"doc.pdf"}`, []string{"media://r1"})
if got != "doc.pdf [file]" {
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "doc.pdf [file]")
}
})
t.Run("falls back when file content missing", func(t *testing.T) {
got := normalizeRepliedContent("file", `{"file_key":"file_xxx"}`, nil)
if got != "[replied file]" {
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied file]")
}
})
}
func TestHasLeadingCommandPrefix(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{name: "slash command", input: "/help", want: true},
{name: "bang command", input: "!status", want: true},
{name: "leading spaces slash", input: " /ping arg", want: true},
{name: "normal text", input: "hello /help", want: false},
{name: "empty", input: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := hasLeadingCommandPrefix(tt.input); got != tt.want {
t.Fatalf("hasLeadingCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewFeishuChannel(cfg.Channels.Feishu, b)
})
channels.RegisterFactory(
config.ChannelFeishu,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.FeishuSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewFeishuChannel(bc, c, b)
},
)
}
+18 -5
View File
@@ -51,14 +51,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&")
var chatID string
var peer bus.Peer
if isDM {
chatID = nick
peer = bus.Peer{Kind: "direct", ID: nick}
} else {
chatID = target
peer = bus.Peer{Kind: "group", ID: target}
}
sender := bus.SenderInfo{
@@ -73,9 +70,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
return
}
isMentioned := false
// For channel messages, check group trigger (mention detection)
if !isDM {
isMentioned := isBotMentioned(content, currentNick)
isMentioned = isBotMentioned(content, currentNick)
if isMentioned {
content = stripBotMention(content, currentNick)
}
@@ -100,7 +99,21 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
metadata["channel"] = target
}
c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "irc",
ChatID: chatID,
SenderID: nick,
MessageID: messageID,
Mentioned: isMentioned,
Raw: metadata,
}
if isDM {
inboundCtx.ChatType = "direct"
} else {
inboundCtx.ChatType = "group"
}
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
}
// nickMentionedAt returns the byte index where botNick is mentioned in content
+25 -6
View File
@@ -7,10 +7,29 @@ import (
)
func init() {
channels.RegisterFactory("irc", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
if !cfg.Channels.IRC.Enabled {
return nil, nil
}
return NewIRCChannel(cfg.Channels.IRC, b)
})
channels.RegisterFactory(
config.ChannelIRC,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
if bc == nil || !bc.Enabled {
return nil, nil
}
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.IRCSettings)
if !ok {
return nil, channels.ErrSendFailed
}
ch, err := NewIRCChannel(bc, c, b)
if err != nil {
return nil, err
}
if channelName != config.ChannelIRC {
ch.SetName(channelName)
}
return ch, nil
},
)
}
+8 -6
View File
@@ -18,14 +18,15 @@ import (
// IRCChannel implements the Channel interface for IRC servers.
type IRCChannel struct {
*channels.BaseChannel
config config.IRCConfig
bc *config.Channel
config *config.IRCSettings
conn *ircevent.Connection
ctx context.Context
cancel context.CancelFunc
}
// NewIRCChannel creates a new IRC channel.
func NewIRCChannel(cfg config.IRCConfig, messageBus *bus.MessageBus) (*IRCChannel, error) {
func NewIRCChannel(bc *config.Channel, cfg *config.IRCSettings, messageBus *bus.MessageBus) (*IRCChannel, error) {
if cfg.Server == "" {
return nil, fmt.Errorf("irc server is required")
}
@@ -33,14 +34,15 @@ func NewIRCChannel(cfg config.IRCConfig, messageBus *bus.MessageBus) (*IRCChanne
return nil, fmt.Errorf("irc nick is required")
}
base := channels.NewBaseChannel("irc", cfg, messageBus, cfg.AllowFrom,
base := channels.NewBaseChannel("irc", cfg, messageBus, bc.AllowFrom,
channels.WithMaxMessageLength(400),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &IRCChannel{
BaseChannel: base,
bc: bc,
config: cfg,
}, nil
}
@@ -166,7 +168,7 @@ func (c *IRCChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]strin
func (c *IRCChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
noop := func() {}
if !c.config.Typing.Enabled || !c.IsRunning() || c.conn == nil {
if !c.bc.Typing.Enabled || !c.IsRunning() || c.conn == nil {
return noop, nil
}
+9 -6
View File
@@ -11,28 +11,31 @@ func TestNewIRCChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing server", func(t *testing.T) {
cfg := config.IRCConfig{Nick: "bot"}
_, err := NewIRCChannel(cfg, msgBus)
bc := &config.Channel{Type: config.ChannelIRC, Enabled: true}
cfg := &config.IRCSettings{Nick: "bot"}
_, err := NewIRCChannel(bc, cfg, msgBus)
if err == nil {
t.Error("expected error for missing server, got nil")
}
})
t.Run("missing nick", func(t *testing.T) {
cfg := config.IRCConfig{Server: "irc.example.com:6667"}
_, err := NewIRCChannel(cfg, msgBus)
bc := &config.Channel{Type: config.ChannelIRC, Enabled: true}
cfg := &config.IRCSettings{Server: "irc.example.com:6667"}
_, err := NewIRCChannel(bc, cfg, msgBus)
if err == nil {
t.Error("expected error for missing nick, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.IRCConfig{
bc := &config.Channel{Type: config.ChannelIRC, Enabled: true}
cfg := &config.IRCSettings{
Server: "irc.example.com:6667",
Nick: "testbot",
Channels: []string{"#test"},
}
ch, err := NewIRCChannel(cfg, msgBus)
ch, err := NewIRCChannel(bc, cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewLINEChannel(cfg.Channels.LINE, b)
})
channels.RegisterFactory(
config.ChannelLINE,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.LINESettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewLINEChannel(bc, c, b)
},
)
}
+30 -13
View File
@@ -40,7 +40,7 @@ type replyTokenEntry struct {
// and the official LINE Bot SDK for sending messages.
type LINEChannel struct {
*channels.BaseChannel
config config.LINEConfig
config *config.LINESettings
client *messaging_api.MessagingApiAPI
botUserID string // Bot's user ID
botBasicID string // Bot's basic ID (e.g. @216ru...)
@@ -52,7 +52,11 @@ type LINEChannel struct {
}
// NewLINEChannel creates a new LINE channel instance.
func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) {
func NewLINEChannel(
bc *config.Channel,
cfg *config.LINESettings,
messageBus *bus.MessageBus,
) (*LINEChannel, error) {
if cfg.ChannelSecret.String() == "" || cfg.ChannelAccessToken.String() == "" {
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
}
@@ -62,10 +66,10 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
return nil, fmt.Errorf("failed to create LINE messaging client: %w", err)
}
base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom,
base := channels.NewBaseChannel("line", cfg, messageBus, bc.AllowFrom,
channels.WithMaxMessageLength(5000),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &LINEChannel{
@@ -191,6 +195,7 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
var content string
var mediaPaths []string
var messageID string
var quoteToken string
var isMentioned bool
// Helper to register a local file with the media store
@@ -214,6 +219,7 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
isMentioned = c.isBotMentioned(msg)
// Store quote token for quoting the original message in reply
if msg.QuoteToken != "" {
quoteToken = msg.QuoteToken
c.quoteTokens.Store(chatID, msg.QuoteToken)
}
// Strip bot mention from text in group chats
@@ -275,13 +281,6 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
"source_type": sourceType,
}
var peer bus.Peer
if isGroup {
peer = bus.Peer{Kind: "group", ID: chatID}
} else {
peer = bus.Peer{Kind: "direct", ID: senderID}
}
logger.DebugCF("line", "Received message", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
@@ -300,7 +299,25 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
return
}
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: chatID,
ChatType: map[bool]string{true: "group", false: "direct"}[isGroup],
SenderID: senderID,
MessageID: messageID,
Mentioned: isMentioned,
Raw: metadata,
}
if msgEvent.ReplyToken != "" {
inboundCtx.ReplyHandles = map[string]string{
"reply_token": msgEvent.ReplyToken,
}
if quoteToken != "" {
inboundCtx.ReplyHandles["quote_token"] = quoteToken
}
}
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
}
// isBotMentioned checks if the bot is mentioned in the message.
+9 -5
View File
@@ -6,10 +6,12 @@ import (
"net/http/httptest"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestWebhookRejectsOversizedBody(t *testing.T) {
ch := &LINEChannel{}
ch := &LINEChannel{config: &config.LINESettings{}}
oversized := bytes.Repeat([]byte("A"), maxWebhookBodySize+1)
req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(oversized))
@@ -23,7 +25,7 @@ func TestWebhookRejectsOversizedBody(t *testing.T) {
}
func TestWebhookAcceptsMaxBodySize(t *testing.T) {
ch := &LINEChannel{}
ch := &LINEChannel{config: &config.LINESettings{}}
body := bytes.Repeat([]byte("A"), maxWebhookBodySize)
req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body))
@@ -38,7 +40,7 @@ func TestWebhookAcceptsMaxBodySize(t *testing.T) {
}
func TestWebhookRejectsOversizedBodyBeforeSignatureCheck(t *testing.T) {
ch := &LINEChannel{}
ch := &LINEChannel{config: &config.LINESettings{}}
oversized := bytes.Repeat([]byte("A"), maxWebhookBodySize+1)
req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(oversized))
@@ -53,7 +55,7 @@ func TestWebhookRejectsOversizedBodyBeforeSignatureCheck(t *testing.T) {
}
func TestWebhookRejectsNonPostMethod(t *testing.T) {
ch := &LINEChannel{}
ch := &LINEChannel{config: &config.LINESettings{}}
req := httptest.NewRequest(http.MethodGet, "/webhook", nil)
rec := httptest.NewRecorder()
@@ -66,7 +68,9 @@ func TestWebhookRejectsNonPostMethod(t *testing.T) {
}
func TestWebhookRejectsInvalidSignature(t *testing.T) {
ch := &LINEChannel{}
ch := &LINEChannel{
config: &config.LINESettings{},
}
body := `{"events":[]}`
req := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body))
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewMaixCamChannel(cfg.Channels.MaixCam, b)
})
channels.RegisterFactory(
config.ChannelMaixCam,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.MaixCamSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewMaixCamChannel(bc, c, b)
},
)
}
+17 -15
View File
@@ -17,7 +17,7 @@ import (
type MaixCamChannel struct {
*channels.BaseChannel
config config.MaixCamConfig
config *config.MaixCamSettings
listener net.Listener
ctx context.Context
cancel context.CancelFunc
@@ -32,13 +32,17 @@ type MaixCamMessage struct {
Data map[string]any `json:"data"`
}
func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) {
func NewMaixCamChannel(
bc *config.Channel,
cfg *config.MaixCamSettings,
bus *bus.MessageBus,
) (*MaixCamChannel, error) {
base := channels.NewBaseChannel(
"maixcam",
cfg,
bus,
cfg.AllowFrom,
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
bc.AllowFrom,
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &MaixCamChannel{
@@ -196,17 +200,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
return
}
c.HandleMessage(
c.ctx,
bus.Peer{Kind: "channel", ID: "default"},
"",
senderID,
chatID,
content,
[]string{},
metadata,
sender,
)
inboundCtx := bus.InboundContext{
Channel: "maixcam",
ChatID: chatID,
ChatType: "channel",
SenderID: senderID,
Raw: metadata,
}
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
}
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
+195 -130
View File
@@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"math"
"net"
"net/http"
"sort"
"sync"
@@ -86,6 +87,7 @@ type Manager struct {
dispatchTask *asyncTask
mux *dynamicServeMux
httpServer *http.Server
httpListeners []net.Listener
mu sync.RWMutex
placeholders sync.Map // "channel:chatID" → placeholderID (string)
typingStops sync.Map // "channel:chatID" → func()
@@ -98,6 +100,22 @@ type asyncTask struct {
cancel context.CancelFunc
}
func outboundMessageChannel(msg bus.OutboundMessage) string {
return msg.Context.Channel
}
func outboundMessageChatID(msg bus.OutboundMessage) string {
return msg.ChatID
}
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
return msg.Context.Channel
}
func outboundMediaChatID(msg bus.OutboundMediaMessage) string {
return msg.ChatID
}
// RecordPlaceholder registers a placeholder message for later editing.
// Implements PlaceholderRecorder.
func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) {
@@ -161,7 +179,8 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
// preSend handles typing stop, reaction undo, and placeholder editing before sending a message.
// Returns the delivered message IDs and true when delivery completed before a normal Send.
func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) {
key := name + ":" + msg.ChatID
chatID := outboundMessageChatID(msg)
key := name + ":" + chatID
// 1. Stop typing
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
@@ -183,9 +202,9 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
// Prefer deleting the placeholder (cleaner UX than editing to same content)
if deleter, ok := ch.(MessageDeleter); ok {
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
} else if editor, ok := ch.(MessageEditor); ok {
editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
}
}
}
@@ -196,7 +215,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
if editor, ok := ch.(MessageEditor); ok {
if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil {
if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil {
return []string{entry.id}, true
}
// edit failed → fall through to normal Send
@@ -212,7 +231,8 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
// delivery never edits the placeholder because there is no text payload to
// replace it with; it only attempts to delete the placeholder when possible.
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
key := name + ":" + msg.ChatID
chatID := outboundMediaChatID(msg)
key := name + ":" + chatID
// 1. Stop typing
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
@@ -235,7 +255,7 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
if deleter, ok := ch.(MessageDeleter); ok {
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
}
}
}
@@ -311,22 +331,27 @@ func (s *finalizeHookStreamer) Finalize(ctx context.Context, content string) err
return nil
}
// initChannel is a helper that looks up a factory by name and creates the channel.
func (m *Manager) initChannel(name, displayName string) {
f, ok := getFactory(name)
// initChannel is a helper that looks up a factory by type name and creates the channel.
// typeName is the channel type used for factory lookup (e.g., "telegram").
// channelName is the config map key used as the channel's runtime name (e.g., "my_telegram").
func (m *Manager) initChannel(typeName, channelName string) {
f, ok := getFactory(typeName)
if !ok {
logger.WarnCF("channels", "Factory not registered", map[string]any{
"channel": displayName,
"channel": channelName,
"type": typeName,
})
return
}
logger.DebugCF("channels", "Attempting to initialize channel", map[string]any{
"channel": displayName,
"channel": channelName,
"type": typeName,
})
ch, err := f(m.config, m.bus)
ch, err := f(channelName, typeName, m.config, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize channel", map[string]any{
"channel": displayName,
"channel": channelName,
"type": typeName,
"error": err.Error(),
})
} else {
@@ -344,103 +369,100 @@ func (m *Manager) initChannel(name, displayName string) {
if setter, ok := ch.(interface{ SetOwner(ch Channel) }); ok {
setter.SetOwner(ch)
}
m.channels[name] = ch
m.channels[channelName] = ch
logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
"channel": displayName,
"channel": channelName,
"type": typeName,
})
}
}
func (m *Manager) getChannelConfigAndEnabled(channelName string) (*config.Channel, bool) {
bc, ok := m.config.Channels[channelName]
if !ok || bc == nil {
return nil, false
}
if !bc.Enabled {
return bc, false
}
// Use Type to determine the config struct for validation.
// The map key (channelName) is the config key, which may differ from the type.
channelType := bc.Type
if channelType == "" {
channelType = channelName
}
// Settings have already been decoded by InitChannelList, so we just need to
// type-assert and check the relevant fields.
decoded, err := bc.GetDecoded()
if err != nil {
return bc, false
}
//nolint:revive
switch settings := decoded.(type) {
case *config.WhatsAppSettings:
if channelType == config.ChannelWhatsApp {
return bc, settings.BridgeURL != ""
}
return bc, channelType == config.ChannelWhatsAppNative && settings.UseNative
case *config.MatrixSettings:
return bc, settings.Homeserver != "" && settings.UserID != "" && settings.AccessToken.String() != ""
case *config.WeComSettings:
return bc, settings.BotID != "" && settings.Secret.String() != ""
case *config.PicoClientSettings:
return bc, settings.URL != ""
case *config.DingTalkSettings:
return bc, settings.ClientID != ""
case *config.SlackSettings:
return bc, settings.BotToken.String() != ""
case *config.WeixinSettings:
return bc, settings.Token.String() != ""
case *config.PicoSettings:
return bc, settings.Token.String() != ""
case *config.IRCSettings:
return bc, settings.Server != ""
case *config.LINESettings:
return bc, settings.ChannelAccessToken.String() != ""
case *config.OneBotSettings:
return bc, settings.WSUrl != ""
case *config.QQSettings:
return bc, settings.AppSecret.String() != ""
case *config.TelegramSettings:
return bc, settings.Token.String() != ""
case *config.FeishuSettings:
return bc, settings.AppSecret.String() != ""
case *config.MaixCamSettings:
return bc, true
case *config.TeamsWebhookSettings:
return bc, true
case *config.DiscordSettings:
return bc, settings.Token.String() != ""
case *config.VKSettings:
return bc, settings.GroupID != 0 && settings.Token.String() != ""
}
return bc, bc.Enabled
}
// initChannels initializes all enabled channels based on the configuration.
// It iterates config entries and uses bc.Type to look up the appropriate factory.
func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
logger.InfoC("channels", "Initializing channel manager")
if channels.Telegram.Enabled && channels.Telegram.Token.String() != "" {
m.initChannel("telegram", "Telegram")
}
if channels.WhatsApp.Enabled {
waCfg := channels.WhatsApp
if waCfg.UseNative {
m.initChannel("whatsapp_native", "WhatsApp Native")
} else if waCfg.BridgeURL != "" {
m.initChannel("whatsapp", "WhatsApp")
for name, bc := range *channels {
if !bc.Enabled {
continue
}
}
if channels.Feishu.Enabled {
m.initChannel("feishu", "Feishu")
}
if channels.Discord.Enabled && channels.Discord.Token.String() != "" {
m.initChannel("discord", "Discord")
}
if channels.MaixCam.Enabled {
m.initChannel("maixcam", "MaixCam")
}
if channels.QQ.Enabled {
m.initChannel("qq", "QQ")
}
if channels.DingTalk.Enabled && channels.DingTalk.ClientID != "" {
m.initChannel("dingtalk", "DingTalk")
}
if channels.Slack.Enabled && channels.Slack.BotToken.String() != "" {
m.initChannel("slack", "Slack")
}
if channels.Matrix.Enabled &&
m.config.Channels.Matrix.Homeserver != "" &&
m.config.Channels.Matrix.UserID != "" &&
m.config.Channels.Matrix.AccessToken.String() != "" {
m.initChannel("matrix", "Matrix")
}
if channels.LINE.Enabled && channels.LINE.ChannelAccessToken.String() != "" {
m.initChannel("line", "LINE")
}
if channels.OneBot.Enabled && channels.OneBot.WSUrl != "" {
m.initChannel("onebot", "OneBot")
}
if channels.WeCom.Enabled && channels.WeCom.BotID != "" && channels.WeCom.Secret.String() != "" {
m.initChannel("wecom", "WeCom")
}
if channels.Weixin.Enabled && channels.Weixin.Token.String() != "" {
m.initChannel("weixin", "Weixin")
}
if channels.Pico.Enabled && channels.Pico.Token.String() != "" {
m.initChannel("pico", "Pico")
}
if channels.PicoClient.Enabled && channels.PicoClient.URL != "" {
m.initChannel("pico_client", "Pico Client")
}
if channels.IRC.Enabled && channels.IRC.Server != "" {
m.initChannel("irc", "IRC")
}
if channels.VK.Enabled && channels.VK.Token.String() != "" && channels.VK.GroupID != 0 {
m.initChannel("vk", "VK")
}
if channels.TeamsWebhook.Enabled && len(channels.TeamsWebhook.Webhooks) > 0 {
hasValidTarget := false
for _, target := range channels.TeamsWebhook.Webhooks {
if target.WebhookURL.String() != "" {
hasValidTarget = true
break
}
_, ready := m.getChannelConfigAndEnabled(name)
if !ready {
continue
}
if hasValidTarget {
m.initChannel("teams_webhook", "Teams Webhook")
typeName := bc.Type
if typeName == "" {
typeName = name
}
m.initChannel(typeName, name)
}
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
@@ -454,6 +476,12 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
// It registers health endpoints from the health server and discovers channels
// that implement WebhookHandler and/or HealthChecker to register their handlers.
func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
m.SetupHTTPServerListeners(nil, addr, healthServer)
}
// SetupHTTPServerListeners creates a shared HTTP server on pre-opened listeners.
// When listeners is empty it falls back to Addr-based ListenAndServe behavior.
func (m *Manager) SetupHTTPServerListeners(listeners []net.Listener, addr string, healthServer *health.Server) {
m.mux = newDynamicServeMux()
// Register health endpoints
@@ -470,6 +498,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
m.httpListeners = append([]net.Listener(nil), listeners...)
}
// registerHTTPHandlersLocked registers webhook and health-check handlers for
@@ -548,7 +577,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
continue
}
// Lazily create worker only after channel starts successfully
w := newChannelWorker(name, channel)
channelType := name
if m.config != nil {
if bc := m.config.Channels.Get(name); bc != nil && bc.Type != "" {
channelType = bc.Type
}
}
w := newChannelWorker(name, channel, channelType)
m.workers[name] = w
go m.runWorker(dispatchCtx, name, w)
go m.runMediaWorker(dispatchCtx, name, w)
@@ -593,16 +628,33 @@ func (m *Manager) StartAll(ctx context.Context) error {
// Start shared HTTP server if configured
if m.httpServer != nil {
go func() {
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
"addr": m.httpServer.Addr,
})
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"error": err.Error(),
})
if len(m.httpListeners) > 0 {
for _, listener := range m.httpListeners {
ln := listener
go func() {
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
"addr": ln.Addr().String(),
})
if err := m.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"addr": ln.Addr().String(),
"error": err.Error(),
})
}
}()
}
}()
} else {
go func() {
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
"addr": m.httpServer.Addr,
})
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"error": err.Error(),
})
}
}()
}
}
logger.InfoCF("channels", "Channel startup completed", map[string]any{
@@ -629,6 +681,7 @@ func (m *Manager) StopAll(ctx context.Context) error {
})
}
m.httpServer = nil
m.httpListeners = nil
}
// Cancel dispatcher
@@ -678,10 +731,10 @@ func (m *Manager) StopAll(ctx context.Context) error {
}
// newChannelWorker creates a channelWorker with a rate limiter configured
// for the given channel name.
func newChannelWorker(name string, ch Channel) *channelWorker {
// for the given channel type. channelType is used for rate limit lookup.
func newChannelWorker(name string, ch Channel, channelType string) *channelWorker {
rateVal := float64(defaultRateLimit)
if r, ok := channelRateConfig[name]; ok {
if r, ok := channelRateConfig[channelType]; ok {
rateVal = r
}
burst := int(math.Max(1, math.Ceil(rateVal/2)))
@@ -812,7 +865,7 @@ func (m *Manager) sendWithRetry(
// All retries exhausted or permanent failure
logger.ErrorCF("channels", "Send failed", map[string]any{
"channel": name,
"chat_id": msg.ChatID,
"chat_id": outboundMessageChatID(msg),
"error": lastErr.Error(),
"retries": maxRetries,
})
@@ -874,7 +927,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.OutboundChan(),
func(msg bus.OutboundMessage) string { return msg.Channel },
func(msg bus.OutboundMessage) string { return outboundMessageChannel(msg) },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
case w.queue <- msg:
@@ -894,7 +947,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.OutboundMediaChan(),
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(msg bus.OutboundMediaMessage) string { return outboundMediaChannel(msg) },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
case w.mediaQueue <- msg:
@@ -993,7 +1046,7 @@ func (m *Manager) sendMediaWithRetry(
// All retries exhausted or permanent failure
logger.ErrorCF("channels", "SendMedia failed", map[string]any{
"channel": name,
"chat_id": msg.ChatID,
"chat_id": outboundMediaChatID(msg),
"error": lastErr.Error(),
"retries": maxRetries,
})
@@ -1137,7 +1190,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
continue
}
// Lazily create worker only after channel starts successfully
w := newChannelWorker(name, channel)
channelType := name
if m.config != nil {
if bc := m.config.Channels.Get(name); bc != nil && bc.Type != "" {
channelType = bc.Type
}
}
w := newChannelWorker(name, channel, channelType)
m.workers[name] = w
go m.runWorker(dispatchCtx, name, w)
go m.runMediaWorker(dispatchCtx, name, w)
@@ -1186,16 +1245,19 @@ func (m *Manager) UnregisterChannel(name string) {
// delivered (or all retries are exhausted), which preserves ordering when
// a subsequent operation depends on the message having been sent.
func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error {
msg = bus.NormalizeOutboundMessage(msg)
channelName := outboundMessageChannel(msg)
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
_, exists := m.channels[channelName]
w, wExists := m.workers[channelName]
m.mu.RUnlock()
if !exists {
return fmt.Errorf("channel %s not found", msg.Channel)
return fmt.Errorf("channel %s not found", channelName)
}
if !wExists || w == nil {
return fmt.Errorf("channel %s has no active worker", msg.Channel)
return fmt.Errorf("channel %s has no active worker", channelName)
}
maxLen := 0
@@ -1206,10 +1268,10 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
for _, chunk := range SplitMessage(msg.Content, maxLen) {
chunkMsg := msg
chunkMsg.Content = chunk
m.sendWithRetry(ctx, msg.Channel, w, chunkMsg)
m.sendWithRetry(ctx, channelName, w, chunkMsg)
}
} else {
m.sendWithRetry(ctx, msg.Channel, w, msg)
m.sendWithRetry(ctx, channelName, w, msg)
}
return nil
}
@@ -1219,19 +1281,22 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
// retries are exhausted), which preserves ordering when later agent behavior
// depends on actual media delivery.
func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
msg = bus.NormalizeOutboundMediaMessage(msg)
channelName := outboundMediaChannel(msg)
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
_, exists := m.channels[channelName]
w, wExists := m.workers[channelName]
m.mu.RUnlock()
if !exists {
return fmt.Errorf("channel %s not found", msg.Channel)
return fmt.Errorf("channel %s not found", channelName)
}
if !wExists || w == nil {
return fmt.Errorf("channel %s has no active worker", msg.Channel)
return fmt.Errorf("channel %s has no active worker", channelName)
}
_, err := m.sendMediaWithRetry(ctx, msg.Channel, w, msg)
_, err := m.sendMediaWithRetry(ctx, channelName, w, msg)
return err
}
@@ -1246,10 +1311,10 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
}
msg := bus.OutboundMessage{
Channel: channelName,
ChatID: chatID,
Context: bus.NewOutboundContext(channelName, chatID, ""),
Content: content,
}
msg = bus.NormalizeOutboundMessage(msg)
if wExists && w != nil {
select {
+62 -110
View File
@@ -6,7 +6,6 @@ import (
"encoding/json"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
func toChannelHashes(cfg *config.Config) map[string]string {
@@ -21,7 +20,7 @@ func toChannelHashes(cfg *config.Config) map[string]string {
if !value["enabled"].(bool) {
continue
}
hiddenValues(key, value, ch)
hiddenValues(key, value, ch.Get(key))
valueBytes, _ := json.Marshal(value)
hash := md5.Sum(valueBytes)
result[key] = hex.EncodeToString(hash[:])
@@ -30,43 +29,77 @@ func toChannelHashes(cfg *config.Config) map[string]string {
return result
}
func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
func hiddenValues(key string, value map[string]any, ch *config.Channel) {
v, err := ch.GetDecoded()
if err != nil {
return
}
switch key {
case "pico":
value["token"] = ch.Pico.Token.String()
if settings, ok := v.(*config.PicoSettings); ok {
value["token"] = settings.Token.String()
}
case "telegram":
value["token"] = ch.Telegram.Token.String()
if settings, ok := v.(*config.TelegramSettings); ok {
value["token"] = settings.Token.String()
}
case "discord":
value["token"] = ch.Discord.Token.String()
if settings, ok := v.(*config.DiscordSettings); ok {
value["token"] = settings.Token.String()
}
case "slack":
value["bot_token"] = ch.Slack.BotToken.String()
value["app_token"] = ch.Slack.AppToken.String()
if settings, ok := v.(*config.SlackSettings); ok {
value["bot_token"] = settings.BotToken.String()
value["app_token"] = settings.AppToken.String()
}
case "matrix":
value["token"] = ch.Matrix.AccessToken.String()
if settings, ok := v.(*config.MatrixSettings); ok {
value["token"] = settings.AccessToken.String()
}
case "onebot":
value["token"] = ch.OneBot.AccessToken.String()
if settings, ok := v.(*config.OneBotSettings); ok {
value["token"] = settings.AccessToken.String()
}
case "line":
value["token"] = ch.LINE.ChannelAccessToken.String()
value["secret"] = ch.LINE.ChannelSecret.String()
if settings, ok := v.(*config.LINESettings); ok {
value["token"] = settings.ChannelAccessToken.String()
value["secret"] = settings.ChannelSecret.String()
}
case "wecom":
value["secret"] = ch.WeCom.Secret.String()
if settings, ok := v.(*config.WeComSettings); ok {
value["secret"] = settings.Secret.String()
}
case "dingtalk":
value["secret"] = ch.DingTalk.ClientSecret.String()
if settings, ok := v.(*config.DingTalkSettings); ok {
value["secret"] = settings.ClientSecret.String()
}
case "qq":
value["secret"] = ch.QQ.AppSecret.String()
if settings, ok := v.(*config.QQSettings); ok {
value["secret"] = settings.AppSecret.String()
}
case "irc":
value["password"] = ch.IRC.Password.String()
value["serv_password"] = ch.IRC.NickServPassword.String()
value["sasl_password"] = ch.IRC.SASLPassword.String()
if settings, ok := v.(*config.IRCSettings); ok {
value["password"] = settings.Password.String()
value["serv_password"] = settings.NickServPassword.String()
value["sasl_password"] = settings.SASLPassword.String()
}
case "feishu":
value["app_secret"] = ch.Feishu.AppSecret.String()
value["encrypt_key"] = ch.Feishu.EncryptKey.String()
value["verification_token"] = ch.Feishu.VerificationToken.String()
if settings, ok := v.(*config.FeishuSettings); ok {
value["app_secret"] = settings.AppSecret.String()
value["encrypt_key"] = settings.EncryptKey.String()
value["verification_token"] = settings.VerificationToken.String()
}
case "teams_webhook":
// Expose webhook URLs for hash computation (they contain secrets)
vv := value["webhooks"]
webhooks := make(map[string]string)
for name, target := range ch.TeamsWebhook.Webhooks {
webhooks[name] = target.WebhookURL.String()
if vv != nil {
webhooks = vv.(map[string]string)
}
if settings, ok := v.(*config.TeamsWebhookSettings); ok {
for name, target := range settings.Webhooks {
webhooks[name] = target.WebhookURL.String()
}
}
value["webhooks"] = webhooks
}
@@ -92,94 +125,13 @@ func compareChannels(old, news map[string]string) (added, removed []string) {
}
func toChannelConfig(cfg *config.Config, list []string) (*config.ChannelsConfig, error) {
result := &config.ChannelsConfig{}
ch := cfg.Channels
// should not be error
marshal, _ := json.Marshal(ch)
var channelConfig map[string]map[string]any
_ = json.Unmarshal(marshal, &channelConfig)
temp := make(map[string]map[string]any, 0)
for key, value := range channelConfig {
found := false
for _, s := range list {
if key == s {
found = true
break
}
}
if !found || !value["enabled"].(bool) {
result := make(config.ChannelsConfig)
for _, name := range list {
bc, ok := cfg.Channels[name]
if !ok || !bc.Enabled {
continue
}
temp[key] = value
}
marshal, err := json.Marshal(temp)
if err != nil {
logger.Errorf("marshal error: %v", err)
return nil, err
}
err = json.Unmarshal(marshal, result)
if err != nil {
logger.Errorf("unmarshal error: %v", err)
return nil, err
}
updateKeys(result, &ch)
return result, nil
}
func updateKeys(newcfg, old *config.ChannelsConfig) {
if newcfg.Pico.Enabled {
newcfg.Pico.Token = old.Pico.Token
}
if newcfg.Telegram.Enabled {
newcfg.Telegram.Token = old.Telegram.Token
}
if newcfg.Discord.Enabled {
newcfg.Discord.Token = old.Discord.Token
}
if newcfg.Slack.Enabled {
newcfg.Slack.BotToken = old.Slack.BotToken
newcfg.Slack.AppToken = old.Slack.AppToken
}
if newcfg.Matrix.Enabled {
newcfg.Matrix.AccessToken = old.Matrix.AccessToken
}
if newcfg.OneBot.Enabled {
newcfg.OneBot.AccessToken = old.OneBot.AccessToken
}
if newcfg.LINE.Enabled {
newcfg.LINE.ChannelAccessToken = old.LINE.ChannelAccessToken
newcfg.LINE.ChannelSecret = old.LINE.ChannelSecret
}
if newcfg.WeCom.Enabled {
newcfg.WeCom.Secret = old.WeCom.Secret
}
if newcfg.DingTalk.Enabled {
newcfg.DingTalk.ClientSecret = old.DingTalk.ClientSecret
}
if newcfg.QQ.Enabled {
newcfg.QQ.AppSecret = old.QQ.AppSecret
}
if newcfg.IRC.Enabled {
newcfg.IRC.Password = old.IRC.Password
newcfg.IRC.NickServPassword = old.IRC.NickServPassword
newcfg.IRC.SASLPassword = old.IRC.SASLPassword
}
if newcfg.Feishu.Enabled {
newcfg.Feishu.AppSecret = old.Feishu.AppSecret
newcfg.Feishu.EncryptKey = old.Feishu.EncryptKey
newcfg.Feishu.VerificationToken = old.Feishu.VerificationToken
}
if newcfg.TeamsWebhook.Enabled {
// Copy SecureString webhook URLs from old config
for name, oldTarget := range old.TeamsWebhook.Webhooks {
if newTarget, ok := newcfg.TeamsWebhook.Webhooks[name]; ok {
newTarget.WebhookURL = oldTarget.WebhookURL
newcfg.TeamsWebhook.Webhooks[name] = newTarget
}
}
result[name] = bc
}
return &result, nil
}
+111 -9
View File
@@ -1,6 +1,7 @@
package channels
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
@@ -15,37 +16,138 @@ func TestToChannelHashes(t *testing.T) {
results := toChannelHashes(cfg)
assert.Equal(t, 0, len(results))
logger.Debugf("results: %v", results)
// Add dingtalk channel via map
cfg2 := config.DefaultConfig()
cfg2.Channels.DingTalk.Enabled = true
cfg2.Channels["dingtalk"] = &config.Channel{
Enabled: true,
Type: config.ChannelDingTalk,
Settings: config.RawNode(`{"enabled":true}`),
}
results2 := toChannelHashes(cfg2)
assert.Equal(t, 1, len(results2))
logger.Debugf("results2: %v", results2)
added, removed := compareChannels(results, results2)
assert.EqualValues(t, []string{"dingtalk"}, added)
assert.EqualValues(t, []string(nil), removed)
// Add telegram channel
cfg3 := config.DefaultConfig()
cfg3.Channels.Telegram.Enabled = true
cfg3.Channels["telegram"] = &config.Channel{
Enabled: true,
Type: config.ChannelTelegram,
Settings: config.RawNode(`{"enabled":true,"token":"test-token"}`),
}
results3 := toChannelHashes(cfg3)
assert.Equal(t, 1, len(results3))
logger.Debugf("results3: %v", results3)
added, removed = compareChannels(results2, results3)
assert.EqualValues(t, []string{"dingtalk"}, removed)
assert.EqualValues(t, []string{"telegram"}, added)
cfg3.Channels.Telegram.SetToken("114314")
// Modify telegram channel — hash should change
cfg3.Channels["telegram"] = &config.Channel{
Enabled: true,
Type: config.ChannelTelegram,
Settings: config.RawNode(`{"enabled":true,"token":"114314"}`),
}
results4 := toChannelHashes(cfg3)
assert.Equal(t, 1, len(results4))
logger.Debugf("results4: %v", results4)
added, removed = compareChannels(results3, results4)
assert.EqualValues(t, []string{"telegram"}, removed)
assert.EqualValues(t, []string{"telegram"}, added)
// toChannelConfig with telegram
cc, err := toChannelConfig(cfg3, added)
assert.NoError(t, err)
logger.Debugf("cc: %#v", cc.Telegram)
assert.Equal(t, "114314", cc.Telegram.Token.String())
assert.Equal(t, true, cc.Telegram.Enabled)
bc := cc.Get("telegram")
assert.NotNil(t, bc)
var tc config.TelegramSettings
bc.Decode(&tc)
assert.Equal(t, "114314", tc.Token.String())
assert.Equal(t, true, bc.Enabled)
// toChannelConfig with dingtalk (no telegram)
cc, err = toChannelConfig(cfg2, added)
assert.NoError(t, err)
logger.Debugf("cc: %#v", cc.Telegram)
assert.Equal(t, "", cc.Telegram.Token.String())
assert.Equal(t, false, cc.Telegram.Enabled)
bc = cc.Get("telegram")
assert.Nil(t, bc)
}
func TestToChannelHashes_SerializationStability(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Channels["test"] = &config.Channel{
Enabled: true,
Settings: config.RawNode(`{"enabled":true,"key":"value"}`),
}
h1 := toChannelHashes(cfg)
// Same config should produce same hash
cfg2 := config.DefaultConfig()
cfg2.Channels["test"] = &config.Channel{
Enabled: true,
Settings: config.RawNode(`{"enabled":true,"key":"value"}`),
}
h2 := toChannelHashes(cfg2)
assert.Equal(t, h1["test"], h2["test"])
}
func TestCompareChannels_NoChanges(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Channels["a"] = &config.Channel{Enabled: true, Settings: config.RawNode(`{}`)}
cfg.Channels["b"] = &config.Channel{Enabled: true, Settings: config.RawNode(`{}`)}
h := toChannelHashes(cfg)
added, removed := compareChannels(h, h)
assert.EqualValues(t, []string(nil), added)
assert.EqualValues(t, []string(nil), removed)
}
func TestToChannelConfig_EmptyList(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Channels["test"] = &config.Channel{Enabled: true, Settings: config.RawNode(`{}`)}
cc, err := toChannelConfig(cfg, []string{})
assert.NoError(t, err)
assert.Equal(t, 0, len(*cc))
}
func TestToChannelHashes_NonEnabledSkipped(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Channels["test"] = &config.Channel{Enabled: false, Settings: config.RawNode(`{"enabled":false}`)}
h := toChannelHashes(cfg)
assert.Equal(t, 0, len(h))
}
func TestToChannelHashes_InvalidJSON(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Channels["test"] = &config.Channel{
Enabled: true,
Settings: config.RawNode(`invalid-json`),
}
// Should not panic, just skip the invalid entry
h := toChannelHashes(cfg)
assert.Equal(t, 0, len(h))
}
func TestToChannelHashes_RealWorldChannel(t *testing.T) {
cfg := config.DefaultConfig()
// Simulate a telegram channel config
telegramSettings, _ := json.Marshal(map[string]any{
"enabled": true,
"token": "123456:ABC-DEF",
})
cfg.Channels["telegram"] = &config.Channel{
Enabled: true,
Type: config.ChannelTelegram,
Settings: config.RawNode(telegramSettings),
}
h := toChannelHashes(cfg)
assert.Equal(t, 1, len(h))
assert.Contains(t, h, "telegram")
}
+143 -52
View File
@@ -175,11 +175,11 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
if err := m.bus.PublishOutbound(pubCtx, testOutboundMessage(bus.OutboundMessage{
Channel: "good",
ChatID: "chat-1",
Content: "hello",
}); err != nil {
})); err != nil {
t.Fatalf("PublishOutbound() error = %v", err)
}
@@ -197,6 +197,20 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
}
}
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
}
return bus.NormalizeOutboundMessage(msg)
}
func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage {
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, "")
}
return bus.NormalizeOutboundMediaMessage(msg)
}
func TestSendWithRetry_Success(t *testing.T) {
m := newTestManager()
var callCount int
@@ -212,7 +226,7 @@ func TestSendWithRetry_Success(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -239,7 +253,7 @@ func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -263,7 +277,7 @@ func TestSendWithRetry_PermanentFailure(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -287,7 +301,7 @@ func TestSendWithRetry_NotRunning(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -314,7 +328,7 @@ func TestSendWithRetry_RateLimitRetry(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
@@ -344,7 +358,7 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -370,11 +384,11 @@ func TestSendMedia_Success(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err != nil {
t.Fatalf("SendMedia() error = %v", err)
}
@@ -397,11 +411,11 @@ func TestSendMedia_PropagatesFailure(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err == nil {
t.Fatal("expected SendMedia to return error")
}
@@ -424,11 +438,11 @@ func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err == nil {
t.Fatal("expected SendMedia to return error for unsupported channel")
}
@@ -454,11 +468,11 @@ func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) {
m.workers["test"] = w
m.RecordPlaceholder("test", "chat1", "placeholder-1")
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err != nil {
t.Fatalf("SendMedia() error = %v", err)
}
@@ -491,7 +505,7 @@ func TestSendWithRetry_UnknownError(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -515,7 +529,7 @@ func TestSendWithRetry_ContextCancelled(t *testing.T) {
}
ctx, cancel := context.WithCancel(context.Background())
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
// Cancel context after first Send attempt returns
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
@@ -561,7 +575,7 @@ func TestWorkerRateLimiter(t *testing.T) {
// Enqueue 4 messages
for i := range 4 {
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)})
}
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
@@ -586,7 +600,7 @@ func TestWorkerRateLimiter(t *testing.T) {
func TestNewChannelWorker_DefaultRate(t *testing.T) {
ch := &mockChannel{}
w := newChannelWorker("unknown_channel", ch)
w := newChannelWorker("unknown_channel", ch, "unknown_channel")
if w.limiter == nil {
t.Fatal("expected limiter to be non-nil")
@@ -599,10 +613,10 @@ func TestNewChannelWorker_DefaultRate(t *testing.T) {
func TestNewChannelWorker_ConfiguredRate(t *testing.T) {
ch := &mockChannel{}
for name, expectedRate := range channelRateConfig {
w := newChannelWorker(name, ch)
for channelType, expectedRate := range channelRateConfig {
w := newChannelWorker(channelType, ch, channelType)
if w.limiter.Limit() != rate.Limit(expectedRate) {
t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit())
t.Fatalf("channel %s: expected rate %v, got %v", channelType, expectedRate, w.limiter.Limit())
}
}
}
@@ -637,7 +651,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
go m.runWorker(ctx, "test", w)
// Send a message that should be split
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"})
time.Sleep(100 * time.Millisecond)
@@ -678,7 +692,7 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
@@ -738,7 +752,7 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
// Register placeholder
m.RecordPlaceholder("test", "123", "456")
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if !edited {
@@ -768,7 +782,7 @@ func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) {
m.RecordPlaceholder("test", "123", "456")
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if edited {
@@ -827,7 +841,7 @@ func TestPreSend_TypingStopCalled(t *testing.T) {
stopCalled = true
})
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
m.preSend(context.Background(), "test", msg, ch)
if !stopCalled {
@@ -844,7 +858,7 @@ func TestPreSend_NoRegisteredState(t *testing.T) {
},
}
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if edited {
@@ -874,7 +888,7 @@ func TestPreSend_TypingAndPlaceholder(t *testing.T) {
})
m.RecordPlaceholder("test", "123", "456")
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if !stopCalled {
@@ -938,7 +952,7 @@ func TestRecordTypingStop_ReplacesExistingStop(t *testing.T) {
t.Fatalf("expected replacement typing stop to stay active until preSend, got %d calls", newStopCalls)
}
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
m.preSend(context.Background(), "test", msg, &mockChannel{})
if newStopCalls != 1 {
@@ -972,7 +986,7 @@ func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) {
limiter: rate.NewLimiter(rate.Inf, 1),
}
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
m.sendWithRetry(context.Background(), "test", w, msg)
if sendCalled {
@@ -1135,7 +1149,7 @@ func TestPreSendStillWorksWithWrappedTypes(t *testing.T) {
})
m.RecordPlaceholder("test", "chat1", "ph_id")
msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if !stopCalled {
@@ -1222,7 +1236,7 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
return nil
},
}
worker := newChannelWorker("mock", mockCh)
worker := newChannelWorker("mock", mockCh, "mock")
mgr.channels["mock"] = mockCh
mgr.workers["mock"] = worker
@@ -1238,11 +1252,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
// Transcription feedback arrives first — it should consume the placeholder
// and be delivered via EditMessage, not Send.
msgTranscript := bus.OutboundMessage{
msgTranscript := testOutboundMessage(bus.OutboundMessage{
Channel: "mock",
ChatID: "chat-1",
Content: "Transcript: hello",
}
})
mgr.sendWithRetry(ctx, "mock", worker, msgTranscript)
if mockCh.editedMessages != 1 {
@@ -1258,11 +1272,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
}
// Final LLM response arrives — no placeholder left, so it goes through Send
msgFinal := bus.OutboundMessage{
msgFinal := testOutboundMessage(bus.OutboundMessage{
Channel: "mock",
ChatID: "chat-1",
Content: "Final Answer",
}
})
mgr.sendWithRetry(ctx, "mock", worker, msgFinal)
if len(mockCh.sentMessages) != 1 {
@@ -1288,12 +1302,12 @@ func TestSendMessage_Synchronous(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "hello world",
ReplyToMessageID: "msg-456",
}
})
err := m.SendMessage(context.Background(), msg)
if err != nil {
@@ -1315,11 +1329,11 @@ func TestSendMessage_Synchronous(t *testing.T) {
func TestSendMessage_UnknownChannel(t *testing.T) {
m := newTestManager()
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "nonexistent",
ChatID: "123",
Content: "hello",
}
})
err := m.SendMessage(context.Background(), msg)
if err == nil {
@@ -1336,11 +1350,11 @@ func TestSendMessage_NoWorker(t *testing.T) {
m.channels["test"] = ch
// No worker registered
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "hello",
}
})
err := m.SendMessage(context.Background(), msg)
if err == nil {
@@ -1369,11 +1383,11 @@ func TestSendMessage_WithRetry(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "retry me",
}
})
err := m.SendMessage(context.Background(), msg)
if err != nil {
@@ -1385,6 +1399,46 @@ func TestSendMessage_WithRetry(t *testing.T) {
}
}
func TestSendMessage_ContextOnlyUsesContextAddressing(t *testing.T) {
m := newTestManager()
var received []bus.OutboundMessage
ch := &mockChannel{
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
received = append(received, msg)
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
msg := testOutboundMessage(bus.OutboundMessage{
Context: bus.NewOutboundContext("test", "123", "msg-9"),
Content: "hello",
})
if err := m.SendMessage(context.Background(), msg); err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(received) != 1 {
t.Fatalf("expected 1 message sent, got %d", len(received))
}
if received[0].Channel != "test" || received[0].ChatID != "123" {
t.Fatalf("expected mirrored legacy address, got %+v", received[0])
}
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "123" {
t.Fatalf("expected context address to be preserved, got %+v", received[0].Context)
}
if received[0].ReplyToMessageID != "msg-9" {
t.Fatalf("expected reply_to_message_id msg-9, got %q", received[0].ReplyToMessageID)
}
}
func TestSendMessage_WithSplitting(t *testing.T) {
m := newTestManager()
@@ -1406,11 +1460,11 @@ func TestSendMessage_WithSplitting(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "hello world",
}
})
err := m.SendMessage(context.Background(), msg)
if err != nil {
@@ -1422,6 +1476,43 @@ func TestSendMessage_WithSplitting(t *testing.T) {
}
}
func TestSendMedia_ContextOnlyUsesContextAddressing(t *testing.T) {
m := newTestManager()
var received []bus.OutboundMediaMessage
ch := &mockMediaChannel{
sendMediaFn: func(_ context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
received = append(received, msg)
return nil, nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
msg := testOutboundMediaMessage(bus.OutboundMediaMessage{
Context: bus.NewOutboundContext("test", "media-chat", ""),
Parts: []bus.MediaPart{{Type: "image", Ref: "media://1"}},
})
if err := m.SendMedia(context.Background(), msg); err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(received) != 1 {
t.Fatalf("expected 1 media message sent, got %d", len(received))
}
if received[0].Channel != "test" || received[0].ChatID != "media-chat" {
t.Fatalf("expected mirrored legacy media address, got %+v", received[0])
}
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "media-chat" {
t.Fatalf("expected media context address to be preserved, got %+v", received[0].Context)
}
}
func TestSendMessage_PreservesOrdering(t *testing.T) {
m := newTestManager()
@@ -1441,12 +1532,12 @@ func TestSendMessage_PreservesOrdering(t *testing.T) {
m.workers["test"] = w
// Send two messages sequentially — they must arrive in order
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
Channel: "test", ChatID: "1", Content: "first",
})
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
}))
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
Channel: "test", ChatID: "1", Content: "second",
})
}))
if len(order) != 2 {
t.Fatalf("expected 2 messages, got %d", len(order))
+26 -8
View File
@@ -9,12 +9,30 @@ import (
)
func init() {
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
matrixCfg := cfg.Channels.Matrix
cryptoDatabasePath := matrixCfg.CryptoDatabasePath
if cryptoDatabasePath == "" {
cryptoDatabasePath = filepath.Join(cfg.WorkspacePath(), "matrix")
}
return NewMatrixChannel(matrixCfg, b, cryptoDatabasePath)
})
channels.RegisterFactory(
config.ChannelMatrix,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.MatrixSettings)
if !ok {
return nil, channels.ErrSendFailed
}
cryptoDatabasePath := c.CryptoDatabasePath
if cryptoDatabasePath == "" {
cryptoDatabasePath = filepath.Join(cfg.WorkspacePath(), "matrix")
}
ch, err := NewMatrixChannel(bc, c, b, cryptoDatabasePath)
if err != nil {
return nil, err
}
if channelName != config.ChannelMatrix {
ch.SetName(channelName)
}
return ch, nil
},
)
}
+25 -22
View File
@@ -174,9 +174,10 @@ func (s *typingSession) stop() {
// MatrixChannel implements the Channel interface for Matrix.
type MatrixChannel struct {
*channels.BaseChannel
bc *config.Channel
client *mautrix.Client
config config.MatrixConfig
config *config.MatrixSettings
syncer *mautrix.DefaultSyncer
ctx context.Context
@@ -194,7 +195,8 @@ type MatrixChannel struct {
}
func NewMatrixChannel(
cfg config.MatrixConfig,
bc *config.Channel,
cfg *config.MatrixSettings,
messageBus *bus.MessageBus,
cryptoDatabasePath string,
) (*MatrixChannel, error) {
@@ -228,14 +230,15 @@ func NewMatrixChannel(
"matrix",
cfg,
messageBus,
cfg.AllowFrom,
bc.AllowFrom,
channels.WithMaxMessageLength(65536),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &MatrixChannel{
BaseChannel: base,
bc: bc,
client: client,
config: cfg,
syncer: syncer,
@@ -570,7 +573,7 @@ func (c *MatrixChannel) StartTyping(ctx context.Context, chatID string) (func(),
// SendPlaceholder implements channels.PlaceholderCapable.
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
if !c.config.Placeholder.Enabled {
if !c.bc.Placeholder.Enabled {
return "", nil
}
@@ -579,7 +582,7 @@ func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (str
return "", fmt.Errorf("matrix room ID is empty")
}
text := c.config.Placeholder.GetRandomText()
text := c.bc.Placeholder.GetRandomText()
resp, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgNotice,
@@ -720,8 +723,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
logger.DebugCF("matrix", "Ignoring group message by trigger rules", map[string]any{
"room_id": roomID,
"is_mentioned": isMentioned,
"mention_only": c.config.GroupTrigger.MentionOnly,
"prefixes": c.config.GroupTrigger.Prefixes,
"mention_only": c.bc.GroupTrigger.MentionOnly,
"prefixes": c.bc.GroupTrigger.Prefixes,
})
return
}
@@ -736,10 +739,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
}
peerKind := "direct"
peerID := senderID
if isGroup {
peerKind = "group"
peerID = roomID
}
metadata := map[string]string{
@@ -752,17 +753,19 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
metadata["reply_to_msg_id"] = replyTo.String()
}
c.HandleMessage(
c.baseContext(),
bus.Peer{Kind: peerKind, ID: peerID},
evt.ID.String(),
senderID,
roomID,
content,
mediaPaths,
metadata,
sender,
)
inboundCtx := bus.InboundContext{
Channel: "matrix",
ChatID: roomID,
ChatType: peerKind,
SenderID: senderID,
MessageID: evt.ID.String(),
Raw: metadata,
}
if replyTo := msgEvt.GetRelatesTo().GetReplyTo(); replyTo != "" {
inboundCtx.ReplyToMessageID = replyTo.String()
}
c.HandleInboundContext(c.baseContext(), roomID, content, mediaPaths, inboundCtx, sender)
}
// decryptEvent decrypts an encrypted event and returns the decrypted message event content.
+3 -3
View File
@@ -437,9 +437,9 @@ func TestMarkdownToHTML(t *testing.T) {
}
func TestMessageContent(t *testing.T) {
richtext := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "richtext"}}
plain := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "plain"}}
defaultt := &MatrixChannel{config: config.MatrixConfig{}}
richtext := &MatrixChannel{config: &config.MatrixSettings{MessageFormat: "richtext"}}
plain := &MatrixChannel{config: &config.MatrixSettings{MessageFormat: "plain"}}
defaultt := &MatrixChannel{config: &config.MatrixSettings{}}
for _, c := range []*MatrixChannel{richtext, defaultt} {
mc := c.messageContent("**hi**")
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewOneBotChannel(cfg.Channels.OneBot, b)
})
channels.RegisterFactory(
config.ChannelOneBot,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.OneBotSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewOneBotChannel(bc, c, b)
},
)
}
+27 -10
View File
@@ -23,7 +23,7 @@ import (
type OneBotChannel struct {
*channels.BaseChannel
config config.OneBotConfig
config *config.OneBotSettings
conn *websocket.Conn
ctx context.Context
cancel context.CancelFunc
@@ -96,10 +96,14 @@ type oneBotMessageSegment struct {
Data map[string]any `json:"data"`
}
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom,
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
func NewOneBotChannel(
bc *config.Channel,
cfg *config.OneBotSettings,
messageBus *bus.MessageBus,
) (*OneBotChannel, error) {
base := channels.NewBaseChannel("onebot", cfg, messageBus, bc.AllowFrom,
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
const dedupSize = 1024
@@ -991,8 +995,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
senderID := strconv.FormatInt(userID, 10)
var chatID string
var peer bus.Peer
var contextChatID string
var contextChatType string
metadata := map[string]string{}
@@ -1003,12 +1007,14 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
switch raw.MessageType {
case "private":
chatID = "private:" + senderID
peer = bus.Peer{Kind: "direct", ID: senderID}
contextChatID = senderID
contextChatType = "direct"
case "group":
groupIDStr := strconv.FormatInt(groupID, 10)
chatID = "group:" + groupIDStr
peer = bus.Peer{Kind: "group", ID: groupIDStr}
contextChatID = groupIDStr
contextChatType = "group"
metadata["group_id"] = groupIDStr
senderUserID, _ := parseJSONInt64(sender.UserID)
@@ -1072,7 +1078,18 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
return
}
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: contextChatID,
ChatType: contextChatType,
SenderID: senderID,
MessageID: messageID,
Mentioned: isBotMentioned,
ReplyToMessageID: parsed.ReplyTo,
Raw: metadata,
}
c.HandleInboundContext(c.ctx, chatID, content, parsed.Media, inboundCtx, senderInfo)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
+23 -11
View File
@@ -22,7 +22,7 @@ import (
// PicoClientChannel connects to a remote Pico Protocol WebSocket server.
type PicoClientChannel struct {
*channels.BaseChannel
config config.PicoClientConfig
config *config.PicoClientSettings
conn *picoConn
mu sync.Mutex
ctx context.Context
@@ -31,14 +31,15 @@ type PicoClientChannel struct {
// NewPicoClientChannel creates a new Pico Protocol client channel.
func NewPicoClientChannel(
cfg config.PicoClientConfig,
bc *config.Channel,
cfg *config.PicoClientSettings,
messageBus *bus.MessageBus,
) (*PicoClientChannel, error) {
if cfg.URL == "" {
return nil, fmt.Errorf("pico_client url is required")
}
base := channels.NewBaseChannel("pico_client", cfg, messageBus, cfg.AllowFrom)
base := channels.NewBaseChannel("pico_client", cfg, messageBus, bc.AllowFrom)
return &PicoClientChannel{
BaseChannel: base,
@@ -242,7 +243,11 @@ func (c *PicoClientChannel) handleInbound(pc *picoConn, msg PicoMessage) {
}
func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
content, _ := msg.Payload["content"].(string)
if isThoughtPayload(msg.Payload) {
return
}
content, _ := msg.Payload[PayloadKeyContent].(string)
if strings.TrimSpace(content) == "" {
return
}
@@ -254,8 +259,6 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
chatID := "pico_client:" + sessionID
senderID := "pico-remote"
peer := bus.Peer{Kind: "direct", ID: chatID}
sender := bus.SenderInfo{
Platform: "pico_client",
PlatformID: senderID,
@@ -266,10 +269,19 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
return
}
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{
"platform": "pico_client",
"session_id": sessionID,
}, sender)
inboundCtx := bus.InboundContext{
Channel: "pico_client",
ChatID: chatID,
ChatType: "direct",
SenderID: senderID,
MessageID: msg.ID,
Raw: map[string]string{
"platform": "pico_client",
"session_id": sessionID,
},
}
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
}
// Send sends a message to the remote server.
@@ -285,7 +297,7 @@ func (c *PicoClientChannel) Send(ctx context.Context, msg bus.OutboundMessage) (
}
outMsg := newMessage(TypeMessageSend, map[string]any{
"content": msg.Content,
PayloadKeyContent: msg.Content,
})
outMsg.SessionID = strings.TrimPrefix(msg.ChatID, "pico_client:")
return nil, pc.writeJSON(outMsg)
+83 -9
View File
@@ -18,7 +18,8 @@ import (
)
func TestNewPicoClientChannel_MissingURL(t *testing.T) {
_, err := NewPicoClientChannel(config.PicoClientConfig{}, bus.NewMessageBus())
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
_, err := NewPicoClientChannel(bc, &config.PicoClientSettings{}, bus.NewMessageBus())
if err == nil {
t.Fatal("expected error for missing URL")
}
@@ -28,7 +29,8 @@ func TestNewPicoClientChannel_MissingURL(t *testing.T) {
}
func TestNewPicoClientChannel_OK(t *testing.T) {
ch, err := NewPicoClientChannel(config.PicoClientConfig{
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: "ws://localhost:9999/ws",
}, bus.NewMessageBus())
if err != nil {
@@ -40,7 +42,8 @@ func TestNewPicoClientChannel_OK(t *testing.T) {
}
func TestSend_NotRunning(t *testing.T) {
ch, err := NewPicoClientChannel(config.PicoClientConfig{
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: "ws://localhost:9999/ws",
}, bus.NewMessageBus())
if err != nil {
@@ -104,7 +107,8 @@ func TestClientChannel_ConnectAndSend(t *testing.T) {
defer srv.Close()
mb := bus.NewMessageBus()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: wsURL(srv.URL),
Token: *config.NewSecureString("test-token"),
SessionID: "sess-1",
@@ -137,7 +141,8 @@ func TestClientChannel_AuthFailure(t *testing.T) {
srv := testServer(t, "correct-token")
defer srv.Close()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: wsURL(srv.URL),
Token: *config.NewSecureString("wrong-token"),
}, bus.NewMessageBus())
@@ -161,7 +166,8 @@ func TestClientChannel_ReceivesServerMessage(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: wsURL(srv.URL),
SessionID: "sess-echo",
ReadTimeout: 10,
@@ -203,7 +209,8 @@ func TestClientChannel_StartTyping(t *testing.T) {
srv := testServer(t, "")
defer srv.Close()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: wsURL(srv.URL),
SessionID: "sess-type",
ReadTimeout: 10,
@@ -231,7 +238,8 @@ func TestSend_ClosedConnection(t *testing.T) {
srv := testServer(t, "")
defer srv.Close()
ch, err := NewPicoClientChannel(config.PicoClientConfig{
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: wsURL(srv.URL),
SessionID: "sess-close",
ReadTimeout: 10,
@@ -279,7 +287,8 @@ func TestParseInlineImageMedia_Valid(t *testing.T) {
func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewPicoChannel(config.PicoConfig{
bc := &config.Channel{Type: "pico", Enabled: true}
ch, err := NewPicoChannel(bc, &config.PicoSettings{
Token: *config.NewSecureString("test-token"),
}, mb)
if err != nil {
@@ -316,3 +325,68 @@ func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
t.Fatal("timed out waiting for inbound media message")
}
}
func TestIsThoughtPayload(t *testing.T) {
tests := []struct {
name string
payload map[string]any
want bool
}{
{
name: "explicit thought bool",
payload: map[string]any{PayloadKeyThought: true},
want: true,
},
{
name: "thought false",
payload: map[string]any{PayloadKeyThought: false},
want: false,
},
{
name: "thought string ignored",
payload: map[string]any{PayloadKeyThought: "true"},
want: false,
},
{
name: "default normal",
payload: map[string]any{PayloadKeyContent: "hello"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isThoughtPayload(tt.payload); got != tt.want {
t.Fatalf("isThoughtPayload() = %v, want %v", got, tt.want)
}
})
}
}
func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
mb := bus.NewMessageBus()
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: "ws://localhost:8080/ws",
}, mb)
if err != nil {
t.Fatalf("NewPicoClientChannel() error = %v", err)
}
ch.ctx = context.Background()
pc := &picoConn{sessionID: "sess-thought"}
ch.handleServerMessage(pc, PicoMessage{
Type: TypeMessageCreate,
Payload: map[string]any{
PayloadKeyContent: "internal reasoning",
PayloadKeyThought: true,
},
})
select {
case msg := <-mb.InboundChan():
t.Fatalf("expected no inbound publish for thought payload, got %+v", msg)
case <-time.After(150 * time.Millisecond):
}
}
+44 -6
View File
@@ -7,10 +7,48 @@ import (
)
func init() {
channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewPicoChannel(cfg.Channels.Pico, b)
})
channels.RegisterFactory("pico_client", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewPicoClientChannel(cfg.Channels.PicoClient, b)
})
channels.RegisterFactory(
config.ChannelPico,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.PicoSettings)
if !ok {
return nil, channels.ErrSendFailed
}
ch, err := NewPicoChannel(bc, c, b)
if err != nil {
return nil, err
}
if channelName != config.ChannelPico {
ch.SetName(channelName)
}
return ch, nil
},
)
channels.RegisterFactory(
config.ChannelPicoClient,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.PicoClientSettings)
if !ok {
return nil, channels.ErrSendFailed
}
ch, err := NewPicoClientChannel(bc, c, b)
if err != nil {
return nil, err
}
if channelName != config.ChannelPicoClient {
ch.SetName(channelName)
}
return ch, nil
},
)
}
+34 -11
View File
@@ -39,6 +39,13 @@ var allowedInlineImageMIMETypes = map[string]struct{}{
"image/bmp": {},
}
func outboundMessageIsThought(msg bus.OutboundMessage) bool {
if len(msg.Context.Raw) == 0 {
return false
}
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindThought)
}
// writeJSON sends a JSON message to the connection with write locking.
func (pc *picoConn) writeJSON(v any) error {
if pc.closed.Load() {
@@ -63,7 +70,8 @@ func (pc *picoConn) close() {
// It serves as the reference implementation for all optional capability interfaces.
type PicoChannel struct {
*channels.BaseChannel
config config.PicoConfig
bc *config.Channel
config *config.PicoSettings
upgrader websocket.Upgrader
connections map[string]*picoConn // connID -> *picoConn
sessionConnections map[string]map[string]*picoConn // sessionID -> connID -> *picoConn
@@ -73,12 +81,16 @@ type PicoChannel struct {
}
// NewPicoChannel creates a new Pico Protocol channel.
func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) {
func NewPicoChannel(
bc *config.Channel,
cfg *config.PicoSettings,
messageBus *bus.MessageBus,
) (*PicoChannel, error) {
if cfg.Token.String() == "" {
return nil, fmt.Errorf("pico token is required")
}
base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom)
base := channels.NewBaseChannel("pico", cfg, messageBus, bc.AllowFrom)
allowOrigins := cfg.AllowOrigins
checkOrigin := func(r *http.Request) bool {
@@ -96,6 +108,7 @@ func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoCha
return &PicoChannel{
BaseChannel: base,
bc: bc,
config: cfg,
upgrader: websocket.Upgrader{
CheckOrigin: checkOrigin,
@@ -247,9 +260,11 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
isThought := outboundMessageIsThought(msg)
outMsg := newMessage(TypeMessageCreate, map[string]any{
"content": msg.Content,
PayloadKeyContent: msg.Content,
PayloadKeyThought: isThought,
})
return nil, c.broadcastToSession(msg.ChatID, outMsg)
@@ -280,16 +295,17 @@ func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), e
// It sends a placeholder message via the Pico Protocol that will later be
// edited to the actual response via EditMessage (channels.MessageEditor).
func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
if !c.config.Placeholder.Enabled {
if !c.bc.Placeholder.Enabled {
return "", nil
}
text := c.config.Placeholder.GetRandomText()
text := c.bc.Placeholder.GetRandomText()
msgID := uuid.New().String()
outMsg := newMessage(TypeMessageCreate, map[string]any{
"content": text,
"message_id": msgID,
PayloadKeyContent: text,
PayloadKeyThought: false,
"message_id": msgID,
})
if err := c.broadcastToSession(chatID, outMsg); err != nil {
@@ -562,8 +578,6 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
chatID := "pico:" + sessionID
senderID := "pico-user"
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
metadata := map[string]string{
"platform": "pico",
"session_id": sessionID,
@@ -586,7 +600,16 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
return
}
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, media, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "pico",
ChatID: chatID,
ChatType: "direct",
SenderID: senderID,
MessageID: msg.ID,
Raw: metadata,
}
c.HandleInboundContext(c.ctx, chatID, content, media, inboundCtx, sender)
}
// truncate truncates a string to maxLen runes.
+3 -2
View File
@@ -15,9 +15,10 @@ import (
func newTestPicoChannel(t *testing.T) *PicoChannel {
t.Helper()
cfg := config.PicoConfig{}
bc := &config.Channel{Type: config.ChannelPico, Enabled: true}
cfg := &config.PicoSettings{}
cfg.SetToken("test-token")
ch, err := NewPicoChannel(cfg, bus.NewMessageBus())
ch, err := NewPicoChannel(bc, cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("NewPicoChannel: %v", err)
}
+10
View File
@@ -19,6 +19,11 @@ const (
TypePong = "pong"
PicoTokenPrefix = "pico-"
PayloadKeyContent = "content"
PayloadKeyThought = "thought"
MessageKindThought = "thought"
)
// PicoMessage is the wire format for all Pico Protocol messages.
@@ -39,6 +44,11 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
}
}
func isThoughtPayload(payload map[string]any) bool {
thought, _ := payload[PayloadKeyThought].(bool)
return thought
}
func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage {
payload := map[string]any{
"code": code,
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewQQChannel(cfg.Channels.QQ, b)
})
channels.RegisterFactory(
config.ChannelQQ,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.QQSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewQQChannel(bc, c, b)
},
)
}
+40 -27
View File
@@ -56,7 +56,8 @@ type qqAPI interface {
type QQChannel struct {
*channels.BaseChannel
config config.QQConfig
bc *config.Channel
config *config.QQSettings
api qqAPI
tokenSource oauth2.TokenSource
ctx context.Context
@@ -82,15 +83,16 @@ type QQChannel struct {
stopOnce sync.Once
}
func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) {
base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom,
func NewQQChannel(bc *config.Channel, cfg *config.QQSettings, messageBus *bus.MessageBus) (*QQChannel, error) {
base := channels.NewBaseChannel("qq", cfg, messageBus, bc.AllowFrom,
channels.WithMaxMessageLength(cfg.MaxMessageLength),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &QQChannel{
BaseChannel: base,
bc: bc,
config: cfg,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -161,8 +163,8 @@ func (c *QQChannel) Start(ctx context.Context) error {
// Pre-register reasoning_channel_id as group chat if configured,
// so outbound-only destinations are routed correctly.
if c.config.ReasoningChannelID != "" {
c.chatType.Store(c.config.ReasoningChannelID, "group")
if c.bc.ReasoningChannelID != "" {
c.chatType.Store(c.bc.ReasoningChannelID, "group")
}
c.SetRunning(true)
@@ -588,12 +590,22 @@ func qqFileType(partType string) uint64 {
}
func (c *QQChannel) maxBase64FileSizeBytes() int64 {
if c.config == nil {
return 0
}
if c.config.MaxBase64FileSizeMiB <= 0 {
return 0
}
return c.config.MaxBase64FileSizeMiB * bytesPerMiB
}
func (c *QQChannel) accountID() string {
if c.config == nil {
return ""
}
return c.config.AppID
}
// handleC2CMessage handles QQ private messages.
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
@@ -647,17 +659,17 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
metadata := map[string]string{
"account_id": senderID,
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.accountID(),
ChatID: senderID,
ChatType: "direct",
SenderID: senderID,
MessageID: data.ID,
Raw: metadata,
}
c.HandleMessage(c.ctx,
bus.Peer{Kind: "direct", ID: senderID},
data.ID,
senderID,
senderID,
content,
mediaPaths,
metadata,
sender,
)
c.HandleInboundContext(c.ctx, senderID, content, mediaPaths, inboundCtx, sender)
return nil
}
@@ -725,17 +737,18 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
"account_id": senderID,
"group_id": data.GroupID,
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.accountID(),
ChatID: data.GroupID,
ChatType: "group",
SenderID: senderID,
MessageID: data.ID,
Mentioned: true,
Raw: metadata,
}
c.HandleMessage(c.ctx,
bus.Peer{Kind: "group", ID: data.GroupID},
data.ID,
senderID,
data.GroupID,
content,
mediaPaths,
metadata,
sender,
)
c.HandleInboundContext(c.ctx, data.GroupID, content, mediaPaths, inboundCtx, sender)
return nil
}
+12 -5
View File
@@ -54,8 +54,8 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
if inbound.Context.Raw["account_id"] != "7750283E123456" {
t.Fatalf("account_id raw = %q, want %q", inbound.Context.Raw["account_id"], "7750283E123456")
}
return
}
@@ -165,8 +165,8 @@ func TestHandleGroupATMessage_AttachmentOnlyPublishesMedia(t *testing.T) {
if !strings.HasPrefix(inbound.Media[0], "media://") {
t.Fatalf("inbound.Media[0] = %q, want media:// ref", inbound.Media[0])
}
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-1" {
t.Fatalf("inbound.Peer = %+v, want group/group-1", inbound.Peer)
if inbound.Context.ChatType != "group" {
t.Fatalf("inbound.Context.ChatType = %q, want group", inbound.Context.ChatType)
}
}
@@ -198,6 +198,7 @@ func TestSendMedia_UploadsLocalFileAsBase64(t *testing.T) {
}
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: &config.QQSettings{},
api: api,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -294,6 +295,7 @@ func assertAudioWAVUploadType(t *testing.T, duration time.Duration, wantFileType
}
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: &config.QQSettings{},
api: api,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -329,6 +331,7 @@ func TestSendMedia_RemoteAudioFallsBackToFileUpload(t *testing.T) {
}
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: &config.QQSettings{},
api: api,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -374,6 +377,7 @@ func TestSendMedia_LocalAudioWithUnknownDurationFallsBackToFileUpload(t *testing
}
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: &config.QQSettings{},
api: api,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -409,6 +413,7 @@ func TestSendMedia_UsesRemoteURLUploadForC2C(t *testing.T) {
}
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: &config.QQSettings{},
api: api,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -481,6 +486,7 @@ func TestSendMedia_LocalFileUploadIncludesStoredFilename(t *testing.T) {
}
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: &config.QQSettings{},
api: api,
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -520,6 +526,7 @@ func TestSendMedia_ReturnsSendFailedWithoutMediaStore(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: &config.QQSettings{},
api: &fakeQQAPI{},
dedup: make(map[string]time.Time),
done: make(chan struct{}),
@@ -566,7 +573,7 @@ func TestSendMedia_ReturnsSendFailedWhenLocalFileExceedsBase64MiBLimit(t *testin
api := &fakeQQAPI{}
ch := &QQChannel{
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
config: config.QQConfig{
config: &config.QQSettings{
MaxBase64FileSizeMiB: 1,
},
api: api,
+47 -1
View File
@@ -1,6 +1,7 @@
package channels
import (
"fmt"
"sync"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -9,7 +10,9 @@ import (
// ChannelFactory is a constructor function that creates a Channel from config and message bus.
// Each channel subpackage registers one or more factories via init().
type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
// channelName is the config map key for this channel instance (may differ from the channel type).
// channelType is the channel type string used to look up the Channel config.
type ChannelFactory func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (Channel, error)
var (
factoriesMu sync.RWMutex
@@ -23,6 +26,38 @@ func RegisterFactory(name string, f ChannelFactory) {
factories[name] = f
}
// RegisterSafeFactory is a convenience wrapper that handles GetDecoded() error checking
// and type assertion, reducing boilerplate in channel init() functions.
//
// Usage:
//
// func init() {
// channels.RegisterSafeFactory(config.ChannelTelegram,
// func(bc *config.Channel, c *config.TelegramSettings, b *bus.MessageBus) (channels.Channel, error) {
// return NewTelegramChannel(bc, c, b)
// })
// }
func RegisterSafeFactory[S any](
channelType string,
ctor func(bc *config.Channel, settings *S, bus *bus.MessageBus) (Channel, error),
) {
RegisterFactory(channelType, func(channelName, _ string, cfg *config.Config, b *bus.MessageBus) (Channel, error) {
bc := cfg.Channels[channelName]
if bc == nil {
return nil, fmt.Errorf("channel %q: config not found", channelName)
}
decoded, err := bc.GetDecoded()
if err != nil {
return nil, fmt.Errorf("channel %q: failed to decode settings: %w", channelName, err)
}
settings, ok := decoded.(*S)
if !ok {
return nil, fmt.Errorf("channel %q: expected %T settings, got %T", channelName, (*S)(nil), decoded)
}
return ctor(bc, settings, b)
})
}
// getFactory looks up a channel factory by name.
func getFactory(name string) (ChannelFactory, bool) {
factoriesMu.RLock()
@@ -30,3 +65,14 @@ func getFactory(name string) (ChannelFactory, bool) {
f, ok := factories[name]
return f, ok
}
// GetRegisteredFactoryNames returns a slice of all registered channel factory names.
func GetRegisteredFactoryNames() []string {
factoriesMu.RLock()
defer factoriesMu.RUnlock()
names := make([]string, 0, len(factories))
for name := range factories {
names = append(names, name)
}
return names
}
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewSlackChannel(cfg.Channels.Slack, b)
})
channels.RegisterFactory(
config.ChannelSlack,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.SlackSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewSlackChannel(bc, c, b)
},
)
}
+92 -33
View File
@@ -21,7 +21,7 @@ import (
type SlackChannel struct {
*channels.BaseChannel
config config.SlackConfig
config *config.SlackSettings
api *slack.Client
socketClient *socketmode.Client
botUserID string
@@ -36,7 +36,11 @@ type slackMessageRef struct {
Timestamp string
}
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
func NewSlackChannel(
bc *config.Channel,
cfg *config.SlackSettings,
messageBus *bus.MessageBus,
) (*SlackChannel, error) {
if cfg.BotToken.String() == "" || cfg.AppToken.String() == "" {
return nil, fmt.Errorf("slack bot_token and app_token are required")
}
@@ -48,10 +52,10 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack
socketClient := socketmode.New(api)
base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom,
base := channels.NewBaseChannel("slack", cfg, messageBus, bc.AllowFrom,
channels.WithMaxMessageLength(40000),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &SlackChannel{
@@ -113,7 +117,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
return nil, channels.ErrNotRunning
}
channelID, threadTS := parseSlackChatID(msg.ChatID)
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget(msg.ChatID, &msg.Context)
if channelID == "" {
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
@@ -135,7 +139,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
return nil, fmt.Errorf("slack send: %w", channels.ErrTemporary)
}
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
if ref, ok := c.pendingAcks.LoadAndDelete(deliveryChatID); ok {
msgRef := ref.(slackMessageRef)
c.api.AddReaction("white_check_mark", slack.ItemRef{
Channel: msgRef.ChannelID,
@@ -157,7 +161,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
return nil, channels.ErrNotRunning
}
channelID, _ := parseSlackChatID(msg.ChatID)
_, channelID, threadTS := resolveSlackMediaOutboundTarget(msg.ChatID, &msg.Context)
if channelID == "" {
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
@@ -188,10 +192,11 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
}
_, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{
Channel: channelID,
File: localPath,
Filename: filename,
Title: title,
Channel: channelID,
ThreadTimestamp: threadTS,
File: localPath,
Filename: filename,
Title: title,
})
if err != nil {
logger.ErrorCF("slack", "Failed to upload media", map[string]any{
@@ -356,14 +361,10 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
}
peerKind := "channel"
peerID := channelID
if strings.HasPrefix(channelID, "D") {
peerKind = "direct"
peerID = senderID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
@@ -379,7 +380,22 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
"has_thread": threadTS != "",
})
c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.teamID,
ChatID: channelID,
ChatType: peerKind,
SenderID: senderID,
MessageID: messageTS,
SpaceID: c.teamID,
SpaceType: "workspace",
Raw: metadata,
}
if threadTS != "" {
inboundCtx.TopicID = threadTS
}
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
@@ -427,14 +443,10 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
}
mentionPeerKind := "channel"
mentionPeerID := channelID
if strings.HasPrefix(channelID, "D") {
mentionPeerKind = "direct"
mentionPeerID = senderID
}
mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
@@ -443,8 +455,21 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
"is_mention": "true",
"team_id": c.teamID,
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.teamID,
ChatID: channelID,
ChatType: mentionPeerKind,
TopicID: threadTS,
SenderID: senderID,
MessageID: messageTS,
SpaceID: c.teamID,
SpaceType: "workspace",
Mentioned: true,
Raw: metadata,
}
c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender)
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, mentionSender)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
@@ -491,18 +516,22 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"command": cmd.Command,
"text": utils.Truncate(content, 50),
})
peerKind := "channel"
if strings.HasPrefix(channelID, "D") {
peerKind = "direct"
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.teamID,
ChatID: channelID,
ChatType: peerKind,
SenderID: senderID,
SpaceID: c.teamID,
SpaceType: "workspace",
Raw: metadata,
}
c.HandleMessage(
c.ctx,
bus.Peer{Kind: "channel", ID: channelID},
"",
senderID,
chatID,
content,
nil,
metadata,
cmdSender,
)
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, cmdSender)
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
@@ -537,3 +566,33 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
}
return channelID, threadTS
}
func resolveSlackOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
deliveryChatID := strings.TrimSpace(chatID)
if deliveryChatID == "" && outboundCtx != nil {
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
}
channelID, threadTS := parseSlackChatID(deliveryChatID)
if threadTS == "" && outboundCtx != nil {
threadTS = strings.TrimSpace(outboundCtx.TopicID)
if threadTS != "" && channelID != "" {
deliveryChatID = channelID + "/" + threadTS
}
}
return deliveryChatID, channelID, threadTS
}
func resolveSlackMediaOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
deliveryChatID := strings.TrimSpace(chatID)
if deliveryChatID == "" && outboundCtx != nil {
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
}
channelID, threadTS := parseSlackChatID(deliveryChatID)
if threadTS == "" && outboundCtx != nil {
threadTS = strings.TrimSpace(outboundCtx.TopicID)
if threadTS != "" && channelID != "" {
deliveryChatID = channelID + "/" + threadTS
}
}
return deliveryChatID, channelID, threadTS
}
+32 -16
View File
@@ -53,6 +53,24 @@ func TestParseSlackChatID(t *testing.T) {
}
}
func TestResolveSlackOutboundTarget_PrefersContextTopicID(t *testing.T) {
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget("C123456", &bus.InboundContext{
Channel: "slack",
ChatID: "C123456",
TopicID: "1234567890.123456",
})
if deliveryChatID != "C123456/1234567890.123456" {
t.Fatalf("deliveryChatID = %q, want %q", deliveryChatID, "C123456/1234567890.123456")
}
if channelID != "C123456" {
t.Fatalf("channelID = %q, want %q", channelID, "C123456")
}
if threadTS != "1234567890.123456" {
t.Fatalf("threadTS = %q, want %q", threadTS, "1234567890.123456")
}
}
func TestStripBotMention(t *testing.T) {
ch := &SlackChannel{botUserID: "U12345BOT"}
@@ -100,32 +118,32 @@ func TestStripBotMention(t *testing.T) {
func TestNewSlackChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
bc := &config.Channel{Type: "slack", Enabled: true}
t.Run("missing bot token", func(t *testing.T) {
cfg := config.SlackConfig{}
cfg := &config.SlackSettings{}
cfg.AppToken = *config.NewSecureString("xapp-test")
_, err := NewSlackChannel(cfg, msgBus)
_, err := NewSlackChannel(bc, cfg, msgBus)
if err == nil {
t.Error("expected error for missing bot_token, got nil")
}
})
t.Run("missing app token", func(t *testing.T) {
cfg := config.SlackConfig{}
cfg := &config.SlackSettings{}
cfg.BotToken = *config.NewSecureString("xoxb-test")
_, err := NewSlackChannel(cfg, msgBus)
_, err := NewSlackChannel(bc, cfg, msgBus)
if err == nil {
t.Error("expected error for missing app_token, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.SlackConfig{
AllowFrom: []string{"U123"},
}
cfg := &config.SlackSettings{}
cfg.BotToken = *config.NewSecureString("xoxb-test")
cfg.AppToken = *config.NewSecureString("xapp-test")
ch, err := NewSlackChannel(cfg, msgBus)
bc := &config.Channel{Type: "slack", Enabled: true, AllowFrom: []string{"U123"}}
ch, err := NewSlackChannel(bc, cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -142,24 +160,22 @@ func TestSlackChannelIsAllowed(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("empty allowlist allows all", func(t *testing.T) {
cfg := config.SlackConfig{
AllowFrom: []string{},
}
bc := &config.Channel{Type: config.ChannelSlack, Enabled: true, AllowFrom: []string{}}
cfg := &config.SlackSettings{}
cfg.BotToken = *config.NewSecureString("xoxb-test")
cfg.AppToken = *config.NewSecureString("xapp-test")
ch, _ := NewSlackChannel(cfg, msgBus)
ch, _ := NewSlackChannel(bc, cfg, msgBus)
if !ch.IsAllowed("U_ANYONE") {
t.Error("empty allowlist should allow all users")
}
})
t.Run("allowlist restricts users", func(t *testing.T) {
cfg := config.SlackConfig{
AllowFrom: []string{"U_ALLOWED"},
}
bc := &config.Channel{Type: config.ChannelSlack, Enabled: true, AllowFrom: []string{"U_ALLOWED"}}
cfg := &config.SlackSettings{}
cfg.BotToken = *config.NewSecureString("xoxb-test")
cfg.AppToken = *config.NewSecureString("xapp-test")
ch, _ := NewSlackChannel(cfg, msgBus)
ch, _ := NewSlackChannel(bc, cfg, msgBus)
if !ch.IsAllowed("U_ALLOWED") {
t.Error("allowed user should pass allowlist check")
}
+22 -3
View File
@@ -7,7 +7,26 @@ import (
)
func init() {
channels.RegisterFactory("teams_webhook", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewTeamsWebhookChannel(cfg.Channels.TeamsWebhook, b)
})
channels.RegisterFactory(
config.ChannelTeamsWebHook,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.TeamsWebhookSettings)
if !ok {
return nil, channels.ErrSendFailed
}
ch, err := NewTeamsWebhookChannel(bc, c, b)
if err != nil {
return nil, err
}
if channelName != config.ChannelTeamsWebHook {
ch.SetName(channelName)
}
return ch, nil
},
)
}
+5 -2
View File
@@ -52,13 +52,15 @@ func classifyTeamsError(err error) error {
// Multiple webhook targets can be configured and selected via ChatID.
type TeamsWebhookChannel struct {
*channels.BaseChannel
config config.TeamsWebhookConfig
bc *config.Channel
config *config.TeamsWebhookSettings
client teamsMessageSender
}
// NewTeamsWebhookChannel creates a new Teams webhook channel.
func NewTeamsWebhookChannel(
cfg config.TeamsWebhookConfig,
bc *config.Channel,
cfg *config.TeamsWebhookSettings,
bus *bus.MessageBus,
) (*TeamsWebhookChannel, error) {
if len(cfg.Webhooks) == 0 {
@@ -99,6 +101,7 @@ func NewTeamsWebhookChannel(
return &TeamsWebhookChannel{
BaseChannel: base,
bc: bc,
config: cfg,
client: client,
}, nil
@@ -31,67 +31,60 @@ func TestNewTeamsWebhookChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
// Test missing webhooks
_, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
cfg := config.TeamsWebhookSettings{
Webhooks: nil,
}, msgBus)
}
_, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err == nil {
t.Error("expected error for missing webhooks")
}
// Test missing "default" webhook
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
Title: "Alerts",
},
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
Title: "Alerts",
},
}, msgBus)
}
_, err = NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err == nil {
t.Error("expected error for missing 'default' webhook")
}
// Test empty webhook URL
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {Title: "Default"},
},
}, msgBus)
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
"default": {Title: "Default"},
}
_, err = NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err == nil {
t.Error("expected error for empty webhook_url")
}
// Test HTTP URL (should fail, must be HTTPS)
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
Title: "Default",
},
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
Title: "Default",
},
}, msgBus)
}
_, err = NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err == nil {
t.Error("expected error for HTTP webhook URL (must be HTTPS)")
}
// Test valid config with HTTPS (must include "default")
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
Title: "Alerts",
},
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
}, msgBus)
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
Title: "Alerts",
},
}
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -103,14 +96,15 @@ func TestNewTeamsWebhookChannel(t *testing.T) {
func TestTeamsWebhookChannel_StartStop(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
cfg := config.TeamsWebhookSettings{
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
},
},
}, msgBus)
}
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -140,8 +134,8 @@ func TestTeamsWebhookChannel_StartStop(t *testing.T) {
func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
cfg := config.TeamsWebhookSettings{
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
@@ -152,7 +146,8 @@ func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
Title: "Custom Title",
},
},
}, msgBus)
}
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -175,14 +170,15 @@ func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
func TestTeamsWebhookChannel_SendNotRunning(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
cfg := config.TeamsWebhookSettings{
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
},
},
}, msgBus)
}
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -208,8 +204,8 @@ func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
cfg := config.TeamsWebhookSettings{
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
@@ -218,7 +214,8 @@ func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
},
},
}, msgBus)
}
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -250,8 +247,8 @@ func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
cfg := config.TeamsWebhookSettings{
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
@@ -262,7 +259,8 @@ func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
Title: "Test Alerts",
},
},
}, msgBus)
}
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -294,8 +292,8 @@ func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
func TestTeamsWebhookChannel_SendError(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
cfg := config.TeamsWebhookSettings{
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
@@ -304,7 +302,8 @@ func TestTeamsWebhookChannel_SendError(t *testing.T) {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
},
},
}, msgBus)
}
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewTelegramChannel(cfg, b)
})
channels.RegisterFactory(
config.ChannelTelegram,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.TelegramSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewTelegramChannel(bc, c, b)
},
)
}
+61 -31
View File
@@ -47,18 +47,23 @@ type TelegramChannel struct {
*channels.BaseChannel
bot *telego.Bot
bh *th.BotHandler
config *config.Config
bc *config.Channel
chatIDs map[string]int64
ctx context.Context
cancel context.CancelFunc
tgCfg *config.TelegramSettings
registerFunc func(context.Context, []commands.Definition) error
commandRegCancel context.CancelFunc
}
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
func NewTelegramChannel(
bc *config.Channel,
telegramCfg *config.TelegramSettings,
bus *bus.MessageBus,
) (*TelegramChannel, error) {
channelName := bc.Name()
var opts []telego.BotOption
telegramCfg := cfg.Channels.Telegram
if telegramCfg.Proxy != "" {
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
@@ -90,20 +95,21 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
}
base := channels.NewBaseChannel(
"telegram",
channelName,
telegramCfg,
bus,
telegramCfg.AllowFrom,
bc.AllowFrom,
channels.WithMaxMessageLength(4000),
channels.WithGroupTrigger(telegramCfg.GroupTrigger),
channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &TelegramChannel{
BaseChannel: base,
bot: bot,
config: cfg,
bc: bc,
chatIDs: make(map[string]int64),
tgCfg: telegramCfg,
}, nil
}
@@ -174,9 +180,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
return nil, channels.ErrNotRunning
}
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
useMarkdownV2 := c.tgCfg.UseMarkdownV2
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
if err != nil {
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
@@ -360,7 +366,7 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
// EditMessage implements channels.MessageEditor.
func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
useMarkdownV2 := c.tgCfg.UseMarkdownV2
cid, _, err := parseTelegramChatID(chatID)
if err != nil {
return err
@@ -435,7 +441,7 @@ func (c *TelegramChannel) DeleteMessage(ctx context.Context, chatID string, mess
// It sends a placeholder message (e.g. "Thinking... 💭") that will later be
// edited to the actual response via EditMessage (channels.MessageEditor).
func (c *TelegramChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
phCfg := c.config.Channels.Telegram.Placeholder
phCfg := c.bc.Placeholder
if !phCfg.Enabled {
return "", nil
}
@@ -463,7 +469,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
return nil, channels.ErrNotRunning
}
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
if err != nil {
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
@@ -691,8 +697,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
}
// In group chats, apply unified group trigger filtering
isMentioned := false
if message.Chat.Type != "private" {
isMentioned := c.isBotMentioned(message)
isMentioned = c.isBotMentioned(message)
if isMentioned {
content = c.stripBotMention(content)
}
@@ -738,13 +745,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
})
peerKind := "direct"
peerID := fmt.Sprintf("%d", user.ID)
if message.Chat.Type != "private" {
peerKind = "group"
peerID = compositeChatID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
messageID := fmt.Sprintf("%d", message.MessageID)
metadata := map[string]string{
@@ -753,24 +756,29 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
if message.ReplyToMessage != nil {
metadata["reply_to_message_id"] = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
}
// Set parent_peer metadata for per-topic agent binding.
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: fmt.Sprintf("%d", chatID),
ChatType: peerKind,
SenderID: platformID,
MessageID: messageID,
Mentioned: isMentioned,
Raw: metadata,
}
if message.Chat.IsForum && threadID != 0 {
metadata["parent_peer_kind"] = "topic"
metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID)
inboundCtx.TopicID = fmt.Sprintf("%d", threadID)
}
if message.ReplyToMessage != nil {
inboundCtx.ReplyToMessageID = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
}
c.HandleMessage(c.ctx,
peer,
messageID,
platformID,
c.HandleMessageWithContext(
c.ctx,
compositeChatID,
content,
mediaPaths,
metadata,
inboundCtx,
sender,
)
return nil
@@ -958,6 +966,28 @@ func parseTelegramChatID(chatID string) (int64, int, error) {
return cid, tid, nil
}
func resolveTelegramOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (int64, int, error) {
targetChatID := strings.TrimSpace(chatID)
if targetChatID == "" && outboundCtx != nil {
targetChatID = strings.TrimSpace(outboundCtx.ChatID)
}
resolvedChatID, resolvedThreadID, err := parseTelegramChatID(targetChatID)
if err != nil {
return 0, 0, err
}
if resolvedThreadID != 0 || outboundCtx == nil {
return resolvedChatID, resolvedThreadID, nil
}
topicID := strings.TrimSpace(outboundCtx.TopicID)
if topicID == "" {
return resolvedChatID, resolvedThreadID, nil
}
if threadID, convErr := strconv.Atoi(topicID); convErr == nil {
return resolvedChatID, threadID, nil
}
return resolvedChatID, resolvedThreadID, nil
}
func logParseFailed(err error, useMarkdownV2 bool) {
parsingName := "HTML"
if useMarkdownV2 {
@@ -1063,7 +1093,7 @@ func (c *TelegramChannel) stripBotMention(content string) string {
// BeginStream implements channels.StreamingCapable.
func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (channels.Streamer, error) {
if !c.config.Channels.Telegram.Streaming.Enabled {
if !c.tgCfg.Streaming.Enabled {
return nil, fmt.Errorf("streaming disabled in config")
}
@@ -1072,7 +1102,7 @@ func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (chann
return nil, err
}
streamCfg := c.config.Channels.Telegram.Streaming
streamCfg := c.tgCfg.Streaming
return &telegramStreamer{
bot: c.bot,
chatID: cid,
+44 -27
View File
@@ -140,7 +140,8 @@ func newTestChannelWithConstructor(
BaseChannel: base,
bot: bot,
chatIDs: make(map[string]int64),
config: config.DefaultConfig(),
bc: &config.Channel{Type: config.ChannelTelegram, Enabled: true},
tgCfg: &config.TelegramSettings{},
}
}
@@ -527,6 +528,38 @@ func TestSend_WithForumThreadID(t *testing.T) {
assert.Len(t, caller.calls, 1)
}
func TestSend_UsesContextTopicIDWhenChatIDDoesNotIncludeThread(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
_, err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "-1001234567890",
Content: "Hello from topic context",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "-1001234567890",
TopicID: "42",
},
})
require.NoError(t, err)
require.Len(t, caller.calls, 1)
var params struct {
ChatID int64 `json:"chat_id"`
MessageThreadID int `json:"message_thread_id"`
Text string `json:"text"`
}
require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, &params))
assert.Equal(t, int64(-1001234567890), params.ChatID)
assert.Equal(t, 42, params.MessageThreadID)
assert.Equal(t, "Hello from topic context", params.Text)
}
func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
@@ -556,16 +589,10 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID
assert.Equal(t, "-1001234567890/42", inbound.ChatID)
// Peer ID should include thread ID for session key isolation
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-1001234567890/42", inbound.Peer.ID)
// Parent peer metadata should be set for agent binding
assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"])
assert.Equal(t, "42", inbound.Metadata["parent_peer_id"])
// ChatID remains the parent chat; TopicID isolates the sub-conversation.
assert.Equal(t, "-1001234567890", inbound.ChatID)
assert.Equal(t, "group", inbound.Context.ChatType)
assert.Equal(t, "42", inbound.Context.TopicID)
}
func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
@@ -598,13 +625,8 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
// Plain chatID without thread suffix
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (no thread suffix)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
assert.Equal(t, "group", inbound.Context.ChatType)
assert.Empty(t, inbound.Context.TopicID)
}
func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
@@ -641,13 +663,8 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
// chatID should NOT include thread suffix for non-forum groups
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (shared session for whole group)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
assert.Equal(t, "group", inbound.Context.ChatType)
assert.Empty(t, inbound.Context.TopicID)
}
func assertHandleMessageQuotedUserReply(
@@ -700,7 +717,7 @@ func assertHandleMessageQuotedUserReply(
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Metadata["reply_to_message_id"])
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Context.ReplyToMessageID)
assert.Equal(t, expectedContent, inbound.Content)
}
@@ -786,7 +803,7 @@ func TestHandleMessage_ReplyToOwnBotMessage_UsesAssistantRole(t *testing.T) {
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
assert.Equal(t, "101", inbound.Metadata["reply_to_message_id"])
assert.Equal(t, "101", inbound.Context.ReplyToMessageID)
assert.Equal(
t,
"[quoted assistant message from afjcjsbx_picoclaw_bot]: Fatto! Ho creato il file notizie_2026_03_28.md\n\nti ricordi questo file?",
+10 -3
View File
@@ -7,7 +7,14 @@ import (
)
func init() {
channels.RegisterFactory("vk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewVKChannel(cfg, b)
})
channels.RegisterFactory(
config.ChannelVK,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
if bc == nil {
return nil, channels.ErrSendFailed
}
return NewVKChannel(channelName, bc, b)
},
)
}
+39 -30
View File
@@ -21,41 +21,54 @@ import (
type VKChannel struct {
*channels.BaseChannel
vk *api.VK
lp *longpoll.LongPoll
config *config.Config
ctx context.Context
cancel context.CancelFunc
vk *api.VK
lp *longpoll.LongPoll
channelName string
bc *config.Channel
ctx context.Context
cancel context.CancelFunc
}
func NewVKChannel(cfg *config.Config, bus *bus.MessageBus) (*VKChannel, error) {
vkCfg := cfg.Channels.VK
func NewVKChannel(channelName string, bc *config.Channel, bus *bus.MessageBus) (*VKChannel, error) {
var vkCfg config.VKSettings
if err := bc.Decode(&vkCfg); err != nil {
return nil, err
}
vk := api.NewVK(vkCfg.Token.String())
base := channels.NewBaseChannel(
"vk",
vkCfg,
channelName,
&vkCfg,
bus,
vkCfg.AllowFrom,
bc.AllowFrom,
channels.WithMaxMessageLength(4000),
channels.WithGroupTrigger(vkCfg.GroupTrigger),
channels.WithReasoningChannelID(vkCfg.ReasoningChannelID),
channels.WithGroupTrigger(bc.GroupTrigger),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &VKChannel{
BaseChannel: base,
vk: vk,
config: cfg,
channelName: channelName,
bc: bc,
}, nil
}
func (c *VKChannel) getVKCfg() *config.VKSettings {
var v config.VKSettings
if err := c.bc.Decode(&v); err != nil {
return nil
}
return &v
}
func (c *VKChannel) Start(ctx context.Context) error {
logger.InfoC("vk", "Starting VK bot (Long Poll mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
groupID := c.config.Channels.VK.GroupID
groupID := c.getVKCfg().GroupID
if groupID == 0 {
c.cancel()
return fmt.Errorf("group_id is required for VK bot")
@@ -143,7 +156,7 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
return
}
groupTrigger := c.config.Channels.VK.GroupTrigger
groupTrigger := c.bc.GroupTrigger
isGroupChat := peerID != fromID
if isGroupChat {
@@ -159,14 +172,11 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
_ = groupTrigger
}
peerKind := "direct"
peerIDStr := userID
chatType := "direct"
if isGroupChat {
peerKind = "group"
peerIDStr = chatID
chatType = "group"
}
peer := bus.Peer{Kind: peerKind, ID: peerIDStr}
messageID := strconv.Itoa(msg.ConversationMessageID)
metadata := map[string]string{
@@ -174,16 +184,15 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
"is_group": fmt.Sprintf("%t", isGroupChat),
}
c.HandleMessage(c.ctx,
peer,
messageID,
userID,
chatID,
text,
nil,
metadata,
sender,
)
c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{
Channel: "vk",
ChatID: chatID,
ChatType: chatType,
SenderID: userID,
MessageID: messageID,
Mentioned: isGroupChat && c.isMentioned(msg),
Raw: metadata,
}, sender)
}
func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
+54 -62
View File
@@ -1,6 +1,7 @@
package vk
import (
"encoding/json"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -8,19 +9,23 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
)
func makeVKTestBaseChannel(vkCfg config.VKSettings) *config.Channel {
settings, _ := json.Marshal(vkCfg)
return &config.Channel{
Enabled: true,
Type: config.ChannelVK,
Settings: settings,
}
}
func TestNewVKChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing group_id", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
bc := makeVKTestBaseChannel(config.VKSettings{
Token: *config.NewSecureString("test_token"),
})
ch, err := NewVKChannel("vk", bc, msgBus)
if err != nil {
t.Fatalf("unexpected error during creation: %v", err)
}
@@ -33,16 +38,11 @@ func TestNewVKChannel(t *testing.T) {
})
t.Run("valid config with group_id", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
bc := makeVKTestBaseChannel(config.VKSettings{
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
})
ch, err := NewVKChannel("vk", bc, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -55,17 +55,18 @@ func TestNewVKChannel(t *testing.T) {
})
t.Run("with allow_from", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
AllowFrom: []string{"123456789"},
},
},
vkCfg := config.VKSettings{
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
}
ch, err := NewVKChannel(cfg, msgBus)
settings, _ := json.Marshal(vkCfg)
bc := &config.Channel{
Enabled: true,
Type: "vk",
AllowFrom: []string{"123456789"},
Settings: settings,
}
ch, err := NewVKChannel("vk", bc, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -78,20 +79,21 @@ func TestNewVKChannel(t *testing.T) {
})
t.Run("with group_trigger", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
GroupTrigger: config.GroupTriggerConfig{
MentionOnly: false,
Prefixes: []string{"/bot", "!bot"},
},
},
},
vkCfg := config.VKSettings{
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
}
ch, err := NewVKChannel(cfg, msgBus)
settings, _ := json.Marshal(vkCfg)
bc := &config.Channel{
Enabled: true,
Type: "vk",
GroupTrigger: config.GroupTriggerConfig{
MentionOnly: false,
Prefixes: []string{"/bot", "!bot"},
},
Settings: settings,
}
ch, err := NewVKChannel("vk", bc, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -103,16 +105,11 @@ func TestNewVKChannel(t *testing.T) {
func TestVKChannel_MaxMessageLength(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
bc := makeVKTestBaseChannel(config.VKSettings{
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
})
ch, err := NewVKChannel("vk", bc, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -236,16 +233,11 @@ func TestVKChannel_ProcessAttachments(t *testing.T) {
func TestVKChannel_VoiceCapabilities(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
bc := makeVKTestBaseChannel(config.VKSettings{
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
})
ch, err := NewVKChannel("vk", bc, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewChannel(cfg.Channels.WeCom, b)
})
channels.RegisterFactory(
config.ChannelWeCom,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.WeComSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewChannel(bc, c, b)
},
)
}
+18 -6
View File
@@ -34,7 +34,7 @@ const (
type WeComChannel struct {
*channels.BaseChannel
config config.WeComConfig
config *config.WeComSettings
ctx context.Context
cancel context.CancelFunc
@@ -108,7 +108,7 @@ func (s *recentMessageSet) Mark(id string) bool {
return true
}
func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChannel, error) {
func NewChannel(bc *config.Channel, cfg *config.WeComSettings, messageBus *bus.MessageBus) (*WeComChannel, error) {
if cfg.BotID == "" || cfg.Secret.String() == "" {
return nil, fmt.Errorf("wecom bot_id and secret are required")
}
@@ -120,8 +120,8 @@ func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChann
"wecom",
cfg,
messageBus,
cfg.AllowFrom,
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
bc.AllowFrom,
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
ch := &WeComChannel{
@@ -570,7 +570,6 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
return err
}
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
metadata := map[string]string{
"channel": "wecom",
"req_id": reqID,
@@ -583,7 +582,20 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
metadata["quote_text"] = quoteText
}
c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: strings.TrimSpace(msg.AIBotID),
ChatID: actualChatID,
ChatType: peerKind,
SenderID: senderID,
MessageID: msg.MsgID,
ReplyHandles: map[string]string{
"req_id": reqID,
},
Raw: metadata,
}
c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender)
return nil
}
+7 -6
View File
@@ -50,11 +50,11 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
if inbound.MessageID != "msg-1" {
t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID)
}
if inbound.Peer.ID != "chat-1" {
t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID)
if inbound.Context.ChatType != "direct" {
t.Fatalf("inbound Context.ChatType = %q, want direct", inbound.Context.ChatType)
}
if inbound.Metadata["req_id"] != "req-1" {
t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"])
if inbound.Context.ReplyHandles["req_id"] != "req-1" {
t.Fatalf("inbound req_id = %q, want req-1", inbound.Context.ReplyHandles["req_id"])
}
default:
t.Fatal("expected inbound message to be published")
@@ -605,9 +605,10 @@ func TestSendMedia_SendsActiveFile(t *testing.T) {
func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel {
t.Helper()
cfg := config.WeComConfig{BotID: "bot-1"}
cfg := &config.WeComSettings{BotID: "bot-1"}
cfg.SetSecret("secret-1")
ch, err := NewChannel(cfg, messageBus)
bc := &config.Channel{Type: config.ChannelWeCom, Enabled: true}
ch, err := NewChannel(bc, cfg, messageBus)
if err != nil {
t.Fatalf("NewChannel() error = %v", err)
}
+3 -3
View File
@@ -44,7 +44,7 @@ func picoclawHomeDir() string {
return config.GetHome()
}
func genWeixinAccountKey(cfg config.WeixinConfig) string {
func genWeixinAccountKey(cfg *config.WeixinSettings) string {
token := strings.TrimSpace(cfg.Token.String())
if token == "" {
return "default"
@@ -53,11 +53,11 @@ func genWeixinAccountKey(cfg config.WeixinConfig) string {
return hex.EncodeToString(sum[:8])
}
func buildWeixinSyncBufPath(cfg config.WeixinConfig) string {
func buildWeixinSyncBufPath(cfg *config.WeixinSettings) string {
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "sync", genWeixinAccountKey(cfg)+".json")
}
func buildWeixinContextTokensPath(cfg config.WeixinConfig) string {
func buildWeixinContextTokensPath(cfg *config.WeixinSettings) string {
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "context-tokens", genWeixinAccountKey(cfg)+".json")
}
+46 -11
View File
@@ -20,7 +20,7 @@ import (
type WeixinChannel struct {
*channels.BaseChannel
api *ApiClient
config config.WeixinConfig
config *config.WeixinSettings
ctx context.Context
cancel context.CancelFunc
bus *bus.MessageBus
@@ -36,25 +36,48 @@ type WeixinChannel struct {
}
func init() {
channels.RegisterFactory("weixin", func(cfg *config.Config, bus *bus.MessageBus) (channels.Channel, error) {
return NewWeixinChannel(cfg.Channels.Weixin, bus)
})
channels.RegisterFactory(
config.ChannelWeixin,
func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
weixinCfg, ok := decoded.(*config.WeixinSettings)
if !ok {
return nil, channels.ErrSendFailed
}
ch, err := NewWeixinChannel(bc, weixinCfg, bus)
if err != nil {
return nil, err
}
if channelName != config.ChannelWeixin {
ch.SetName(channelName)
}
return ch, nil
},
)
}
// NewWeixinChannel creates a new WeixinChannel from config.
func NewWeixinChannel(cfg config.WeixinConfig, messageBus *bus.MessageBus) (*WeixinChannel, error) {
func NewWeixinChannel(
bc *config.Channel,
cfg *config.WeixinSettings,
messageBus *bus.MessageBus,
) (*WeixinChannel, error) {
api, err := NewApiClient(cfg.BaseURL, cfg.Token.String(), cfg.Proxy)
if err != nil {
return nil, fmt.Errorf("weixin: failed to create API client: %w", err)
}
base := channels.NewBaseChannel(
"weixin",
bc.Name(),
cfg,
messageBus,
cfg.AllowFrom,
bc.AllowFrom,
channels.WithMaxMessageLength(4000),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &WeixinChannel{
@@ -334,8 +357,6 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
return
}
peer := bus.Peer{Kind: "direct", ID: fromUserID}
metadata := map[string]string{
"from_user_id": fromUserID,
"context_token": msg.ContextToken,
@@ -354,7 +375,21 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
c.persistContextTokens()
}
c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "weixin",
ChatID: fromUserID,
ChatType: "direct",
SenderID: fromUserID,
MessageID: messageID,
Raw: metadata,
}
if msg.ContextToken != "" {
inboundCtx.ReplyHandles = map[string]string{
"context_token": msg.ContextToken,
}
}
c.HandleInboundContext(ctx, fromUserID, content, mediaRefs, inboundCtx, sender)
}
// Send implements channels.Channel by sending a text message to the WeChat user.
+5 -5
View File
@@ -66,7 +66,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
}, nil
})},
},
config: config.WeixinConfig{
config: &config.WeixinSettings{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
@@ -105,7 +105,7 @@ func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) {
return nil, nil
})},
},
config: config.WeixinConfig{
config: &config.WeixinSettings{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
@@ -155,7 +155,7 @@ func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t
}, nil
})},
},
config: config.WeixinConfig{
config: &config.WeixinSettings{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
@@ -224,7 +224,7 @@ func TestUploadBufferToCDN(t *testing.T) {
}, nil
})},
},
config: config.WeixinConfig{
config: &config.WeixinSettings{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
@@ -259,7 +259,7 @@ func TestBuildWeixinSyncBufPathUsesPicoclawHome(t *testing.T) {
home := t.TempDir()
t.Setenv(config.EnvHome, home)
wxCfg := config.WeixinConfig{
wxCfg := &config.WeixinSettings{
BaseURL: "https://ilinkai.weixin.qq.com/",
}
wxCfg.SetToken("token-123")
+15 -3
View File
@@ -7,7 +7,19 @@ import (
)
func init() {
channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewWhatsAppChannel(cfg.Channels.WhatsApp, b)
})
channels.RegisterFactory(
config.ChannelWhatsApp,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.WhatsAppSettings)
if !ok {
return nil, channels.ErrSendFailed
}
return NewWhatsAppChannel(bc, c, b)
},
)
}
+22 -12
View File
@@ -20,7 +20,7 @@ import (
type WhatsAppChannel struct {
*channels.BaseChannel
conn *websocket.Conn
config config.WhatsAppConfig
config *config.WhatsAppSettings
url string
ctx context.Context
cancel context.CancelFunc
@@ -28,14 +28,18 @@ type WhatsAppChannel struct {
connected bool
}
func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) {
func NewWhatsAppChannel(
bc *config.Channel,
cfg *config.WhatsAppSettings,
bus *bus.MessageBus,
) (*WhatsAppChannel, error) {
base := channels.NewBaseChannel(
"whatsapp",
cfg,
bus,
cfg.AllowFrom,
bc.AllowFrom,
channels.WithMaxMessageLength(65536),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
channels.WithReasoningChannelID(bc.ReasoningChannelID),
)
return &WhatsAppChannel{
@@ -223,13 +227,6 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
metadata["user_name"] = userName
}
var peer bus.Peer
if chatID == senderID {
peer = bus.Peer{Kind: "direct", ID: senderID}
} else {
peer = bus.Peer{Kind: "group", ID: chatID}
}
logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
"sender": senderID,
"preview": utils.Truncate(content, 50),
@@ -248,5 +245,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
return
}
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "whatsapp",
ChatID: chatID,
SenderID: senderID,
MessageID: messageID,
Raw: metadata,
}
if chatID == senderID {
inboundCtx.ChatType = "direct"
} else {
inboundCtx.ChatType = "group"
}
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
}
@@ -12,7 +12,7 @@ import (
func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &WhatsAppChannel{
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil),
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppSettings{}, messageBus, nil),
ctx: context.Background(),
}
+23 -8
View File
@@ -9,12 +9,27 @@ import (
)
func init() {
channels.RegisterFactory("whatsapp_native", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
waCfg := cfg.Channels.WhatsApp
storePath := waCfg.SessionStorePath
if storePath == "" {
storePath = filepath.Join(cfg.WorkspacePath(), "whatsapp")
}
return NewWhatsAppNativeChannel(waCfg, b, storePath)
})
channels.RegisterFactory(
config.ChannelWhatsAppNative,
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
bc := cfg.Channels[channelName]
decoded, err := bc.GetDecoded()
if err != nil {
return nil, err
}
c, ok := decoded.(*config.WhatsAppSettings)
if !ok {
return nil, channels.ErrSendFailed
}
storePath := c.SessionStorePath
if storePath == "" {
storePath = filepath.Join(cfg.WorkspacePath(), "whatsapp")
}
ch, err := NewWhatsAppNativeChannel(bc, channelName, c, b, storePath)
if err != nil {
return nil, err
}
return ch, nil
},
)
}
@@ -20,7 +20,7 @@ import (
func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &WhatsAppNativeChannel{
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil),
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppSettings{}, messageBus, nil),
runCtx: context.Background(),
}
@@ -48,7 +48,7 @@ const (
// WhatsAppNativeChannel implements the WhatsApp channel using whatsmeow (in-process, no external bridge).
type WhatsAppNativeChannel struct {
*channels.BaseChannel
config config.WhatsAppConfig
config *config.WhatsAppSettings
storePath string
client *whatsmeow.Client
container *sqlstore.Container
@@ -64,11 +64,13 @@ type WhatsAppNativeChannel struct {
// NewWhatsAppNativeChannel creates a WhatsApp channel that uses whatsmeow for connection.
// storePath is the directory for the SQLite session store (e.g. workspace/whatsapp).
func NewWhatsAppNativeChannel(
cfg config.WhatsAppConfig,
bc *config.Channel,
name string,
cfg *config.WhatsAppSettings,
bus *bus.MessageBus,
storePath string,
) (channels.Channel, error) {
base := channels.NewBaseChannel("whatsapp_native", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(65536))
base := channels.NewBaseChannel(name, cfg, bus, bc.AllowFrom, channels.WithMaxMessageLength(65536))
if storePath == "" {
storePath = "whatsapp"
}
@@ -375,7 +377,6 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
if evt.Info.Chat.Server == types.GroupServer {
peerKind = "group"
}
peer := bus.Peer{Kind: peerKind, ID: chatID}
messageID := evt.Info.ID
sender := bus.SenderInfo{
Platform: "whatsapp",
@@ -393,7 +394,17 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
"WhatsApp message received",
map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)},
)
c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "whatsapp",
ChatID: chatID,
SenderID: senderID,
MessageID: messageID,
ChatType: peerKind,
Raw: metadata,
}
c.HandleInboundContext(c.runCtx, chatID, content, mediaPaths, inboundCtx, sender)
}
func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {

Some files were not shown because too many files have changed in this diff Show More