mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge remote-tracking branch 'origin/main' into refactor/line-sdk
# Conflicts: # pkg/channels/line/line.go
This commit is contained in:
+54
-25
@@ -685,43 +685,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
|
||||
// tool result messages following it. This is required by strict providers
|
||||
// like DeepSeek that enforce: "An assistant message with 'tool_calls' must
|
||||
// be followed by tool messages responding to each 'tool_call_id'."
|
||||
//
|
||||
// Deduplication is scoped to the contiguous tool-result block that follows a
|
||||
// single assistant tool-call message. Some providers legitimately reuse call
|
||||
// IDs across separate turns (for example "call_0"), so global deduplication
|
||||
// would incorrectly delete later valid tool results and leave an
|
||||
// assistant(tool_calls) -> assistant sequence behind.
|
||||
final := make([]providers.Message, 0, len(sanitized))
|
||||
seenToolCallID := make(map[string]bool)
|
||||
for i := 0; i < len(sanitized); i++ {
|
||||
msg := sanitized[i]
|
||||
|
||||
// Deduplicate tool results by ToolCallID
|
||||
if msg.Role == "tool" && msg.ToolCallID != "" {
|
||||
if seenToolCallID[msg.ToolCallID] {
|
||||
logger.DebugCF("agent", "Dropping duplicate tool result", map[string]any{
|
||||
"tool_call_id": msg.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
seenToolCallID[msg.ToolCallID] = true
|
||||
}
|
||||
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
// Collect expected tool_call IDs
|
||||
expected := make(map[string]bool, len(msg.ToolCalls))
|
||||
invalidToolCallID := false
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.ID == "" {
|
||||
invalidToolCallID = true
|
||||
continue
|
||||
}
|
||||
expected[tc.ID] = false
|
||||
}
|
||||
|
||||
// Check following messages for matching tool results
|
||||
toolMsgCount := 0
|
||||
for j := i + 1; j < len(sanitized); j++ {
|
||||
if sanitized[j].Role != "tool" {
|
||||
block := make([]providers.Message, 0, len(expected))
|
||||
seenInBlock := make(map[string]bool, len(expected))
|
||||
j := i + 1
|
||||
for ; j < len(sanitized); j++ {
|
||||
next := sanitized[j]
|
||||
if next.Role != "tool" {
|
||||
break
|
||||
}
|
||||
toolMsgCount++
|
||||
if _, exists := expected[sanitized[j].ToolCallID]; exists {
|
||||
expected[sanitized[j].ToolCallID] = true
|
||||
if next.ToolCallID == "" {
|
||||
logger.DebugCF("agent", "Dropping tool result without tool_call_id", map[string]any{})
|
||||
continue
|
||||
}
|
||||
if _, ok := expected[next.ToolCallID]; !ok {
|
||||
logger.DebugCF("agent", "Dropping unexpected tool result", map[string]any{
|
||||
"tool_call_id": next.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if seenInBlock[next.ToolCallID] {
|
||||
logger.DebugCF("agent", "Dropping duplicate tool result in tool block", map[string]any{
|
||||
"tool_call_id": next.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
seenInBlock[next.ToolCallID] = true
|
||||
expected[next.ToolCallID] = true
|
||||
block = append(block, next)
|
||||
}
|
||||
|
||||
// If any tool_call_id is missing, drop this assistant message and its partial tool messages
|
||||
allFound := true
|
||||
allFound := !invalidToolCallID
|
||||
if invalidToolCallID {
|
||||
logger.DebugCF("agent", "Dropping assistant message with empty tool_call_id", map[string]any{})
|
||||
}
|
||||
for toolCallID, found := range expected {
|
||||
if !found {
|
||||
allFound = false
|
||||
@@ -731,7 +748,7 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
|
||||
map[string]any{
|
||||
"missing_tool_call_id": toolCallID,
|
||||
"expected_count": len(expected),
|
||||
"found_count": toolMsgCount,
|
||||
"found_count": len(block),
|
||||
},
|
||||
)
|
||||
break
|
||||
@@ -739,11 +756,23 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
|
||||
}
|
||||
|
||||
if !allFound {
|
||||
// Skip this assistant message and its tool messages
|
||||
i += toolMsgCount
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
|
||||
final = append(final, msg)
|
||||
final = append(final, block...)
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.Role == "tool" {
|
||||
logger.DebugCF("agent", "Dropping orphaned tool message after validation", map[string]any{
|
||||
"tool_call_id": msg.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
final = append(final, msg)
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) e
|
||||
if result, ok := m.forceCompression(req.SessionKey); ok {
|
||||
m.al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
m.al.newTurnEventScope("", req.SessionKey).meta(0, "forceCompression", "turn.context.compress"),
|
||||
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: req.Reason,
|
||||
DroppedMessages: result.DroppedMessages,
|
||||
@@ -61,6 +61,16 @@ func (m *legacyContextManager) Ingest(_ context.Context, _ *IngestRequest) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *legacyContextManager) Clear(_ context.Context, sessionKey string) error {
|
||||
agent := m.al.registry.GetDefaultAgent()
|
||||
if agent == nil || agent.Sessions == nil {
|
||||
return fmt.Errorf("sessions not initialized")
|
||||
}
|
||||
agent.Sessions.SetHistory(sessionKey, []providers.Message{})
|
||||
agent.Sessions.SetSummary(sessionKey, "")
|
||||
return agent.Sessions.Save(sessionKey)
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
// It runs asynchronously in a goroutine.
|
||||
func (m *legacyContextManager) maybeSummarize(sessionKey string) {
|
||||
@@ -237,7 +247,7 @@ func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey
|
||||
agent.Sessions.Save(sessionKey)
|
||||
m.al.emitEvent(
|
||||
EventKindSessionSummarize,
|
||||
m.al.newTurnEventScope(agent.ID, sessionKey).meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
SessionSummarizePayload{
|
||||
SummarizedMessages: len(validMessages),
|
||||
KeptMessages: keepCount,
|
||||
|
||||
@@ -24,6 +24,10 @@ type ContextManager interface {
|
||||
// Ingest records a message into the ContextManager's own storage.
|
||||
// Called after each message is persisted to session JSONL.
|
||||
Ingest(ctx context.Context, req *IngestRequest) error
|
||||
|
||||
// Clear removes all stored context for a session (messages, summaries, etc.).
|
||||
// Called when the user issues /clear or /reset.
|
||||
Clear(ctx context.Context, sessionKey string) error
|
||||
}
|
||||
|
||||
// AssembleRequest is the input to Assemble.
|
||||
|
||||
@@ -690,6 +690,7 @@ func (m *noopContextManager) Assemble(_ context.Context, req *AssembleRequest) (
|
||||
}
|
||||
func (m *noopContextManager) Compact(_ context.Context, _ *CompactRequest) error { return nil }
|
||||
func (m *noopContextManager) Ingest(_ context.Context, _ *IngestRequest) error { return nil }
|
||||
func (m *noopContextManager) Clear(_ context.Context, _ string) error { return nil }
|
||||
|
||||
// trackingContextManager tracks call counts for each method.
|
||||
type trackingContextManager struct {
|
||||
@@ -726,6 +727,8 @@ func (m *trackingContextManager) Ingest(_ context.Context, req *IngestRequest) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *trackingContextManager) Clear(_ context.Context, _ string) error { return nil }
|
||||
|
||||
// resetCMRegistry clears the global factory registry and returns a cleanup
|
||||
// function that restores the original state after the test.
|
||||
func resetCMRegistry() func() {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !mipsle && !netbsd
|
||||
//go:build !mipsle && !netbsd && !(freebsd && arm)
|
||||
|
||||
package agent
|
||||
|
||||
@@ -154,6 +154,19 @@ func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest)
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear removes all stored context for a session (seahorse DB + JSONL).
|
||||
func (m *seahorseContextManager) Clear(ctx context.Context, sessionKey string) error {
|
||||
if err := m.engine.ClearSession(ctx, sessionKey); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.sessions != nil {
|
||||
m.sessions.SetHistory(sessionKey, []providers.Message{})
|
||||
m.sessions.SetSummary(sessionKey, "")
|
||||
return m.sessions.Save(sessionKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
|
||||
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
|
||||
if m.sessions == nil {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build mipsle || netbsd
|
||||
//go:build mipsle || netbsd || (freebsd && arm)
|
||||
|
||||
package agent
|
||||
|
||||
|
||||
@@ -213,6 +213,47 @@ func TestSanitizeHistoryForProvider_DuplicateToolResults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeHistoryForProvider_ReusedToolCallIDAcrossRounds(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msg("user", "first"),
|
||||
assistantWithTools("call_0"),
|
||||
toolResult("call_0"),
|
||||
msg("assistant", "first done"),
|
||||
msg("user", "second"),
|
||||
assistantWithTools("call_0"),
|
||||
toolResult("call_0"),
|
||||
msg("assistant", "second done"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
if len(result) != 8 {
|
||||
t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "assistant", "tool", "assistant")
|
||||
if result[2].ToolCallID != "call_0" || result[6].ToolCallID != "call_0" {
|
||||
t.Fatalf(
|
||||
"expected both tool results to be preserved, got IDs %q and %q",
|
||||
result[2].ToolCallID,
|
||||
result[6].ToolCallID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeHistoryForProvider_DropsAssistantWithEmptyToolCallID(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msg("user", "do something"),
|
||||
assistantWithTools(""),
|
||||
toolResult(""),
|
||||
msg("assistant", "done"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "assistant")
|
||||
}
|
||||
|
||||
func roles(msgs []providers.Message) []string {
|
||||
r := make([]string, len(msgs))
|
||||
for i, m := range msgs {
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// DispatchRequest is the normalized runtime input passed into the agent loop
|
||||
// after routing and session allocation have completed.
|
||||
type DispatchRequest struct {
|
||||
SessionKey string
|
||||
SessionAliases []string
|
||||
InboundContext *bus.InboundContext
|
||||
RouteResult *routing.ResolvedRoute
|
||||
SessionScope *session.SessionScope
|
||||
UserMessage string
|
||||
Media []string
|
||||
}
|
||||
|
||||
func (r DispatchRequest) Channel() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.Channel
|
||||
}
|
||||
|
||||
func (r DispatchRequest) ChatID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.ChatID
|
||||
}
|
||||
|
||||
func (r DispatchRequest) MessageID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.MessageID
|
||||
}
|
||||
|
||||
func (r DispatchRequest) ReplyToMessageID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.ReplyToMessageID
|
||||
}
|
||||
|
||||
func (r DispatchRequest) SenderID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.SenderID
|
||||
}
|
||||
|
||||
func normalizeProcessOptionsInPlace(opts *processOptions) {
|
||||
if opts == nil {
|
||||
return
|
||||
}
|
||||
*opts = normalizeProcessOptions(*opts)
|
||||
}
|
||||
|
||||
func normalizeProcessOptions(opts processOptions) processOptions {
|
||||
if opts.Dispatch.SessionKey == "" {
|
||||
opts.Dispatch.SessionKey = strings.TrimSpace(opts.SessionKey)
|
||||
}
|
||||
if len(opts.Dispatch.SessionAliases) == 0 && len(opts.SessionAliases) > 0 {
|
||||
opts.Dispatch.SessionAliases = append([]string(nil), opts.SessionAliases...)
|
||||
}
|
||||
if opts.Dispatch.UserMessage == "" {
|
||||
opts.Dispatch.UserMessage = opts.UserMessage
|
||||
}
|
||||
if len(opts.Dispatch.Media) == 0 && len(opts.Media) > 0 {
|
||||
opts.Dispatch.Media = append([]string(nil), opts.Media...)
|
||||
}
|
||||
if opts.Dispatch.RouteResult == nil {
|
||||
opts.Dispatch.RouteResult = cloneResolvedRoute(opts.RouteResult)
|
||||
}
|
||||
if opts.Dispatch.SessionScope == nil {
|
||||
opts.Dispatch.SessionScope = session.CloneScope(opts.SessionScope)
|
||||
}
|
||||
if opts.Dispatch.InboundContext == nil {
|
||||
if opts.InboundContext != nil {
|
||||
opts.Dispatch.InboundContext = cloneInboundContext(opts.InboundContext)
|
||||
} else if opts.Channel != "" || opts.ChatID != "" || opts.SenderID != "" ||
|
||||
opts.MessageID != "" || opts.ReplyToMessageID != "" {
|
||||
inbound := bus.InboundContext{
|
||||
Channel: strings.TrimSpace(opts.Channel),
|
||||
ChatID: strings.TrimSpace(opts.ChatID),
|
||||
SenderID: strings.TrimSpace(opts.SenderID),
|
||||
MessageID: strings.TrimSpace(opts.MessageID),
|
||||
ReplyToMessageID: strings.TrimSpace(opts.ReplyToMessageID),
|
||||
}
|
||||
inbound.ChatType = inferChatTypeFromSessionScope(opts.Dispatch.SessionScope)
|
||||
if inbound.Channel != "" || inbound.ChatID != "" || inbound.SenderID != "" ||
|
||||
inbound.MessageID != "" || inbound.ReplyToMessageID != "" {
|
||||
inbound = bus.NormalizeInboundMessage(bus.InboundMessage{Context: inbound}).Context
|
||||
opts.Dispatch.InboundContext = &inbound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep legacy mirrors populated while the rest of the runtime migrates.
|
||||
opts.SessionKey = opts.Dispatch.SessionKey
|
||||
opts.SessionAliases = append([]string(nil), opts.Dispatch.SessionAliases...)
|
||||
opts.UserMessage = opts.Dispatch.UserMessage
|
||||
opts.Media = append([]string(nil), opts.Dispatch.Media...)
|
||||
opts.InboundContext = cloneInboundContext(opts.Dispatch.InboundContext)
|
||||
opts.RouteResult = cloneResolvedRoute(opts.Dispatch.RouteResult)
|
||||
opts.SessionScope = session.CloneScope(opts.Dispatch.SessionScope)
|
||||
if opts.InboundContext != nil {
|
||||
if opts.Channel == "" {
|
||||
opts.Channel = opts.InboundContext.Channel
|
||||
}
|
||||
if opts.ChatID == "" {
|
||||
opts.ChatID = opts.InboundContext.ChatID
|
||||
}
|
||||
if opts.MessageID == "" {
|
||||
opts.MessageID = opts.InboundContext.MessageID
|
||||
}
|
||||
if opts.ReplyToMessageID == "" {
|
||||
opts.ReplyToMessageID = opts.InboundContext.ReplyToMessageID
|
||||
}
|
||||
if opts.SenderID == "" {
|
||||
opts.SenderID = opts.InboundContext.SenderID
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func inferChatTypeFromSessionScope(scope *session.SessionScope) string {
|
||||
if scope == nil || len(scope.Values) == 0 {
|
||||
return ""
|
||||
}
|
||||
chatValue := strings.TrimSpace(scope.Values["chat"])
|
||||
if chatValue == "" {
|
||||
return ""
|
||||
}
|
||||
chatType, _, ok := strings.Cut(chatValue, ":")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(chatType))
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
func TestNormalizeProcessOptions_PopulatesDispatchFromLegacyFields(t *testing.T) {
|
||||
opts := normalizeProcessOptions(processOptions{
|
||||
SessionKey: "session-1",
|
||||
SessionAliases: []string{"legacy:one"},
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-1",
|
||||
MessageID: "msg-1",
|
||||
ReplyToMessageID: "reply-1",
|
||||
SenderID: "user-1",
|
||||
UserMessage: "hello",
|
||||
Media: []string{"media://one"},
|
||||
})
|
||||
|
||||
if opts.Dispatch.SessionKey != "session-1" {
|
||||
t.Fatalf("Dispatch.SessionKey = %q, want session-1", opts.Dispatch.SessionKey)
|
||||
}
|
||||
if len(opts.Dispatch.SessionAliases) != 1 || opts.Dispatch.SessionAliases[0] != "legacy:one" {
|
||||
t.Fatalf("Dispatch.SessionAliases = %v, want [legacy:one]", opts.Dispatch.SessionAliases)
|
||||
}
|
||||
if opts.Dispatch.Channel() != "telegram" || opts.Dispatch.ChatID() != "chat-1" {
|
||||
t.Fatalf(
|
||||
"dispatch addressing = (%q,%q), want (telegram,chat-1)",
|
||||
opts.Dispatch.Channel(),
|
||||
opts.Dispatch.ChatID(),
|
||||
)
|
||||
}
|
||||
if opts.Dispatch.SenderID() != "user-1" || opts.Dispatch.MessageID() != "msg-1" {
|
||||
t.Fatalf("dispatch sender/message = (%q,%q)", opts.Dispatch.SenderID(), opts.Dispatch.MessageID())
|
||||
}
|
||||
if opts.Dispatch.ReplyToMessageID() != "reply-1" {
|
||||
t.Fatalf("Dispatch.ReplyToMessageID() = %q, want reply-1", opts.Dispatch.ReplyToMessageID())
|
||||
}
|
||||
if opts.Dispatch.UserMessage != "hello" {
|
||||
t.Fatalf("Dispatch.UserMessage = %q, want hello", opts.Dispatch.UserMessage)
|
||||
}
|
||||
if len(opts.Dispatch.Media) != 1 || opts.Dispatch.Media[0] != "media://one" {
|
||||
t.Fatalf("Dispatch.Media = %v, want [media://one]", opts.Dispatch.Media)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProcessOptions_UsesDispatchAsSourceOfTruth(t *testing.T) {
|
||||
inbound := &bus.InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C123",
|
||||
ChatType: "channel",
|
||||
SenderID: "U123",
|
||||
MessageID: "m-1",
|
||||
ReplyToMessageID: "parent-1",
|
||||
}
|
||||
route := &routing.ResolvedRoute{
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
AccountID: "workspace-a",
|
||||
MatchedBy: "dispatch.rule:test",
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"chat", "sender"},
|
||||
},
|
||||
}
|
||||
scope := &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "channel:c123",
|
||||
},
|
||||
}
|
||||
|
||||
opts := normalizeProcessOptions(processOptions{
|
||||
Dispatch: DispatchRequest{
|
||||
SessionKey: "sk_v1_example",
|
||||
SessionAliases: []string{"agent:support:slack:channel:c123"},
|
||||
InboundContext: inbound,
|
||||
RouteResult: route,
|
||||
SessionScope: scope,
|
||||
UserMessage: "hello",
|
||||
Media: []string{"media://one"},
|
||||
},
|
||||
})
|
||||
|
||||
if opts.SessionKey != "sk_v1_example" {
|
||||
t.Fatalf("SessionKey = %q, want sk_v1_example", opts.SessionKey)
|
||||
}
|
||||
if opts.Channel != "slack" || opts.ChatID != "C123" {
|
||||
t.Fatalf("legacy mirrors = (%q,%q), want (slack,C123)", opts.Channel, opts.ChatID)
|
||||
}
|
||||
if opts.SenderID != "U123" || opts.MessageID != "m-1" {
|
||||
t.Fatalf("legacy sender/message = (%q,%q)", opts.SenderID, opts.MessageID)
|
||||
}
|
||||
if opts.ReplyToMessageID != "parent-1" {
|
||||
t.Fatalf("ReplyToMessageID = %q, want parent-1", opts.ReplyToMessageID)
|
||||
}
|
||||
if opts.RouteResult == nil || opts.RouteResult.AgentID != "support" {
|
||||
t.Fatalf("RouteResult = %#v, want support route", opts.RouteResult)
|
||||
}
|
||||
if opts.SessionScope == nil || opts.SessionScope.AgentID != "support" {
|
||||
t.Fatalf("SessionScope = %#v, want support scope", opts.SessionScope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProcessOptions_InfersLegacyChatTypeFromSessionScope(t *testing.T) {
|
||||
opts := normalizeProcessOptions(processOptions{
|
||||
Channel: "telegram",
|
||||
ChatID: "-100123",
|
||||
SenderID: "user-1",
|
||||
UserMessage: "hello",
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "group:-100123",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if opts.Dispatch.InboundContext == nil {
|
||||
t.Fatal("Dispatch.InboundContext is nil")
|
||||
}
|
||||
if opts.Dispatch.InboundContext.ChatType != "group" {
|
||||
t.Fatalf("Dispatch.InboundContext.ChatType = %q, want group", opts.Dispatch.InboundContext.ChatType)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -136,6 +138,31 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "tester",
|
||||
},
|
||||
RouteResult: &routing.ResolvedRoute{
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
AccountID: routing.DefaultAccountID,
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
MatchedBy: "default",
|
||||
},
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
Account: routing.DefaultAccountID,
|
||||
Dimensions: []string{"sender"},
|
||||
Values: map[string]string{
|
||||
"sender": "tester",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
@@ -176,6 +203,18 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
if evt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
|
||||
}
|
||||
if evt.Context == nil || evt.Context.Inbound == nil {
|
||||
t.Fatalf("event %d missing inbound turn context", i)
|
||||
}
|
||||
if evt.Context.Inbound.Channel != "cli" || evt.Context.Inbound.SenderID != "tester" {
|
||||
t.Fatalf("event %d inbound context = %+v", i, evt.Context.Inbound)
|
||||
}
|
||||
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
|
||||
t.Fatalf("event %d missing route context: %+v", i, evt.Context.Route)
|
||||
}
|
||||
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "tester" {
|
||||
t.Fatalf("event %d missing session scope: %+v", i, evt.Context.Scope)
|
||||
}
|
||||
}
|
||||
|
||||
startPayload, ok := events[0].Payload.(TurnStartPayload)
|
||||
@@ -472,7 +511,6 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
// Use legacyContextManager's summarizeSession via contextManager interface
|
||||
lcm := &legacyContextManager{al: al}
|
||||
lcm.summarizeSession(defaultAgent, "session-1")
|
||||
|
||||
@@ -572,12 +610,6 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
if payload.SourceTool != "async_followup" {
|
||||
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
|
||||
}
|
||||
if payload.Channel != "cli" {
|
||||
t.Fatalf("expected channel cli, got %q", payload.Channel)
|
||||
}
|
||||
if payload.ChatID != "direct" {
|
||||
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
|
||||
}
|
||||
if payload.ContentLen != len("background result") {
|
||||
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
|
||||
}
|
||||
|
||||
+2
-4
@@ -86,6 +86,7 @@ type Event struct {
|
||||
Kind EventKind
|
||||
Time time.Time
|
||||
Meta EventMeta
|
||||
Context *TurnContext
|
||||
Payload any
|
||||
}
|
||||
|
||||
@@ -98,6 +99,7 @@ type EventMeta struct {
|
||||
Iteration int
|
||||
TracePath string
|
||||
Source string
|
||||
turnContext *TurnContext
|
||||
}
|
||||
|
||||
// TurnEndStatus describes the terminal state of a turn.
|
||||
@@ -114,8 +116,6 @@ const (
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
type TurnStartPayload struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
MediaCount int
|
||||
}
|
||||
@@ -217,8 +217,6 @@ type SteeringInjectedPayload struct {
|
||||
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
|
||||
type FollowUpQueuedPayload struct {
|
||||
SourceTool string
|
||||
Channel string
|
||||
ChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -90,7 +92,8 @@ type processHookAfterLLMResponse struct {
|
||||
|
||||
type processHookBeforeToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"` // Result returned directly by hook (for respond action)
|
||||
}
|
||||
|
||||
type processHookAfterToolResponse struct {
|
||||
@@ -120,7 +123,9 @@ func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stderr: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
// Route hook subprocess startup through the shared isolation entry point so
|
||||
// process hooks inherit the same isolation behavior as other child processes.
|
||||
if err := isolation.Start(cmd); err != nil {
|
||||
return nil, fmt.Errorf("start process hook: %w", err)
|
||||
}
|
||||
|
||||
@@ -241,6 +246,10 @@ func (ph *ProcessHook) BeforeTool(
|
||||
if resp.Call == nil {
|
||||
resp.Call = call
|
||||
}
|
||||
// If hook returned a Result, carry it in ToolCallHookRequest
|
||||
if resp.Result != nil {
|
||||
resp.Call.HookResult = resp.Result
|
||||
}
|
||||
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,10 +7,13 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -178,6 +181,76 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_IsolationSupportsRelativeDirAndCommand(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("linux-only isolation path handling")
|
||||
}
|
||||
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
root := t.TempDir()
|
||||
t.Setenv(config.EnvHome, filepath.Join(root, "picoclaw-home"))
|
||||
binDir := filepath.Join(root, "bin")
|
||||
hookDir := filepath.Join(root, "hooks")
|
||||
if err := os.MkdirAll(binDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(hookDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
writeFakeBwrap(t, filepath.Join(binDir, "bwrap"))
|
||||
t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
linkTestBinary(t, os.Args[0], filepath.Join(hookDir, "hook-helper"))
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Isolation.Enabled = true
|
||||
isolation.Configure(cfg)
|
||||
t.Cleanup(func() { isolation.Configure(config.DefaultConfig()) })
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
relHookDir, err := filepath.Rel(cwd, hookDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mountErr := al.MountProcessHook(context.Background(), "ipc-relative", ProcessHookOptions{
|
||||
Command: []string{"./hook-helper", "-test.run=TestProcessHook_HelperProcess", "--"},
|
||||
Dir: relHookDir,
|
||||
Env: processHookHelperEnv("rewrite", ""),
|
||||
InterceptLLM: true,
|
||||
})
|
||||
if mountErr != nil {
|
||||
t.Fatalf("MountProcessHook failed with relative dir/command under isolation: %v", mountErr)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-relative",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked llm content, got %q", resp)
|
||||
}
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func processHookHelperCommand() []string {
|
||||
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
|
||||
}
|
||||
@@ -193,6 +266,59 @@ func processHookHelperEnv(mode, eventLog string) []string {
|
||||
return env
|
||||
}
|
||||
|
||||
func writeFakeBwrap(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
script := `#!/bin/sh
|
||||
set -eu
|
||||
workdir=
|
||||
while [ "$#" -gt 0 ]; do
|
||||
case "$1" in
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
--chdir)
|
||||
workdir="$2"
|
||||
shift 2
|
||||
;;
|
||||
--bind|--ro-bind)
|
||||
shift 3
|
||||
;;
|
||||
--proc|--dev)
|
||||
shift 2
|
||||
;;
|
||||
--die-with-parent|--unshare-ipc)
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
if [ -n "$workdir" ]; then
|
||||
cd "$workdir"
|
||||
fi
|
||||
exec "$@"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("write fake bwrap: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func linkTestBinary(t *testing.T, source, target string) {
|
||||
t.Helper()
|
||||
if err := os.Symlink(source, target); err == nil {
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(source)
|
||||
if err != nil {
|
||||
t.Fatalf("read test binary: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(target, data, 0o755); err != nil {
|
||||
t.Fatalf("create hook helper binary: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForFileContains(t *testing.T, path, substring string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+34
-13
@@ -25,6 +25,7 @@ type HookAction string
|
||||
const (
|
||||
HookActionContinue HookAction = "continue"
|
||||
HookActionModify HookAction = "modify"
|
||||
HookActionRespond HookAction = "respond" // Return result directly, skip tool execution. SECURITY: This bypasses ApproveTool checks, allowing hooks to return results for any tool (including sensitive ones like bash) without approval. Use with caution.
|
||||
HookActionDenyTool HookAction = "deny_tool"
|
||||
HookActionAbortTurn HookAction = "abort_turn"
|
||||
HookActionHardAbort HookAction = "hard_abort"
|
||||
@@ -89,12 +90,11 @@ type ToolApprover interface {
|
||||
|
||||
type LLMHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []providers.Message `json:"messages,omitempty"`
|
||||
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
|
||||
}
|
||||
|
||||
@@ -103,6 +103,8 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Messages = cloneProviderMessages(r.Messages)
|
||||
cloned.Tools = cloneToolDefinitions(r.Tools)
|
||||
cloned.Options = cloneStringAnyMap(r.Options)
|
||||
@@ -111,10 +113,9 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
|
||||
type LLMHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Response *providers.LLMResponse `json:"response,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
@@ -122,16 +123,20 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Response = cloneLLMResponse(r.Response)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
HookResult *tools.ToolResult `json:"hook_result,omitempty"` // Result returned directly by hook (for respond action). Media is supported - see Media handling section in docs.
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
@@ -139,16 +144,18 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.HookResult = cloneToolResult(r.HookResult)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolApprovalRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
@@ -156,18 +163,19 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolResultHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
@@ -175,6 +183,8 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.Result = cloneToolResult(r.Result)
|
||||
return &cloned
|
||||
@@ -382,6 +392,10 @@ func (hm *HookManager) BeforeTool(
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionRespond:
|
||||
// Hook returns result directly, skip tool execution
|
||||
// Carry HookResult in ToolCallHookRequest and return
|
||||
return next, decision
|
||||
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
@@ -793,6 +807,13 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
||||
if len(result.Media) > 0 {
|
||||
cloned.Media = append([]string(nil), result.Media...)
|
||||
}
|
||||
if len(result.ArtifactTags) > 0 {
|
||||
cloned.ArtifactTags = append([]string(nil), result.ArtifactTags...)
|
||||
}
|
||||
if len(result.Messages) > 0 {
|
||||
cloned.Messages = make([]providers.Message, len(result.Messages))
|
||||
copy(cloned.Messages, result.Messages)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
|
||||
+582
-1
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -10,6 +11,8 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -106,7 +109,8 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
}
|
||||
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
eventCh chan Event
|
||||
lastInbound *bus.InboundContext
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
@@ -123,6 +127,9 @@ func (h *llmObserverHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
if req.Context != nil {
|
||||
h.lastInbound = cloneInboundContext(req.Context.Inbound)
|
||||
}
|
||||
next := req.Clone()
|
||||
next.Model = "hook-model"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
@@ -155,6 +162,31 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "hook-user",
|
||||
},
|
||||
RouteResult: &routing.ResolvedRoute{
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
AccountID: routing.DefaultAccountID,
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
MatchedBy: "default",
|
||||
},
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
Account: routing.DefaultAccountID,
|
||||
Dimensions: []string{"sender"},
|
||||
Values: map[string]string{
|
||||
"sender": "hook-user",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
@@ -169,12 +201,30 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
if lastModel != "hook-model" {
|
||||
t.Fatalf("expected model hook-model, got %q", lastModel)
|
||||
}
|
||||
if hook.lastInbound == nil {
|
||||
t.Fatal("expected hook to receive inbound context")
|
||||
}
|
||||
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
|
||||
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
|
||||
}
|
||||
if hook.lastInbound != nil && hook.lastInbound.ChatID != "direct" {
|
||||
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.eventCh:
|
||||
if evt.Kind != EventKindTurnEnd {
|
||||
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
||||
}
|
||||
if evt.Context == nil || evt.Context.Inbound == nil {
|
||||
t.Fatal("expected observer event to carry inbound context")
|
||||
}
|
||||
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
|
||||
t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route)
|
||||
}
|
||||
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" {
|
||||
t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for hook observer event")
|
||||
}
|
||||
@@ -343,3 +393,534 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
// respondHook is a test hook for testing HookActionRespond functionality
|
||||
type respondHook struct {
|
||||
respondTools map[string]bool // tool names to respond to
|
||||
}
|
||||
|
||||
func (h *respondHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.respondTools[call.Tool] {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: "hook-responded: " + call.Tool,
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, HookDecision{Action: HookActionRespond}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *respondHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
// Should not be called since respond skips tool execution
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("respond-hook", &respondHook{
|
||||
respondTools: map[string]bool{"echo_text": true},
|
||||
})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify response comes from hook, not tool
|
||||
expected := "hook-responded: echo_text"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
// Verify event stream has ToolExecEnd, not actual tool execution
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected tool exec end event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
if payload.Tool != "echo_text" {
|
||||
t.Fatalf("expected tool echo_text, got %q", payload.Tool)
|
||||
}
|
||||
if payload.ForLLMLen != len(expected) {
|
||||
t.Fatalf("expected ForLLMLen %d, got %d", len(expected), payload.ForLLMLen)
|
||||
}
|
||||
}
|
||||
|
||||
// denyToolHook tests HookActionDenyTool functionality
|
||||
type denyToolHook struct {
|
||||
denyTools map[string]bool
|
||||
}
|
||||
|
||||
func (h *denyToolHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.denyTools[call.Tool] {
|
||||
return call, HookDecision{Action: HookActionDenyTool, Reason: "tool denied by hook"}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *denyToolHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolDenyAction(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("deny-hook", &denyToolHook{
|
||||
denyTools: map[string]bool{"echo_text": true},
|
||||
})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "Tool execution denied by hook: tool denied by hook"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookManager_BeforeTool_RespondAction(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
defer hm.Close()
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"test_tool": true},
|
||||
}
|
||||
if err := hm.Mount(NamedHook("respond-test", hook)); err != nil {
|
||||
t.Fatalf("mount hook: %v", err)
|
||||
}
|
||||
|
||||
req := &ToolCallHookRequest{
|
||||
Tool: "test_tool",
|
||||
Arguments: map[string]any{"arg": "value"},
|
||||
}
|
||||
result, decision := hm.BeforeTool(context.Background(), req)
|
||||
|
||||
if decision.Action != HookActionRespond {
|
||||
t.Fatalf("expected action %q, got %q", HookActionRespond, decision.Action)
|
||||
}
|
||||
|
||||
if result.HookResult == nil {
|
||||
t.Fatal("expected HookResult to be set")
|
||||
}
|
||||
if result.HookResult.ForLLM != "hook-responded: test_tool" {
|
||||
t.Fatalf("unexpected HookResult.ForLLM: %q", result.HookResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
type respondWithMediaHook struct {
|
||||
respondTools map[string]bool
|
||||
media []string
|
||||
responseHandled bool
|
||||
forLLM string
|
||||
}
|
||||
|
||||
func (h *respondWithMediaHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.respondTools[call.Tool] {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: h.forLLM,
|
||||
ForUser: "media result",
|
||||
Media: h.media,
|
||||
ResponseHandled: h.responseHandled,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, HookDecision{Action: HookActionRespond}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *respondWithMediaHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
type errorMediaChannel struct {
|
||||
fakeChannel
|
||||
sendErr error
|
||||
}
|
||||
|
||||
func (f *errorMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
return nil, f.sendErr
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &respondWithMediaHook{
|
||||
respondTools: map[string]bool{"media_tool": true},
|
||||
media: []string{"media://test/image.png"},
|
||||
responseHandled: true,
|
||||
forLLM: "media sent successfully",
|
||||
}
|
||||
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
al.channelManager = newStartedTestChannelManager(t, al.bus, al.mediaStore, "discord", &errorMediaChannel{
|
||||
sendErr: errors.New("channel unavailable"),
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-media-err",
|
||||
Channel: "discord",
|
||||
ChatID: "chat1",
|
||||
UserMessage: "send media",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
|
||||
if !payload.IsError {
|
||||
t.Fatal("expected IsError=true when SendMedia fails")
|
||||
}
|
||||
|
||||
if payload.ForLLMLen < 30 {
|
||||
t.Fatalf("expected ForLLM to contain error message, got ForLLMLen=%d", payload.ForLLMLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &respondWithMediaHook{
|
||||
respondTools: map[string]bool{"media_tool": true},
|
||||
media: []string{"media://test/image.png"},
|
||||
responseHandled: true,
|
||||
forLLM: "media queued",
|
||||
}
|
||||
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-bus-fallback",
|
||||
Channel: "cli",
|
||||
ChatID: "chat1",
|
||||
UserMessage: "send media",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
|
||||
if payload.IsError {
|
||||
t.Fatal("expected IsError=false for bus fallback (media queued, not delivered)")
|
||||
}
|
||||
|
||||
if resp != "done" {
|
||||
t.Fatalf("expected response 'done', got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type multiToolProvider struct {
|
||||
mu sync.Mutex
|
||||
callCount int
|
||||
toolCalls []providers.ToolCall
|
||||
finalContent string
|
||||
}
|
||||
|
||||
func (p *multiToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.callCount++
|
||||
if p.callCount == 1 && len(p.toolCalls) > 0 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: p.toolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: p.finalContent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *multiToolProvider) GetDefaultModel() string {
|
||||
return "multi-tool-provider"
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
||||
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
||||
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, _, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
tool1ExecCh := make(chan struct{}, 1)
|
||||
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond, execCh: tool1ExecCh})
|
||||
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"tool_one": true},
|
||||
}
|
||||
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"run tools",
|
||||
sessionKey,
|
||||
"cli",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if err := al.InterruptGraceful("stop now"); err != nil {
|
||||
t.Fatalf("InterruptGraceful failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
|
||||
}
|
||||
|
||||
for _, evt := range skippedEvts {
|
||||
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
||||
}
|
||||
if payload.Reason != "graceful interrupt requested" {
|
||||
t.Fatalf("expected skip reason 'graceful interrupt requested', got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
||||
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
||||
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, _, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond})
|
||||
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"tool_one": true},
|
||||
}
|
||||
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"run tools",
|
||||
sessionKey,
|
||||
"cli",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
collectedEvents := make([]Event, 0, 8)
|
||||
steered := false
|
||||
deadline := time.After(3 * time.Second)
|
||||
for !steered {
|
||||
select {
|
||||
case evt := <-sub.C:
|
||||
collectedEvents = append(collectedEvents, evt)
|
||||
if evt.Kind != EventKindToolExecEnd {
|
||||
continue
|
||||
}
|
||||
payload, ok := evt.Payload.(ToolExecEndPayload)
|
||||
if !ok || payload.Tool != "tool_one" {
|
||||
continue
|
||||
}
|
||||
al.Steer(providers.Message{Role: "user", Content: "change direction"})
|
||||
steered = true
|
||||
case <-deadline:
|
||||
t.Fatal("timeout waiting for tool_one to finish before steering")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := append(collectedEvents, collectEventStream(sub.C)...)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after steering")
|
||||
}
|
||||
|
||||
for _, evt := range skippedEvts {
|
||||
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
||||
}
|
||||
if payload.Reason != "queued user steering message" {
|
||||
t.Fatalf("expected skip reason 'queued user steering message', got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterEvents(events []Event, kind EventKind) []Event {
|
||||
var result []Event
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
result = append(result, evt)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
@@ -64,6 +65,12 @@ func NewAgentInstance(
|
||||
cfg *config.Config,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentInstance {
|
||||
if cfg != nil {
|
||||
// Keep the subprocess isolation runtime aligned with the latest loaded config
|
||||
// before any tools or providers start spawning child processes.
|
||||
isolation.Configure(cfg)
|
||||
}
|
||||
|
||||
workspace := resolveAgentWorkspace(agentCfg, defaults)
|
||||
os.MkdirAll(workspace, 0o755)
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func messagesContainMedia(messages []providers.Message) bool {
|
||||
for _, msg := range messages {
|
||||
for _, ref := range msg.Media {
|
||||
if strings.TrimSpace(ref) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stripMessageMedia(messages []providers.Message) []providers.Message {
|
||||
if !messagesContainMedia(messages) {
|
||||
return messages
|
||||
}
|
||||
stripped := make([]providers.Message, len(messages))
|
||||
for i, msg := range messages {
|
||||
stripped[i] = msg
|
||||
stripped[i].Media = nil
|
||||
}
|
||||
return stripped
|
||||
}
|
||||
|
||||
func isVisionUnsupportedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
// OpenRouter (and OpenAI-compatible) style.
|
||||
if strings.Contains(msg, "no endpoints found that support image input") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Common provider variants.
|
||||
if strings.Contains(msg, "does not support image input") ||
|
||||
strings.Contains(msg, "does not support image inputs") ||
|
||||
strings.Contains(msg, "does not support images") ||
|
||||
strings.Contains(msg, "image input is not supported") ||
|
||||
strings.Contains(msg, "images are not supported") ||
|
||||
strings.Contains(msg, "does not support vision") ||
|
||||
strings.Contains(msg, "unsupported content type: image_url") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Some providers return a generic "invalid" message that still mentions image_url.
|
||||
if strings.Contains(msg, "image_url") && strings.Contains(msg, "invalid") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
+817
-239
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,16 @@ type mcpRuntime struct {
|
||||
initErr error
|
||||
}
|
||||
|
||||
func (r *mcpRuntime) reset() *mcp.Manager {
|
||||
r.mu.Lock()
|
||||
manager := r.manager
|
||||
r.manager = nil
|
||||
r.initErr = nil
|
||||
r.initOnce = sync.Once{}
|
||||
r.mu.Unlock()
|
||||
return manager
|
||||
}
|
||||
|
||||
func (r *mcpRuntime) setManager(manager *mcp.Manager) {
|
||||
r.mu.Lock()
|
||||
r.manager = manager
|
||||
|
||||
@@ -7,13 +7,73 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
|
||||
func TestMCPRuntimeResetClearsState(t *testing.T) {
|
||||
var rt mcpRuntime
|
||||
manager := mcp.NewManager()
|
||||
rt.setManager(manager)
|
||||
rt.setInitErr(errors.New("stale init error"))
|
||||
rt.initOnce.Do(func() {})
|
||||
|
||||
got := rt.reset()
|
||||
if got != manager {
|
||||
t.Fatalf("reset() manager = %p, want %p", got, manager)
|
||||
}
|
||||
if rt.hasManager() {
|
||||
t.Fatal("expected manager to be cleared after reset")
|
||||
}
|
||||
if err := rt.getInitErr(); err != nil {
|
||||
t.Fatalf("getInitErr() = %v, want nil", err)
|
||||
}
|
||||
|
||||
reran := false
|
||||
rt.initOnce.Do(func() { reran = true })
|
||||
if !reran {
|
||||
t.Fatal("expected initOnce to be reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadProviderAndConfig_ResetsMCPRuntime(t *testing.T) {
|
||||
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
defer al.Close()
|
||||
|
||||
manager := mcp.NewManager()
|
||||
al.mcp.setManager(manager)
|
||||
al.mcp.setInitErr(errors.New("stale init error"))
|
||||
al.mcp.initOnce.Do(func() {})
|
||||
|
||||
if !al.mcp.hasManager() {
|
||||
t.Fatal("expected MCP manager to exist before reload")
|
||||
}
|
||||
|
||||
if err := al.ReloadProviderAndConfig(context.Background(), &mockProvider{}, cfg); err != nil {
|
||||
t.Fatalf("ReloadProviderAndConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if al.mcp.hasManager() {
|
||||
t.Fatal("expected MCP manager to be cleared when reloaded config has MCP disabled")
|
||||
}
|
||||
if err := al.mcp.getInitErr(); err != nil {
|
||||
t.Fatalf("getInitErr() = %v, want nil", err)
|
||||
}
|
||||
|
||||
reran := false
|
||||
al.mcp.initOnce.Do(func() { reran = true })
|
||||
if !reran {
|
||||
t.Fatal("expected MCP initOnce to be reset after reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerIsDeferred(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
+859
-105
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ package agent
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -64,9 +65,9 @@ func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
|
||||
return agent, ok
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message.
|
||||
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(input)
|
||||
// ResolveRoute determines which agent handles the normalized inbound context.
|
||||
func (r *AgentRegistry) ResolveRoute(inbound bus.InboundContext) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(inbound)
|
||||
}
|
||||
|
||||
// ListAgentIDs returns all registered agent IDs.
|
||||
|
||||
+35
-8
@@ -3,12 +3,14 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -290,12 +292,22 @@ func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
sessionKey, channel, chatID string,
|
||||
scope *session.SessionScope,
|
||||
steeringMsgs []providers.Message,
|
||||
) (string, error) {
|
||||
dispatch := DispatchRequest{
|
||||
SessionKey: sessionKey,
|
||||
SessionScope: session.CloneScope(scope),
|
||||
}
|
||||
if channel != "" || chatID != "" {
|
||||
dispatch.InboundContext = &bus.InboundContext{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
ChatType: inferChatTypeFromSessionScope(scope),
|
||||
}
|
||||
}
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Dispatch: dispatch,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
@@ -310,9 +322,19 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
return nil
|
||||
}
|
||||
|
||||
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
|
||||
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
||||
return agent
|
||||
agentIDs := registry.ListAgentIDs()
|
||||
sort.Strings(agentIDs)
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
if !ok || agent == nil {
|
||||
continue
|
||||
}
|
||||
resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey)
|
||||
if resolvedAgentID == "" {
|
||||
continue
|
||||
}
|
||||
if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok {
|
||||
return scopedAgent
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,7 +374,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
}
|
||||
}
|
||||
|
||||
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
|
||||
var scope *session.SessionScope
|
||||
if metaStore, ok := agent.Sessions.(session.MetadataAwareSessionStore); ok {
|
||||
scope = metaStore.GetSessionScope(sessionKey)
|
||||
}
|
||||
|
||||
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, scope, steeringMsgs)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
|
||||
+89
-36
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -357,7 +358,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-peer",
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -365,14 +366,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
|
||||
|
||||
activeMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "active turn",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "active turn",
|
||||
}
|
||||
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
|
||||
if !ok {
|
||||
@@ -380,14 +380,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
otherMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user2",
|
||||
ChatID: "chat2",
|
||||
Content: "other session",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user2",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat2",
|
||||
ChatType: "direct",
|
||||
SenderID: "user2",
|
||||
},
|
||||
Content: "other session",
|
||||
}
|
||||
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
|
||||
if !ok {
|
||||
@@ -422,9 +421,9 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timeout waiting for requeued message on outbound bus")
|
||||
case requeued := <-msgBus.OutboundChan():
|
||||
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
|
||||
t.Fatalf("timeout waiting for requeued message on inbound bus")
|
||||
case requeued := <-msgBus.InboundChan():
|
||||
if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID ||
|
||||
requeued.Content != otherMsg.Content {
|
||||
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
|
||||
}
|
||||
@@ -841,24 +840,22 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "first message",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "first message",
|
||||
}
|
||||
late := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "late append",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "late append",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
@@ -949,7 +946,7 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
@@ -1013,6 +1010,62 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
List: []config.AgentConfig{
|
||||
{ID: "sales", Default: true},
|
||||
{ID: "support"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
||||
support, ok := al.registry.GetAgent("support")
|
||||
if !ok || support == nil {
|
||||
t.Fatal("expected support agent")
|
||||
}
|
||||
|
||||
metaStore, ok := support.Sessions.(session.MetadataAwareSessionStore)
|
||||
if !ok {
|
||||
t.Fatal("support session store does not support metadata")
|
||||
}
|
||||
|
||||
alias := "agent:support:slack:channel:c001"
|
||||
key := session.BuildOpaqueSessionKey(alias)
|
||||
scope := &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
Account: "default",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "channel:c001",
|
||||
},
|
||||
}
|
||||
metaStore.EnsureSessionMetadata(key, scope, []string{alias})
|
||||
|
||||
got := al.agentForSession(key)
|
||||
if got == nil {
|
||||
t.Fatal("agentForSession() returned nil")
|
||||
}
|
||||
if got.ID != "support" {
|
||||
t.Fatalf("agentForSession() = %q, want %q", got.ID, "support")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
@@ -1060,7 +1113,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.SetMediaStore(store)
|
||||
@@ -1168,7 +1221,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
@@ -1322,7 +1375,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
|
||||
+13
-7
@@ -351,15 +351,17 @@ func spawnSubTurn(
|
||||
}
|
||||
|
||||
// Create processOptions for the child turn
|
||||
dispatch := DispatchRequest{
|
||||
SessionKey: childID,
|
||||
UserMessage: cfg.SystemPrompt,
|
||||
Media: nil,
|
||||
InboundContext: cloneInboundContext(parentTS.opts.Dispatch.InboundContext),
|
||||
}
|
||||
opts := processOptions{
|
||||
SessionKey: childID,
|
||||
Channel: parentTS.channel,
|
||||
ChatID: parentTS.chatID,
|
||||
SenderID: parentTS.opts.SenderID,
|
||||
Dispatch: dispatch,
|
||||
SenderID: parentTS.opts.Dispatch.SenderID(),
|
||||
SenderDisplayName: parentTS.opts.SenderDisplayName,
|
||||
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
|
||||
SystemPromptOverride: cfg.ActualSystemPrompt,
|
||||
Media: nil,
|
||||
InitialSteeringMessages: cfg.InitialMessages,
|
||||
DefaultResponse: "",
|
||||
EnableSummary: false,
|
||||
@@ -369,7 +371,11 @@ func spawnSubTurn(
|
||||
}
|
||||
|
||||
// Create event scope for the child turn
|
||||
scope := al.newTurnEventScope(agent.ID, childID)
|
||||
scope := al.newTurnEventScope(
|
||||
agent.ID,
|
||||
childID,
|
||||
newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope),
|
||||
)
|
||||
|
||||
// Create child turnState using the new API
|
||||
childTS := newTurnState(&agent, opts, scope)
|
||||
|
||||
+15
-12
@@ -56,6 +56,7 @@ type turnState struct {
|
||||
turnID string
|
||||
agentID string
|
||||
sessionKey string
|
||||
turnCtx *TurnContext
|
||||
|
||||
channel string
|
||||
chatID string
|
||||
@@ -115,11 +116,12 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
|
||||
scope: scope,
|
||||
turnID: scope.turnID,
|
||||
agentID: agent.ID,
|
||||
sessionKey: opts.SessionKey,
|
||||
channel: opts.Channel,
|
||||
chatID: opts.ChatID,
|
||||
userMessage: opts.UserMessage,
|
||||
media: append([]string(nil), opts.Media...),
|
||||
sessionKey: opts.Dispatch.SessionKey,
|
||||
turnCtx: cloneTurnContext(scope.context),
|
||||
channel: opts.Dispatch.Channel(),
|
||||
chatID: opts.Dispatch.ChatID(),
|
||||
userMessage: opts.Dispatch.UserMessage,
|
||||
media: append([]string(nil), opts.Dispatch.Media...),
|
||||
phase: TurnPhaseSetup,
|
||||
startedAt: time.Now(),
|
||||
}
|
||||
@@ -127,7 +129,7 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
|
||||
// Bind session store and capture initial history length for rollback logic
|
||||
if agent != nil && agent.Sessions != nil {
|
||||
ts.session = agent.Sessions
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
|
||||
}
|
||||
|
||||
return ts
|
||||
@@ -302,12 +304,13 @@ func (ts *turnState) hardAbortRequested() bool {
|
||||
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
||||
snap := ts.snapshot()
|
||||
return EventMeta{
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
turnContext: cloneTurnContext(ts.turnCtx),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// TurnContext carries normalized turn-scoped facts that can be shared across
|
||||
// events, hooks, and other runtime observers without re-parsing legacy fields.
|
||||
type TurnContext struct {
|
||||
Inbound *bus.InboundContext `json:"inbound,omitempty"`
|
||||
Route *routing.ResolvedRoute `json:"route,omitempty"`
|
||||
Scope *session.SessionScope `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
func newTurnContext(
|
||||
inbound *bus.InboundContext,
|
||||
route *routing.ResolvedRoute,
|
||||
scope *session.SessionScope,
|
||||
) *TurnContext {
|
||||
if inbound == nil && route == nil && scope == nil {
|
||||
return nil
|
||||
}
|
||||
return &TurnContext{
|
||||
Inbound: cloneInboundContext(inbound),
|
||||
Route: cloneResolvedRoute(route),
|
||||
Scope: session.CloneScope(scope),
|
||||
}
|
||||
}
|
||||
|
||||
func cloneTurnContext(ctx *TurnContext) *TurnContext {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *ctx
|
||||
cloned.Inbound = cloneInboundContext(ctx.Inbound)
|
||||
cloned.Route = cloneResolvedRoute(ctx.Route)
|
||||
cloned.Scope = session.CloneScope(ctx.Scope)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneInboundContext(ctx *bus.InboundContext) *bus.InboundContext {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *ctx
|
||||
cloned.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
|
||||
cloned.Raw = cloneStringMap(ctx.Raw)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneStringMap(src map[string]string) map[string]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string]string, len(src))
|
||||
for k, v := range src {
|
||||
cloned[k] = v
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneEventMeta(meta EventMeta) EventMeta {
|
||||
meta.turnContext = cloneTurnContext(meta.turnContext)
|
||||
return meta
|
||||
}
|
||||
|
||||
func cloneResolvedRoute(route *routing.ResolvedRoute) *routing.ResolvedRoute {
|
||||
if route == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *route
|
||||
cloned.SessionPolicy = routing.SessionPolicy{
|
||||
Dimensions: append([]string(nil), route.SessionPolicy.Dimensions...),
|
||||
IdentityLinks: cloneIdentityLinks(route.SessionPolicy.IdentityLinks),
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneIdentityLinks(src map[string][]string) map[string][]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string][]string, len(src))
|
||||
for canonical, ids := range src {
|
||||
dup := make([]string, len(ids))
|
||||
copy(dup, ids)
|
||||
cloned[canonical] = dup
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
+10
-9
@@ -226,8 +226,7 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err})
|
||||
}
|
||||
if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: channelType,
|
||||
ChatID: acc.chatID,
|
||||
Context: bus.NewOutboundContext(channelType, acc.chatID, ""),
|
||||
Content: "Goodbye! Leaving the voice channel.",
|
||||
}); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err})
|
||||
@@ -238,14 +237,16 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
|
||||
oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally."
|
||||
|
||||
if err := a.bus.PublishInbound(ctx, bus.InboundMessage{
|
||||
Channel: channelType,
|
||||
SenderID: acc.speakerID,
|
||||
ChatID: acc.chatID,
|
||||
Content: res.Text + oralPrompt,
|
||||
Peer: bus.Peer{Kind: "channel", ID: acc.chatID},
|
||||
Metadata: map[string]string{
|
||||
"is_voice": "true",
|
||||
Context: bus.InboundContext{
|
||||
Channel: channelType,
|
||||
ChatID: acc.chatID,
|
||||
ChatType: "channel",
|
||||
SenderID: acc.speakerID,
|
||||
Raw: map[string]string{
|
||||
"is_voice": "true",
|
||||
},
|
||||
},
|
||||
Content: res.Text + oralPrompt,
|
||||
}); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err})
|
||||
}
|
||||
|
||||
@@ -185,8 +185,8 @@ func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) {
|
||||
if !strings.Contains(msg.Content, "hello there") {
|
||||
t.Fatalf("unexpected inbound content: %q", msg.Content)
|
||||
}
|
||||
if msg.Metadata["is_voice"] != "true" {
|
||||
t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata)
|
||||
if msg.Context.Raw["is_voice"] != "true" {
|
||||
t.Fatalf("expected is_voice metadata, got %#v", msg.Context.Raw)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("expected inbound publish")
|
||||
|
||||
+19
-1
@@ -12,6 +12,12 @@ import (
|
||||
// ErrBusClosed is returned when publishing to a closed MessageBus.
|
||||
var ErrBusClosed = errors.New("message bus closed")
|
||||
|
||||
var (
|
||||
ErrMissingInboundContext = errors.New("inbound message context is required")
|
||||
ErrMissingOutboundContext = errors.New("outbound message context is required")
|
||||
ErrMissingOutboundMediaContext = errors.New("outbound media context is required")
|
||||
)
|
||||
|
||||
const defaultBusBufferSize = 64
|
||||
|
||||
// StreamDelegate is implemented by the channel Manager to provide streaming
|
||||
@@ -49,7 +55,7 @@ func NewMessageBus() *MessageBus {
|
||||
inbound: make(chan InboundMessage, defaultBusBufferSize),
|
||||
outbound: make(chan OutboundMessage, defaultBusBufferSize),
|
||||
outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize),
|
||||
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer
|
||||
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer.
|
||||
voiceControls: make(chan VoiceControl, defaultBusBufferSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
@@ -84,6 +90,10 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
msg = NormalizeInboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingInboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.inbound, msg)
|
||||
}
|
||||
|
||||
@@ -92,6 +102,10 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
|
||||
msg = NormalizeOutboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingOutboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outbound, msg)
|
||||
}
|
||||
|
||||
@@ -100,6 +114,10 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
|
||||
msg = NormalizeOutboundMediaMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingOutboundMediaContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outboundMedia, msg)
|
||||
}
|
||||
|
||||
|
||||
+438
-15
@@ -14,10 +14,13 @@ func TestPublishConsume(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
msg := InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(ctx, msg); err != nil {
|
||||
@@ -34,6 +37,138 @@ func TestPublishConsume(t *testing.T) {
|
||||
if got.Channel != "test" {
|
||||
t.Fatalf("expected channel 'test', got %q", got.Channel)
|
||||
}
|
||||
if got.Context.Channel != "test" {
|
||||
t.Fatalf("expected context channel 'test', got %q", got.Context.Channel)
|
||||
}
|
||||
if got.Context.ChatID != "chat1" {
|
||||
t.Fatalf("expected context chat ID 'chat1', got %q", got.Context.ChatID)
|
||||
}
|
||||
if got.Context.SenderID != "user1" {
|
||||
t.Fatalf("expected context sender ID 'user1', got %q", got.Context.SenderID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_NormalizesContext(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
ChatID: "C456/1712",
|
||||
ChatType: "group",
|
||||
TopicID: "1712",
|
||||
SpaceID: "T001",
|
||||
SpaceType: "team",
|
||||
SenderID: "U123",
|
||||
MessageID: "1712.01",
|
||||
ReplyToMessageID: "1700.01",
|
||||
Mentioned: true,
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.InboundChan()
|
||||
if got.Context.Channel != "slack" {
|
||||
t.Fatalf("expected context channel slack, got %q", got.Context.Channel)
|
||||
}
|
||||
if got.Context.Account != "workspace-a" {
|
||||
t.Fatalf("expected context account workspace-a, got %q", got.Context.Account)
|
||||
}
|
||||
if got.Context.ChatType != "group" {
|
||||
t.Fatalf("expected context chat type group, got %q", got.Context.ChatType)
|
||||
}
|
||||
if got.Context.TopicID != "1712" {
|
||||
t.Fatalf("expected topic 1712, got %q", got.Context.TopicID)
|
||||
}
|
||||
if got.Context.SpaceType != "team" || got.Context.SpaceID != "T001" {
|
||||
t.Fatalf("expected team space T001, got %q/%q", got.Context.SpaceType, got.Context.SpaceID)
|
||||
}
|
||||
if !got.Context.Mentioned {
|
||||
t.Fatal("expected mentioned=true in context")
|
||||
}
|
||||
if got.Context.ReplyToMessageID != "1700.01" {
|
||||
t.Fatalf("expected reply_to_message_id 1700.01, got %q", got.Context.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_MirrorsContextIntoConvenienceFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
Account: "bot-a",
|
||||
ChatID: "-1001",
|
||||
ChatType: "group",
|
||||
TopicID: "42",
|
||||
SpaceID: "guild-9",
|
||||
SpaceType: "guild",
|
||||
SenderID: "user-1",
|
||||
MessageID: "777",
|
||||
Mentioned: true,
|
||||
ReplyToMessageID: "666",
|
||||
},
|
||||
Content: "hi",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.InboundChan()
|
||||
if got.Channel != "telegram" {
|
||||
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
|
||||
}
|
||||
if got.ChatID != "-1001" {
|
||||
t.Fatalf("expected legacy chat ID -1001, got %q", got.ChatID)
|
||||
}
|
||||
if got.SenderID != "user-1" {
|
||||
t.Fatalf("expected legacy sender ID user-1, got %q", got.SenderID)
|
||||
}
|
||||
if got.MessageID != "777" {
|
||||
t.Fatalf("expected legacy message ID 777, got %q", got.MessageID)
|
||||
}
|
||||
if got.Context.Account != "bot-a" || got.Context.SpaceID != "guild-9" || got.Context.TopicID != "42" {
|
||||
t.Fatalf("unexpected normalized context: %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "session-1",
|
||||
SenderID: "user-1",
|
||||
MessageID: "msg-1",
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.InboundChan()
|
||||
if got.Context.Channel != "pico" {
|
||||
t.Fatalf("expected context channel pico, got %q", got.Context.Channel)
|
||||
}
|
||||
if got.Context.ChatID != "session-1" {
|
||||
t.Fatalf("expected context chat ID session-1, got %q", got.Context.ChatID)
|
||||
}
|
||||
if got.Context.SenderID != "user-1" {
|
||||
t.Fatalf("expected context sender ID user-1, got %q", got.Context.SenderID)
|
||||
}
|
||||
if got.Context.MessageID != "msg-1" {
|
||||
t.Fatalf("expected context message ID msg-1, got %q", got.Context.MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
@@ -43,8 +178,10 @@ func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "123",
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "123",
|
||||
},
|
||||
Content: "world",
|
||||
}
|
||||
|
||||
@@ -59,6 +196,222 @@ func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
if got.Content != "world" {
|
||||
t.Fatalf("expected content 'world', got %q", got.Content)
|
||||
}
|
||||
if got.Context.Channel != "telegram" || got.Context.ChatID != "123" {
|
||||
t.Fatalf("expected normalized outbound context, got %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_MirrorsContextToLegacyFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-42",
|
||||
ReplyToMessageID: "msg-9",
|
||||
},
|
||||
AgentID: "main",
|
||||
SessionKey: "sk_v1_123",
|
||||
Scope: &OutboundScope{
|
||||
Version: 1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Account: "bot-a",
|
||||
Dimensions: []string{"chat", "sender"},
|
||||
Values: map[string]string{
|
||||
"chat": "direct:chat-42",
|
||||
"sender": "user-1",
|
||||
},
|
||||
},
|
||||
Content: "reply",
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundChan()
|
||||
if got.Channel != "telegram" {
|
||||
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
|
||||
}
|
||||
if got.ChatID != "chat-42" {
|
||||
t.Fatalf("expected legacy chat ID chat-42, got %q", got.ChatID)
|
||||
}
|
||||
if got.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
|
||||
}
|
||||
if got.AgentID != "main" || got.SessionKey != "sk_v1_123" {
|
||||
t.Fatalf("unexpected outbound turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
|
||||
}
|
||||
if got.Scope == nil || got.Scope.AgentID != "main" || got.Scope.Values["chat"] != "direct:chat-42" {
|
||||
t.Fatalf("unexpected outbound scope: %+v", got.Scope)
|
||||
}
|
||||
if got.Context.Channel != "telegram" || got.Context.ChatID != "chat-42" {
|
||||
t.Fatalf("unexpected outbound context: %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_PreservesExplicitReplyToMessageID(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-42",
|
||||
},
|
||||
ReplyToMessageID: "msg-9",
|
||||
Content: "reply",
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundChan()
|
||||
if got.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
|
||||
}
|
||||
if got.Context.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_PreservesExplicitReplyToMessageIDWhenContextReplyIsBlank(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-42",
|
||||
ReplyToMessageID: " ",
|
||||
},
|
||||
ReplyToMessageID: "msg-9",
|
||||
Content: "reply",
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundChan()
|
||||
if got.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
|
||||
}
|
||||
if got.Context.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMediaMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C001",
|
||||
},
|
||||
AgentID: "support",
|
||||
SessionKey: "sk_v1_media",
|
||||
Scope: &OutboundScope{
|
||||
Version: 1,
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "channel:c001",
|
||||
},
|
||||
},
|
||||
Parts: []MediaPart{{Type: "image", Ref: "media://1"}},
|
||||
}
|
||||
|
||||
if err := mb.PublishOutboundMedia(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutboundMedia failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundMediaChan()
|
||||
if got.Channel != "slack" {
|
||||
t.Fatalf("expected legacy channel slack, got %q", got.Channel)
|
||||
}
|
||||
if got.ChatID != "C001" {
|
||||
t.Fatalf("expected legacy chat ID C001, got %q", got.ChatID)
|
||||
}
|
||||
if got.AgentID != "support" || got.SessionKey != "sk_v1_media" {
|
||||
t.Fatalf("unexpected outbound media turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
|
||||
}
|
||||
if got.Scope == nil || got.Scope.Values["chat"] != "channel:c001" {
|
||||
t.Fatalf("unexpected outbound media scope: %+v", got.Scope)
|
||||
}
|
||||
if got.Context.Channel != "slack" || got.Context.ChatID != "C001" {
|
||||
t.Fatalf("unexpected outbound media context: %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishAudioChunkSubscribe(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
chunk := AudioChunk{
|
||||
SessionID: "voice-1",
|
||||
SpeakerID: "speaker-1",
|
||||
ChatID: "chat-1",
|
||||
Channel: "discord",
|
||||
Sequence: 7,
|
||||
Format: "opus",
|
||||
Data: []byte{0x01, 0x02},
|
||||
}
|
||||
|
||||
if err := mb.PublishAudioChunk(context.Background(), chunk); err != nil {
|
||||
t.Fatalf("PublishAudioChunk failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := <-mb.AudioChunksChan()
|
||||
if !ok {
|
||||
t.Fatal("AudioChunksChan returned ok=false")
|
||||
}
|
||||
if got.SessionID != "voice-1" || got.Sequence != 7 {
|
||||
t.Fatalf("unexpected audio chunk: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishVoiceControlSubscribe(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
ctrl := VoiceControl{
|
||||
SessionID: "voice-1",
|
||||
ChatID: "chat-1",
|
||||
Type: "command",
|
||||
Action: "start",
|
||||
}
|
||||
|
||||
if err := mb.PublishVoiceControl(context.Background(), ctrl); err != nil {
|
||||
t.Fatalf("PublishVoiceControl failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := <-mb.VoiceControlsChan()
|
||||
if !ok {
|
||||
t.Fatal("VoiceControlsChan returned ok=false")
|
||||
}
|
||||
if got.Type != "command" || got.Action != "start" {
|
||||
t.Fatalf("unexpected voice control: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOutboundContext_NormalizesReplyAddress(t *testing.T) {
|
||||
ctx := NewOutboundContext(" telegram ", " chat-42 ", " msg-9 ")
|
||||
if ctx.Channel != "telegram" {
|
||||
t.Fatalf("expected channel telegram, got %q", ctx.Channel)
|
||||
}
|
||||
if ctx.ChatID != "chat-42" {
|
||||
t.Fatalf("expected chat_id chat-42, got %q", ctx.ChatID)
|
||||
}
|
||||
if ctx.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected reply_to_message_id msg-9, got %q", ctx.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
@@ -68,7 +421,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
// Fill the buffer
|
||||
ctx := context.Background()
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -77,7 +438,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
|
||||
err := mb.PublishInbound(cancelCtx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-overflow",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-overflow",
|
||||
},
|
||||
Content: "overflow",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from canceled context, got nil")
|
||||
}
|
||||
@@ -90,7 +459,15 @@ func TestPublishInbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
@@ -100,7 +477,13 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
|
||||
err := mb.PublishOutbound(context.Background(), OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
@@ -112,14 +495,30 @@ func TestConsumeInbound_ContextCancel(t *testing.T) {
|
||||
defer mb.Close()
|
||||
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
|
||||
mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-cancel",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-cancel",
|
||||
},
|
||||
Content: "ContextCancel",
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -213,7 +612,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
|
||||
// Fill the buffer
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -222,7 +629,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
|
||||
err := mb.PublishInbound(timeoutCtx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-overflow",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-overflow",
|
||||
},
|
||||
Content: "overflow",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when buffer is full and context times out")
|
||||
}
|
||||
@@ -240,7 +655,15 @@ func TestCloseIdempotent(t *testing.T) {
|
||||
mb.Close()
|
||||
|
||||
// After close, publish should return ErrBusClosed
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package bus
|
||||
|
||||
import "strings"
|
||||
|
||||
// NormalizeInboundMessage ensures the inbound context is normalized and keeps
|
||||
// convenience mirrors in sync for runtime consumers.
|
||||
func NormalizeInboundMessage(msg InboundMessage) InboundMessage {
|
||||
if msg.Context.Channel == "" {
|
||||
msg.Context.Channel = msg.Channel
|
||||
}
|
||||
if msg.Context.ChatID == "" {
|
||||
msg.Context.ChatID = msg.ChatID
|
||||
}
|
||||
if msg.Context.SenderID == "" {
|
||||
msg.Context.SenderID = msg.SenderID
|
||||
}
|
||||
if msg.Context.MessageID == "" {
|
||||
msg.Context.MessageID = msg.MessageID
|
||||
}
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
msg.Channel = msg.Context.Channel
|
||||
msg.SenderID = msg.Context.SenderID
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
if msg.MessageID == "" {
|
||||
msg.MessageID = msg.Context.MessageID
|
||||
}
|
||||
if msg.Context.MessageID == "" {
|
||||
msg.Context.MessageID = msg.MessageID
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (ctx InboundContext) isZero() bool {
|
||||
return ctx.Channel == "" &&
|
||||
ctx.Account == "" &&
|
||||
ctx.ChatID == "" &&
|
||||
ctx.ChatType == "" &&
|
||||
ctx.TopicID == "" &&
|
||||
ctx.SpaceID == "" &&
|
||||
ctx.SpaceType == "" &&
|
||||
ctx.SenderID == "" &&
|
||||
ctx.MessageID == "" &&
|
||||
!ctx.Mentioned &&
|
||||
ctx.ReplyToMessageID == "" &&
|
||||
ctx.ReplyToSenderID == "" &&
|
||||
len(ctx.ReplyHandles) == 0 &&
|
||||
len(ctx.Raw) == 0
|
||||
}
|
||||
|
||||
func normalizeInboundContext(ctx InboundContext) InboundContext {
|
||||
ctx.Channel = strings.TrimSpace(ctx.Channel)
|
||||
ctx.Account = strings.TrimSpace(ctx.Account)
|
||||
ctx.ChatID = strings.TrimSpace(ctx.ChatID)
|
||||
ctx.ChatType = normalizeKind(ctx.ChatType)
|
||||
ctx.TopicID = strings.TrimSpace(ctx.TopicID)
|
||||
ctx.SpaceID = strings.TrimSpace(ctx.SpaceID)
|
||||
ctx.SpaceType = normalizeKind(ctx.SpaceType)
|
||||
ctx.SenderID = strings.TrimSpace(ctx.SenderID)
|
||||
ctx.MessageID = strings.TrimSpace(ctx.MessageID)
|
||||
ctx.ReplyToMessageID = strings.TrimSpace(ctx.ReplyToMessageID)
|
||||
ctx.ReplyToSenderID = strings.TrimSpace(ctx.ReplyToSenderID)
|
||||
ctx.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
|
||||
ctx.Raw = cloneStringMap(ctx.Raw)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func cloneStringMap(src map[string]string) map[string]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dst := make(map[string]string, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func normalizeKind(kind string) string {
|
||||
return strings.ToLower(strings.TrimSpace(kind))
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package bus
|
||||
|
||||
import "strings"
|
||||
|
||||
// NewOutboundContext builds the minimal normalized addressing context required
|
||||
// to deliver an outbound text message or reply.
|
||||
func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext {
|
||||
return normalizeInboundContext(InboundContext{
|
||||
Channel: strings.TrimSpace(channel),
|
||||
ChatID: strings.TrimSpace(chatID),
|
||||
ReplyToMessageID: strings.TrimSpace(replyToMessageID),
|
||||
})
|
||||
}
|
||||
|
||||
// NormalizeOutboundMessage ensures Context is normalized and keeps convenience
|
||||
// mirrors in sync for runtime consumers.
|
||||
func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage {
|
||||
msg.Channel = strings.TrimSpace(msg.Channel)
|
||||
msg.ChatID = strings.TrimSpace(msg.ChatID)
|
||||
msg.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID)
|
||||
if msg.Context.Channel == "" {
|
||||
msg.Context.Channel = msg.Channel
|
||||
}
|
||||
if msg.Context.ChatID == "" {
|
||||
msg.Context.ChatID = msg.ChatID
|
||||
}
|
||||
if msg.Context.ReplyToMessageID == "" {
|
||||
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
|
||||
}
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
if msg.Channel == "" {
|
||||
msg.Channel = msg.Context.Channel
|
||||
}
|
||||
if msg.ChatID == "" {
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
}
|
||||
if msg.ReplyToMessageID == "" {
|
||||
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
|
||||
}
|
||||
if msg.Context.ReplyToMessageID == "" {
|
||||
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
|
||||
}
|
||||
msg.Scope = cloneOutboundScope(msg.Scope)
|
||||
return msg
|
||||
}
|
||||
|
||||
// NormalizeOutboundMediaMessage ensures media outbound messages also carry a
|
||||
// normalized context while keeping convenience mirrors in sync.
|
||||
func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage {
|
||||
msg.Channel = strings.TrimSpace(msg.Channel)
|
||||
msg.ChatID = strings.TrimSpace(msg.ChatID)
|
||||
if msg.Context.Channel == "" {
|
||||
msg.Context.Channel = msg.Channel
|
||||
}
|
||||
if msg.Context.ChatID == "" {
|
||||
msg.Context.ChatID = msg.ChatID
|
||||
}
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
if msg.Channel == "" {
|
||||
msg.Channel = msg.Context.Channel
|
||||
}
|
||||
if msg.ChatID == "" {
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
}
|
||||
msg.Scope = cloneOutboundScope(msg.Scope)
|
||||
return msg
|
||||
}
|
||||
|
||||
func cloneOutboundScope(scope *OutboundScope) *OutboundScope {
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *scope
|
||||
if len(scope.Dimensions) > 0 {
|
||||
cloned.Dimensions = append([]string(nil), scope.Dimensions...)
|
||||
}
|
||||
if len(scope.Values) > 0 {
|
||||
cloned.Values = make(map[string]string, len(scope.Values))
|
||||
for key, value := range scope.Values {
|
||||
cloned.Values[key] = value
|
||||
}
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
+64
-25
@@ -1,11 +1,5 @@
|
||||
package bus
|
||||
|
||||
// Peer identifies the routing peer for a message (direct, group, channel, etc.)
|
||||
type Peer struct {
|
||||
Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// SenderInfo provides structured sender identity information.
|
||||
type SenderInfo struct {
|
||||
Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ...
|
||||
@@ -15,26 +9,67 @@ type SenderInfo struct {
|
||||
DisplayName string `json:"display_name,omitempty"` // display name
|
||||
}
|
||||
|
||||
// InboundContext captures the normalized, platform-agnostic facts about an
|
||||
// inbound message. This is the source of truth for routing and session
|
||||
// allocation.
|
||||
type InboundContext struct {
|
||||
Channel string `json:"channel"`
|
||||
Account string `json:"account,omitempty"`
|
||||
|
||||
ChatID string `json:"chat_id"`
|
||||
ChatType string `json:"chat_type,omitempty"` // direct / group / channel
|
||||
TopicID string `json:"topic_id,omitempty"`
|
||||
|
||||
SpaceID string `json:"space_id,omitempty"`
|
||||
SpaceType string `json:"space_type,omitempty"` // guild / team / workspace / tenant
|
||||
|
||||
SenderID string `json:"sender_id"`
|
||||
MessageID string `json:"message_id,omitempty"`
|
||||
|
||||
Mentioned bool `json:"mentioned,omitempty"`
|
||||
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
ReplyToSenderID string `json:"reply_to_sender_id,omitempty"`
|
||||
|
||||
ReplyHandles map[string]string `json:"reply_handles,omitempty"`
|
||||
Raw map[string]string `json:"raw,omitempty"`
|
||||
}
|
||||
|
||||
type InboundMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
SenderID string `json:"sender_id"`
|
||||
Sender SenderInfo `json:"sender"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
Peer Peer `json:"peer"` // routing peer
|
||||
MessageID string `json:"message_id,omitempty"` // platform message ID
|
||||
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
|
||||
SessionKey string `json:"session_key"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Context InboundContext `json:"context"`
|
||||
Sender SenderInfo `json:"sender"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
|
||||
SessionKey string `json:"session_key"`
|
||||
|
||||
// Convenience mirrors derived from Context for runtime consumers.
|
||||
Channel string `json:"channel"`
|
||||
SenderID string `json:"sender_id"`
|
||||
ChatID string `json:"chat_id"`
|
||||
MessageID string `json:"message_id,omitempty"` // platform message ID
|
||||
}
|
||||
|
||||
// OutboundScope captures the structured session scope associated with an
|
||||
// outbound turn result without depending on the session package.
|
||||
type OutboundScope struct {
|
||||
Version int `json:"version,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
Account string `json:"account,omitempty"`
|
||||
Dimensions []string `json:"dimensions,omitempty"`
|
||||
Values map[string]string `json:"values,omitempty"`
|
||||
}
|
||||
|
||||
type OutboundMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Context InboundContext `json:"context"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
Scope *OutboundScope `json:"scope,omitempty"`
|
||||
Content string `json:"content"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
}
|
||||
|
||||
// MediaPart describes a single media attachment to send.
|
||||
@@ -48,9 +83,13 @@ type MediaPart struct {
|
||||
|
||||
// OutboundMediaMessage carries media attachments from Agent to channels via the bus.
|
||||
type OutboundMediaMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Parts []MediaPart `json:"parts"`
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Context InboundContext `json:"context"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
Scope *OutboundScope `json:"scope,omitempty"`
|
||||
Parts []MediaPart `json:"parts"`
|
||||
}
|
||||
|
||||
// AudioChunk represents a chunk of streaming voice data.
|
||||
|
||||
+80
-33
@@ -327,8 +327,13 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewTelegramChannel(cfg, b)
|
||||
channels.RegisterFactory(config.ChannelTelegram, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil { return nil, err }
|
||||
c, ok := decoded.(*config.TelegramSettings)
|
||||
if !ok { return nil, channels.ErrSendFailed }
|
||||
return NewTelegramChannel(bc, c, b)
|
||||
})
|
||||
}
|
||||
```
|
||||
@@ -427,8 +432,13 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewMatrixChannel(cfg, b)
|
||||
channels.RegisterFactory(config.ChannelMatrix, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil { return nil, err }
|
||||
c, ok := decoded.(*config.MatrixSettings)
|
||||
if !ok { return nil, channels.ErrSendFailed }
|
||||
return NewMatrixChannel(bc, c, b)
|
||||
})
|
||||
}
|
||||
```
|
||||
@@ -773,41 +783,59 @@ When the Agent finishes processing a message, Manager's `preSend` automatically:
|
||||
|
||||
### 3.5 Register Configuration and Gateway Integration
|
||||
|
||||
#### Add configuration in `pkg/config/config.go`
|
||||
#### Add configuration entry
|
||||
|
||||
Channels now use a unified map-based configuration (`map[string]*config.Channel`).
|
||||
Each channel entry stores common fields (`enabled`, `type`, `allow_from`, etc.) at
|
||||
the top level, with channel-specific settings in the `settings` sub-key:
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"matrix": {
|
||||
"enabled": true,
|
||||
"type": "matrix",
|
||||
"allow_from": ["@user:example.com"],
|
||||
"settings": {
|
||||
"home_server": "https://matrix.org",
|
||||
"user_id": "@bot:example.com",
|
||||
"access_token": "enc://..."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Secure fields (tokens, passwords, API keys) go into `.security.yml`:
|
||||
|
||||
```yaml
|
||||
channels:
|
||||
matrix:
|
||||
access_token: "your-matrix-access-token"
|
||||
```
|
||||
|
||||
Channel types must be registered in `channelSettingsFactory` in
|
||||
`pkg/config/config_channel.go`:
|
||||
|
||||
```go
|
||||
type ChannelsConfig struct {
|
||||
var channelSettingsFactory = map[string]any{
|
||||
// ... existing channels
|
||||
Matrix MatrixChannelConfig `json:"matrix"`
|
||||
}
|
||||
|
||||
type MatrixChannelConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
HomeServer string `json:"home_server"`
|
||||
Token string `json:"token"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id"`
|
||||
ChannelMatrix: (MatrixSettings{}),
|
||||
}
|
||||
```
|
||||
|
||||
#### Add entry in Manager.initChannels()
|
||||
#### No Manager changes needed
|
||||
|
||||
```go
|
||||
// In the initChannels() method of pkg/channels/manager.go
|
||||
if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
|
||||
m.initChannel("matrix", "Matrix")
|
||||
}
|
||||
```
|
||||
The Manager uses `InitChannelList()` to validate types and decode settings,
|
||||
then looks up factories by `bc.Type`. No per-channel entry needed in Manager —
|
||||
just register the factory and the config entry.
|
||||
|
||||
> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config:
|
||||
> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native),
|
||||
> register both types in `channelSettingsFactory` and branch on config:
|
||||
> ```go
|
||||
> if cfg.UseNative {
|
||||
> m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
> } else {
|
||||
> m.initChannel("whatsapp", "WhatsApp")
|
||||
> }
|
||||
> // In config_channel.go:
|
||||
> ChannelWhatsApp: (WhatsAppSettings{}),
|
||||
> ChannelWhatsAppNative: (WhatsAppSettings{}),
|
||||
> ```
|
||||
|
||||
#### Add blank import in Gateway
|
||||
@@ -947,10 +975,29 @@ channels.WithReasoningChannelID(id) // Set reasoning chain routing target
|
||||
**File**: `pkg/channels/registry.go`
|
||||
|
||||
```go
|
||||
type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
|
||||
type ChannelFactory func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (Channel, error)
|
||||
|
||||
func RegisterFactory(name string, f ChannelFactory) // Called in sub-package init()
|
||||
func getFactory(name string) (ChannelFactory, bool) // Called internally by Manager
|
||||
func RegisterFactory(name string, f ChannelFactory) // Called in sub-package init()
|
||||
func getFactory(name string) (ChannelFactory, bool) // Called internally by Manager
|
||||
func GetRegisteredFactoryNames() []string // Returns all registered factory names
|
||||
```
|
||||
|
||||
For convenience, `RegisterSafeFactory[S any]` provides automatic type-safe settings decoding:
|
||||
|
||||
```go
|
||||
// Instead of manual GetDecoded() + type assertion:
|
||||
channels.RegisterFactory(config.ChannelTelegram,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil { return nil, err }
|
||||
c, ok := decoded.(*config.TelegramSettings)
|
||||
if !ok { return nil, ErrSendFailed }
|
||||
return NewTelegramChannel(bc, c, b)
|
||||
})
|
||||
|
||||
// You can use RegisterSafeFactory (same safety, less boilerplate):
|
||||
channels.RegisterSafeFactory(config.ChannelTelegram, NewTelegramChannel)
|
||||
```
|
||||
|
||||
The factory registry is protected by `sync.RWMutex` and registrations occur during `init()` phase (completed at process startup). Manager looks up factories by name in `initChannel()` and calls them.
|
||||
|
||||
+79
-33
@@ -327,8 +327,13 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewTelegramChannel(cfg, b)
|
||||
channels.RegisterFactory(config.ChannelTelegram, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil { return nil, err }
|
||||
c, ok := decoded.(*config.TelegramSettings)
|
||||
if !ok { return nil, channels.ErrSendFailed }
|
||||
return NewTelegramChannel(bc, c, b)
|
||||
})
|
||||
}
|
||||
```
|
||||
@@ -427,8 +432,13 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewMatrixChannel(cfg, b)
|
||||
channels.RegisterFactory(config.ChannelMatrix, func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil { return nil, err }
|
||||
c, ok := decoded.(*config.MatrixSettings)
|
||||
if !ok { return nil, channels.ErrSendFailed }
|
||||
return NewMatrixChannel(bc, c, b)
|
||||
})
|
||||
}
|
||||
```
|
||||
@@ -772,41 +782,58 @@ if c.owner != nil && c.placeholderRecorder != nil {
|
||||
|
||||
### 3.5 注册配置和 Gateway 接入
|
||||
|
||||
#### 在 `pkg/config/config.go` 中添加配置
|
||||
#### 添加配置入口
|
||||
|
||||
Channels 现在使用统一的 map 类型配置(`map[string]*config.Channel`)。
|
||||
每个 channel 条目将通用字段(`enabled`、`type`、`allow_from` 等)放在顶层,
|
||||
channel 特定的设置放在 `settings` 子键中:
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"matrix": {
|
||||
"enabled": true,
|
||||
"type": "matrix",
|
||||
"allow_from": ["@user:example.com"],
|
||||
"settings": {
|
||||
"home_server": "https://matrix.org",
|
||||
"user_id": "@bot:example.com",
|
||||
"access_token": "enc://..."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
安全字段(token、密码、API 密钥)放入 `.security.yml`:
|
||||
|
||||
```yaml
|
||||
channels:
|
||||
matrix:
|
||||
access_token: "your-matrix-access-token"
|
||||
```
|
||||
|
||||
Channel 类型必须在 `pkg/config/config_channel.go` 的 `channelSettingsFactory` 中注册:
|
||||
|
||||
```go
|
||||
type ChannelsConfig struct {
|
||||
var channelSettingsFactory = map[string]any{
|
||||
// ... 现有 channels
|
||||
Matrix MatrixChannelConfig `json:"matrix"`
|
||||
}
|
||||
|
||||
type MatrixChannelConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
HomeServer string `json:"home_server"`
|
||||
Token string `json:"token"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id"`
|
||||
ChannelMatrix: (MatrixSettings{}),
|
||||
}
|
||||
```
|
||||
|
||||
#### 在 Manager.initChannels() 中添加入口
|
||||
#### 无需修改 Manager
|
||||
|
||||
```go
|
||||
// pkg/channels/manager.go 的 initChannels() 方法中
|
||||
if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
|
||||
m.initChannel("matrix", "Matrix")
|
||||
}
|
||||
```
|
||||
Manager 使用 `InitChannelList()` 来验证类型和解码设置,
|
||||
然后通过 `bc.Type` 查找工厂。不需要在 Manager 中添加每个 channel 的条目——
|
||||
只需注册工厂和配置条目即可。
|
||||
|
||||
> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),需要在 initChannels 中根据配置分支:
|
||||
> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),
|
||||
> 在 `channelSettingsFactory` 中注册两种类型,并根据配置分支:
|
||||
> ```go
|
||||
> if cfg.UseNative {
|
||||
> m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
> } else {
|
||||
> m.initChannel("whatsapp", "WhatsApp")
|
||||
> }
|
||||
> // 在 config_channel.go 中:
|
||||
> ChannelWhatsApp: (WhatsAppSettings{}),
|
||||
> ChannelWhatsAppNative: (WhatsAppSettings{}),
|
||||
> ```
|
||||
|
||||
#### 在 Gateway 中添加 blank import
|
||||
@@ -946,10 +973,29 @@ channels.WithReasoningChannelID(id) // 设置思维链路由目标 channe
|
||||
**文件**:`pkg/channels/registry.go`
|
||||
|
||||
```go
|
||||
type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
|
||||
type ChannelFactory func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (Channel, error)
|
||||
|
||||
func RegisterFactory(name string, f ChannelFactory) // 子包 init() 中调用
|
||||
func getFactory(name string) (ChannelFactory, bool) // Manager 内部调用
|
||||
func RegisterFactory(name string, f ChannelFactory) // 子包 init() 中调用
|
||||
func getFactory(name string) (ChannelFactory, bool) // Manager 内部调用
|
||||
func GetRegisteredFactoryNames() []string // 返回所有已注册的工厂名称
|
||||
```
|
||||
|
||||
为方便使用,`RegisterSafeFactory[S any]` 提供自动类型安全的设置解码:
|
||||
|
||||
```go
|
||||
// 不使用 RegisterSafeFactory(手动 GetDecoded() + 类型断言):
|
||||
channels.RegisterFactory(config.ChannelTelegram,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil { return nil, err }
|
||||
c, ok := decoded.(*config.TelegramSettings)
|
||||
if !ok { return nil, ErrSendFailed }
|
||||
return NewTelegramChannel(bc, c, b)
|
||||
})
|
||||
|
||||
// 使用 RegisterSafeFactory(同等安全,减少样板代码):
|
||||
channels.RegisterSafeFactory(config.ChannelTelegram, NewTelegramChannel)
|
||||
```
|
||||
|
||||
工厂注册表使用 `sync.RWMutex` 保护,在 `init()` 阶段注册(进程启动时完成)。Manager 在 `initChannel()` 中通过名字查找工厂并调用它。
|
||||
|
||||
+55
-19
@@ -103,6 +103,16 @@ func NewBaseChannel(
|
||||
allowList []string,
|
||||
opts ...BaseChannelOption,
|
||||
) *BaseChannel {
|
||||
isEmpty := true
|
||||
for _, s := range allowList {
|
||||
if s != "" {
|
||||
isEmpty = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isEmpty {
|
||||
allowList = []string{}
|
||||
}
|
||||
bc := &BaseChannel{
|
||||
config: config,
|
||||
bus: bus,
|
||||
@@ -177,6 +187,12 @@ func (c *BaseChannel) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// SetName updates the channel name. Used by the manager after channel creation
|
||||
// to ensure the name matches the config key (which may differ from the type).
|
||||
func (c *BaseChannel) SetName(name string) {
|
||||
c.name = name
|
||||
}
|
||||
|
||||
func (c *BaseChannel) ReasoningChannelID() string {
|
||||
return c.reasoningChannelID
|
||||
}
|
||||
@@ -244,12 +260,11 @@ func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *BaseChannel) HandleMessage(
|
||||
func (c *BaseChannel) HandleMessageWithContext(
|
||||
ctx context.Context,
|
||||
peer bus.Peer,
|
||||
messageID, senderID, chatID, content string,
|
||||
deliveryChatID, content string,
|
||||
media []string,
|
||||
metadata map[string]string,
|
||||
inboundCtx bus.InboundContext,
|
||||
senderOpts ...bus.SenderInfo,
|
||||
) {
|
||||
// Use SenderInfo-based allow check when available, else fall back to string
|
||||
@@ -257,6 +272,7 @@ func (c *BaseChannel) HandleMessage(
|
||||
if len(senderOpts) > 0 {
|
||||
sender = senderOpts[0]
|
||||
}
|
||||
senderID := strings.TrimSpace(inboundCtx.SenderID)
|
||||
if sender.CanonicalID != "" || sender.PlatformID != "" {
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return
|
||||
@@ -273,20 +289,28 @@ func (c *BaseChannel) HandleMessage(
|
||||
resolvedSenderID = sender.CanonicalID
|
||||
}
|
||||
|
||||
scope := BuildMediaScope(c.name, chatID, messageID)
|
||||
if resolvedSenderID == "" {
|
||||
resolvedSenderID = senderID
|
||||
}
|
||||
|
||||
inboundCtx.Channel = c.name
|
||||
if inboundCtx.ChatID == "" {
|
||||
inboundCtx.ChatID = deliveryChatID
|
||||
}
|
||||
if inboundCtx.SenderID == "" {
|
||||
inboundCtx.SenderID = resolvedSenderID
|
||||
}
|
||||
|
||||
scope := BuildMediaScope(c.name, deliveryChatID, inboundCtx.MessageID)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: c.name,
|
||||
SenderID: resolvedSenderID,
|
||||
Context: inboundCtx,
|
||||
Sender: sender,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Media: media,
|
||||
Peer: peer,
|
||||
MessageID: messageID,
|
||||
MediaScope: scope,
|
||||
Metadata: metadata,
|
||||
}
|
||||
msg = bus.NormalizeInboundMessage(msg)
|
||||
|
||||
// Auto-trigger typing indicator, message reaction, and placeholder before publishing.
|
||||
// Each capability is independent — all three may fire for the same message.
|
||||
@@ -297,14 +321,14 @@ func (c *BaseChannel) HandleMessage(
|
||||
if c.owner != nil && c.placeholderRecorder != nil {
|
||||
// Typing
|
||||
if tc, ok := c.owner.(TypingCapable); ok {
|
||||
if stop, err := tc.StartTyping(ctx, chatID); err == nil {
|
||||
c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
|
||||
if stop, err := tc.StartTyping(ctx, deliveryChatID); err == nil {
|
||||
c.placeholderRecorder.RecordTypingStop(c.name, deliveryChatID, stop)
|
||||
}
|
||||
}
|
||||
// Reaction
|
||||
if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
|
||||
if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
|
||||
c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
|
||||
if rc, ok := c.owner.(ReactionCapable); ok && msg.MessageID != "" {
|
||||
if undo, err := rc.ReactToMessage(ctx, deliveryChatID, msg.MessageID); err == nil {
|
||||
c.placeholderRecorder.RecordReactionUndo(c.name, deliveryChatID, undo)
|
||||
}
|
||||
}
|
||||
// Placeholder — independent pipeline.
|
||||
@@ -313,8 +337,8 @@ func (c *BaseChannel) HandleMessage(
|
||||
// "Thinking…" only once the voice has been processed.
|
||||
if !audioAnnotationRe.MatchString(content) {
|
||||
if pc, ok := c.owner.(PlaceholderCapable); ok {
|
||||
if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" {
|
||||
c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID)
|
||||
if phID, err := pc.SendPlaceholder(ctx, deliveryChatID); err == nil && phID != "" {
|
||||
c.placeholderRecorder.RecordPlaceholder(c.name, deliveryChatID, phID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -323,12 +347,24 @@ func (c *BaseChannel) HandleMessage(
|
||||
if err := c.bus.PublishInbound(ctx, msg); err != nil {
|
||||
logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{
|
||||
"channel": c.name,
|
||||
"chat_id": chatID,
|
||||
"chat_id": deliveryChatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleInboundContext publishes a normalized inbound message using only the
|
||||
// structured context.
|
||||
func (c *BaseChannel) HandleInboundContext(
|
||||
ctx context.Context,
|
||||
deliveryChatID, content string,
|
||||
media []string,
|
||||
inboundCtx bus.InboundContext,
|
||||
senderOpts ...bus.SenderInfo,
|
||||
) {
|
||||
c.HandleMessageWithContext(ctx, deliveryChatID, content, media, inboundCtx, senderOpts...)
|
||||
}
|
||||
|
||||
func (c *BaseChannel) SetRunning(running bool) {
|
||||
c.running.Store(running)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -263,3 +264,58 @@ func TestIsAllowedSender(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleInboundContext_PublishesNormalizedContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inbound bus.InboundContext
|
||||
wantChat string
|
||||
wantSender string
|
||||
}{
|
||||
{
|
||||
name: "direct uses sender as peer",
|
||||
inbound: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-1",
|
||||
MessageID: "msg-1",
|
||||
},
|
||||
wantChat: "chat-1",
|
||||
wantSender: "user-1",
|
||||
},
|
||||
{
|
||||
name: "group uses chat as peer",
|
||||
inbound: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "group-1",
|
||||
ChatType: "group",
|
||||
SenderID: "user-2",
|
||||
MessageID: "msg-2",
|
||||
},
|
||||
wantChat: "group-1",
|
||||
wantSender: "user-2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
defer msgBus.Close()
|
||||
|
||||
ch := NewBaseChannel("test", nil, msgBus, nil)
|
||||
ch.HandleInboundContext(context.Background(), tt.inbound.ChatID, "hello", nil, tt.inbound)
|
||||
|
||||
msg := <-msgBus.InboundChan()
|
||||
if msg.ChatID != tt.wantChat {
|
||||
t.Fatalf("ChatID = %q, want %q", msg.ChatID, tt.wantChat)
|
||||
}
|
||||
if msg.SenderID != tt.wantSender {
|
||||
t.Fatalf("SenderID = %q, want %q", msg.SenderID, tt.wantSender)
|
||||
}
|
||||
if msg.Context.ChatType != tt.inbound.ChatType {
|
||||
t.Fatalf("ChatType = %q, want %q", msg.Context.ChatType, tt.inbound.ChatType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
// It uses WebSocket for receiving messages via stream mode and API for sending
|
||||
type DingTalkChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.DingTalkConfig
|
||||
config *config.DingTalkSettings
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
@@ -36,7 +36,11 @@ type DingTalkChannel struct {
|
||||
}
|
||||
|
||||
// NewDingTalkChannel creates a new DingTalk channel instance
|
||||
func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (*DingTalkChannel, error) {
|
||||
func NewDingTalkChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.DingTalkSettings,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*DingTalkChannel, error) {
|
||||
if cfg.ClientID == "" || cfg.ClientSecret.String() == "" {
|
||||
return nil, fmt.Errorf("dingtalk client_id and client_secret are required")
|
||||
}
|
||||
@@ -44,10 +48,10 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
|
||||
// Set the logger for the Stream SDK
|
||||
dinglog.SetLogger(logger.NewLogger("dingtalk"))
|
||||
|
||||
base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom,
|
||||
base := channels.NewBaseChannel("dingtalk", cfg, messageBus, bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(20000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &DingTalkChannel{
|
||||
@@ -181,16 +185,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
var (
|
||||
chatType string
|
||||
isMentioned bool
|
||||
)
|
||||
if data.ConversationType == "1" {
|
||||
peerID := senderID
|
||||
if peerID == "" {
|
||||
peerID = chatID
|
||||
}
|
||||
peer = bus.Peer{Kind: "direct", ID: peerID}
|
||||
chatType = "direct"
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
|
||||
isMentioned := data.IsInAtList
|
||||
chatType = "group"
|
||||
isMentioned = data.IsInAtList
|
||||
if isMentioned {
|
||||
content = stripLeadingAtMentions(content)
|
||||
}
|
||||
@@ -228,8 +231,21 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "dingtalk",
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
SenderID: resolvedSenderID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if data.SessionWebhook != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, chatID, content, nil, inboundCtx, sender)
|
||||
|
||||
// Return nil to indicate we've handled the message asynchronously
|
||||
// The response will be sent through the message bus
|
||||
|
||||
@@ -11,7 +11,11 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkChannel, *bus.MessageBus) {
|
||||
func newTestDingTalkChannel(
|
||||
t *testing.T,
|
||||
cfg config.DingTalkSettings,
|
||||
bc *config.Channel,
|
||||
) (*DingTalkChannel, *bus.MessageBus) {
|
||||
t.Helper()
|
||||
|
||||
if cfg.ClientID == "" {
|
||||
@@ -22,7 +26,10 @@ func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkC
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewDingTalkChannel(cfg, msgBus)
|
||||
if bc == nil {
|
||||
bc = &config.Channel{Type: config.ChannelDingTalk, Enabled: true}
|
||||
}
|
||||
ch, err := NewDingTalkChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("new channel: %v", err)
|
||||
}
|
||||
@@ -41,9 +48,12 @@ func mustReceiveInbound(t *testing.T, msgBus *bus.MessageBus) bus.InboundMessage
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention(t *testing.T) {
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{
|
||||
bc := &config.Channel{
|
||||
Type: config.ChannelDingTalk,
|
||||
Enabled: true,
|
||||
GroupTrigger: config.GroupTriggerConfig{MentionOnly: true},
|
||||
})
|
||||
}
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkSettings{}, bc)
|
||||
|
||||
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
Text: chatbot.BotCallbackDataTextModel{Content: " @bot /help "},
|
||||
@@ -65,8 +75,8 @@ func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention
|
||||
if inbound.ChatID != "group-abc" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
if inbound.Context.ChatType != "group" {
|
||||
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
|
||||
}
|
||||
if inbound.Content != "/help" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
@@ -74,7 +84,7 @@ func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *testing.T) {
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{})
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkSettings{}, nil)
|
||||
|
||||
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
Text: chatbot.BotCallbackDataTextModel{Content: "ping"},
|
||||
@@ -93,12 +103,15 @@ func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *te
|
||||
if inbound.ChatID != "conv-direct-42" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
if inbound.Context.ChatType != "direct" {
|
||||
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
|
||||
}
|
||||
if inbound.SenderID != "dingtalk:openid-user-42" {
|
||||
if inbound.SenderID != "openid-user-42" {
|
||||
t.Fatalf("sender_id=%q", inbound.SenderID)
|
||||
}
|
||||
if inbound.Sender.CanonicalID != "dingtalk:openid-user-42" {
|
||||
t.Fatalf("sender canonical_id=%q", inbound.Sender.CanonicalID)
|
||||
}
|
||||
|
||||
if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok {
|
||||
t.Fatal("expected session webhook keyed by conversation_id")
|
||||
|
||||
@@ -7,7 +7,26 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewDingTalkChannel(cfg.Channels.DingTalk, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelDingTalk,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.DingTalkSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
ch, err := NewDingTalkChannel(bc, c, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelName != config.ChannelDingTalk {
|
||||
ch.SetName(channelName)
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -38,8 +38,9 @@ var (
|
||||
|
||||
type DiscordChannel struct {
|
||||
*channels.BaseChannel
|
||||
bc *config.Channel
|
||||
session *discordgo.Session
|
||||
config config.DiscordConfig
|
||||
config *config.DiscordSettings
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
typingMu sync.Mutex
|
||||
@@ -56,7 +57,11 @@ type DiscordChannel struct {
|
||||
ttsPlayID uint64
|
||||
}
|
||||
|
||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||
func NewDiscordChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.DiscordSettings,
|
||||
bus *bus.MessageBus,
|
||||
) (*DiscordChannel, error) {
|
||||
discordgo.Logger = logger.NewLogger("discord").
|
||||
WithLevels(map[int]logger.LogLevel{
|
||||
discordgo.LogError: logger.ERROR,
|
||||
@@ -73,14 +78,15 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
|
||||
base := channels.NewBaseChannel("discord", cfg, bus, bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(2000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &DiscordChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
session: session,
|
||||
config: cfg,
|
||||
ctx: context.Background(),
|
||||
@@ -297,11 +303,11 @@ func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, message
|
||||
// It sends a placeholder message that will later be edited to the actual
|
||||
// response via EditMessage (channels.MessageEditor).
|
||||
func (c *DiscordChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
if !c.config.Placeholder.Enabled {
|
||||
if !c.bc.Placeholder.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
text := c.config.Placeholder.GetRandomText()
|
||||
text := c.bc.Placeholder.GetRandomText()
|
||||
|
||||
msg, err := c.session.ChannelMessageSend(chatID, text)
|
||||
if err != nil {
|
||||
@@ -402,8 +408,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
|
||||
// In guild (group) channels, apply unified group trigger filtering
|
||||
// DMs (GuildID is empty) always get a response
|
||||
isMentioned := false
|
||||
if m.GuildID != "" {
|
||||
isMentioned := false
|
||||
for _, mention := range m.Mentions {
|
||||
if mention.ID == c.botUserID {
|
||||
isMentioned = true
|
||||
@@ -500,14 +506,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
})
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := m.ChannelID
|
||||
if m.GuildID == "" {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"user_id": senderID,
|
||||
"username": m.Author.Username,
|
||||
@@ -516,8 +518,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
"channel_id": m.ChannelID,
|
||||
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: m.ChannelID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: m.ID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if m.GuildID != "" {
|
||||
inboundCtx.SpaceID = m.GuildID
|
||||
inboundCtx.SpaceType = "guild"
|
||||
}
|
||||
if m.MessageReference != nil {
|
||||
inboundCtx.ReplyToMessageID = m.MessageReference.MessageID
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender)
|
||||
c.HandleInboundContext(c.ctx, m.ChannelID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// startTyping starts a continuous typing indicator loop for the given chatID.
|
||||
|
||||
@@ -8,11 +8,23 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
ch, err := NewDiscordChannel(cfg.Channels.Discord, b)
|
||||
if err == nil {
|
||||
ch.tts = tts.DetectTTS(cfg)
|
||||
}
|
||||
return ch, err
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelDiscord,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.DiscordSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
ch, err := NewDiscordChannel(bc, c, b)
|
||||
if err == nil {
|
||||
ch.tts = tts.DetectTTS(cfg)
|
||||
}
|
||||
return ch, err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type FeishuChannel struct {
|
||||
var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
|
||||
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
return nil, errors.New(
|
||||
"feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config",
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
@@ -37,21 +38,28 @@ const errCodeTenantTokenInvalid = 99991663
|
||||
|
||||
type FeishuChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.FeishuConfig
|
||||
bc *config.Channel
|
||||
config *config.FeishuSettings
|
||||
client *lark.Client
|
||||
wsClient *larkws.Client
|
||||
tokenCache *tokenCache // custom cache that supports invalidation
|
||||
|
||||
botOpenID atomic.Value // stores string; populated lazily for @mention detection
|
||||
botOpenID atomic.Value // stores string; populated lazily for @mention detection
|
||||
messageCache sync.Map // caches fetched messages (messageID -> *larkim.Message)
|
||||
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom,
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
type cachedMessage struct {
|
||||
msg *larkim.Message
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
func NewFeishuChannel(bc *config.Channel, cfg *config.FeishuSettings, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
base := channels.NewBaseChannel("feishu", cfg, bus, bc.AllowFrom,
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
tc := newTokenCache()
|
||||
@@ -61,6 +69,7 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
|
||||
}
|
||||
ch := &FeishuChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
config: cfg,
|
||||
tokenCache: tc,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret.String(), opts...),
|
||||
@@ -204,14 +213,14 @@ func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, cont
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
// Sends an interactive card with placeholder text and returns its message ID.
|
||||
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
if !c.config.Placeholder.Enabled {
|
||||
if !c.bc.Placeholder.Enabled {
|
||||
logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
return "", nil
|
||||
}
|
||||
|
||||
text := c.config.Placeholder.GetRandomText()
|
||||
text := c.bc.Placeholder.GetRandomText()
|
||||
|
||||
cardContent, err := buildMarkdownCard(text)
|
||||
if err != nil {
|
||||
@@ -439,30 +448,20 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
metadata := map[string]string{}
|
||||
if messageID != "" {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
if messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
metadata := buildInboundMetadata(message, sender)
|
||||
|
||||
var peer bus.Peer
|
||||
var (
|
||||
inboundChatType string
|
||||
isMentioned bool
|
||||
)
|
||||
if chatType == "p2p" {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
inboundChatType = "direct"
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
inboundChatType = "group"
|
||||
|
||||
// Check if bot was mentioned
|
||||
isMentioned := c.isBotMentioned(message)
|
||||
isMentioned = c.isBotMentioned(message)
|
||||
|
||||
// Strip mention placeholders from content before group trigger check
|
||||
if len(message.Mentions) > 0 {
|
||||
@@ -477,14 +476,41 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
content = cleaned
|
||||
}
|
||||
|
||||
if replyTargetID(message) != "" || stringValue(message.ThreadId) != "" {
|
||||
content, mediaRefs = c.prependReplyContext(ctx, message, chatID, content, mediaRefs)
|
||||
}
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
logger.InfoCF("feishu", "Feishu message received", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"preview": utils.Truncate(content, 80),
|
||||
})
|
||||
logger.InfoCF("feishu", "Feishu reply linkage", map[string]any{
|
||||
"message_id": messageID,
|
||||
"parent_id": stringValue(message.ParentId),
|
||||
"root_id": stringValue(message.RootId),
|
||||
"thread_id": stringValue(message.ThreadId),
|
||||
})
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "feishu",
|
||||
ChatID: chatID,
|
||||
ChatType: inboundChatType,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
|
||||
inboundCtx.SpaceType = "tenant"
|
||||
inboundCtx.SpaceID = *sender.TenantKey
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, chatID, content, mediaRefs, inboundCtx, senderInfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,298 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const messageCacheTTL = 30 * time.Second
|
||||
|
||||
const (
|
||||
maxReplyContextLen = 600
|
||||
)
|
||||
|
||||
func (c *FeishuChannel) prependReplyContext(
|
||||
ctx context.Context,
|
||||
message *larkim.EventMessage,
|
||||
chatID string,
|
||||
content string,
|
||||
mediaRefs []string,
|
||||
) (string, []string) {
|
||||
if message == nil {
|
||||
return content, mediaRefs
|
||||
}
|
||||
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
targetMessageID := c.resolveReplyTargetMessageID(lookupCtx, message)
|
||||
if targetMessageID == "" {
|
||||
logger.DebugCF("feishu", "No reply target resolved; skip reply context", map[string]any{
|
||||
"message_id": stringValue(message.MessageId),
|
||||
"parent_id": stringValue(message.ParentId),
|
||||
"root_id": stringValue(message.RootId),
|
||||
"thread_id": stringValue(message.ThreadId),
|
||||
})
|
||||
return content, mediaRefs
|
||||
}
|
||||
|
||||
repliedMessage, err := c.fetchMessageByID(lookupCtx, targetMessageID)
|
||||
if err != nil {
|
||||
logger.DebugCF("feishu", "Failed to fetch replied message context", map[string]any{
|
||||
"target_message_id": targetMessageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return content, mediaRefs
|
||||
}
|
||||
|
||||
messageType := stringValue(repliedMessage.MsgType)
|
||||
rawContent := ""
|
||||
if repliedMessage.Body != nil {
|
||||
rawContent = stringValue(repliedMessage.Body.Content)
|
||||
}
|
||||
|
||||
var repliedMediaRefs []string
|
||||
if store := c.GetMediaStore(); store != nil {
|
||||
repliedMediaRefs = c.downloadInboundMedia(lookupCtx, chatID, targetMessageID, messageType, rawContent, store)
|
||||
if messageType == larkim.MsgTypeInteractive {
|
||||
_, externalURLs := extractCardImageKeys(rawContent)
|
||||
if len(externalURLs) > 0 {
|
||||
repliedMediaRefs = append(repliedMediaRefs, externalURLs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
repliedContent := normalizeRepliedContent(messageType, rawContent, repliedMediaRefs)
|
||||
if len(repliedMediaRefs) > 0 {
|
||||
mediaRefs = append(repliedMediaRefs, mediaRefs...)
|
||||
}
|
||||
|
||||
return formatReplyContext(targetMessageID, repliedContent, content), mediaRefs
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) resolveReplyTargetMessageID(ctx context.Context, message *larkim.EventMessage) string {
|
||||
if targetID := replyTargetID(message); targetID != "" {
|
||||
logger.DebugCF("feishu", "Resolved reply target from event payload", map[string]any{
|
||||
"message_id": stringValue(message.MessageId),
|
||||
"parent_id": stringValue(message.ParentId),
|
||||
"root_id": stringValue(message.RootId),
|
||||
"target_id": targetID,
|
||||
})
|
||||
return targetID
|
||||
}
|
||||
|
||||
currentMessageID := stringValue(message.MessageId)
|
||||
if currentMessageID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if stringValue(message.ThreadId) == "" {
|
||||
logger.DebugCF("feishu", "No reply target found; message is not in a thread", map[string]any{
|
||||
"message_id": stringValue(message.MessageId),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
msg, err := c.fetchMessageByID(ctx, currentMessageID)
|
||||
if err != nil {
|
||||
logger.DebugCF("feishu", "Failed to query current message detail for reply info", map[string]any{
|
||||
"message_id": currentMessageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
targetID := replyTargetIDFromMessage(msg)
|
||||
if targetID != "" {
|
||||
logger.DebugCF("feishu", "Resolved reply target from message detail", map[string]any{
|
||||
"message_id": currentMessageID,
|
||||
"parent_id": stringValue(msg.ParentId),
|
||||
"root_id": stringValue(msg.RootId),
|
||||
"target_id": targetID,
|
||||
})
|
||||
}
|
||||
return targetID
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) fetchMessageByID(ctx context.Context, messageID string) (*larkim.Message, error) {
|
||||
if cached, ok := c.messageCache.Load(messageID); ok {
|
||||
cm := cached.(*cachedMessage)
|
||||
if time.Now().Before(cm.expiry) {
|
||||
return cm.msg, nil
|
||||
}
|
||||
c.messageCache.Delete(messageID)
|
||||
}
|
||||
|
||||
req := larkim.NewGetMessageReqBuilder().
|
||||
MessageId(messageID).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Get(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("feishu get message: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return nil, fmt.Errorf("feishu get message api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
if resp.Data == nil || len(resp.Data.Items) == 0 || resp.Data.Items[0] == nil {
|
||||
return nil, fmt.Errorf("feishu get message: empty response")
|
||||
}
|
||||
// Items[0] contains the target message - the Feishu API returns a list
|
||||
// but we request a single message by ID, so the list always has at most one item.
|
||||
msg := resp.Data.Items[0]
|
||||
c.messageCache.Store(messageID, &cachedMessage{msg: msg, expiry: time.Now().Add(messageCacheTTL)})
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func replyTargetID(message *larkim.EventMessage) string {
|
||||
if message == nil {
|
||||
return ""
|
||||
}
|
||||
if parentID := stringValue(message.ParentId); parentID != "" {
|
||||
return parentID
|
||||
}
|
||||
return stringValue(message.RootId)
|
||||
}
|
||||
|
||||
func replyTargetIDFromMessage(message *larkim.Message) string {
|
||||
if message == nil {
|
||||
return ""
|
||||
}
|
||||
if parentID := stringValue(message.ParentId); parentID != "" {
|
||||
return parentID
|
||||
}
|
||||
return stringValue(message.RootId)
|
||||
}
|
||||
|
||||
func buildInboundMetadata(message *larkim.EventMessage, sender *larkim.EventSender) map[string]string {
|
||||
metadata := map[string]string{}
|
||||
if message == nil {
|
||||
return metadata
|
||||
}
|
||||
|
||||
messageID := stringValue(message.MessageId)
|
||||
if messageID != "" {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
|
||||
messageType := stringValue(message.MessageType)
|
||||
if messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
}
|
||||
|
||||
parentID := stringValue(message.ParentId)
|
||||
if parentID != "" {
|
||||
metadata["parent_id"] = parentID
|
||||
}
|
||||
|
||||
rootID := stringValue(message.RootId)
|
||||
if rootID != "" {
|
||||
metadata["root_id"] = rootID
|
||||
}
|
||||
|
||||
if replyTo := replyTargetID(message); replyTo != "" {
|
||||
metadata["reply_to_message_id"] = replyTo
|
||||
}
|
||||
|
||||
threadID := stringValue(message.ThreadId)
|
||||
if threadID != "" {
|
||||
metadata["thread_id"] = threadID
|
||||
}
|
||||
|
||||
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
|
||||
return metadata
|
||||
}
|
||||
|
||||
func normalizeRepliedContent(messageType, rawContent string, mediaRefs []string) string {
|
||||
content := extractContent(messageType, rawContent)
|
||||
|
||||
if containsFeishuUpgradePlaceholder(rawContent) || containsFeishuUpgradePlaceholder(content) {
|
||||
content = ""
|
||||
}
|
||||
|
||||
content = appendMediaTags(content, messageType, mediaRefs)
|
||||
if strings.TrimSpace(content) != "" {
|
||||
return content
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case larkim.MsgTypeImage:
|
||||
return "[replied image]"
|
||||
case larkim.MsgTypeFile:
|
||||
return "[replied file]"
|
||||
case larkim.MsgTypeAudio:
|
||||
return "[replied audio]"
|
||||
case larkim.MsgTypeMedia:
|
||||
return "[replied video]"
|
||||
case larkim.MsgTypeInteractive:
|
||||
return "[replied interactive card]"
|
||||
default:
|
||||
return "[replied message content unavailable]"
|
||||
}
|
||||
}
|
||||
|
||||
func containsFeishuUpgradePlaceholder(s string) bool {
|
||||
upgradePrompt := "\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef"
|
||||
upgradePromptEscaped := "\\u8bf7\\u5347\\u7ea7\\u81f3\\u6700\\u65b0\\u7248\\u672c\\u5ba2\\u6237\\u7aef"
|
||||
return strings.Contains(s, upgradePrompt) || strings.Contains(s, upgradePromptEscaped)
|
||||
}
|
||||
|
||||
func formatReplyContext(parentID, repliedContent, content string) string {
|
||||
parentID = strings.TrimSpace(parentID)
|
||||
repliedContent = strings.TrimSpace(repliedContent)
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
if parentID == "" || repliedContent == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
repliedContent = utils.Truncate(repliedContent, maxReplyContextLen)
|
||||
repliedContent = sanitizeReplyContextContent(repliedContent)
|
||||
content = sanitizeReplyContextContent(content)
|
||||
header := fmt.Sprintf("[replied_message id=%q]", parentID)
|
||||
footer := "[/replied_message]"
|
||||
if content == "" {
|
||||
return header + "\n" + repliedContent + "\n" + footer
|
||||
}
|
||||
if hasLeadingCommandPrefix(content) {
|
||||
return content + "\n\n" + header + "\n" + repliedContent + "\n" + footer
|
||||
}
|
||||
return header + "\n" + repliedContent + "\n" + footer + "\n\n[current_message]\n" + content + "\n[/current_message]"
|
||||
}
|
||||
|
||||
func hasLeadingCommandPrefix(s string) bool {
|
||||
tokens := strings.Fields(strings.TrimSpace(s))
|
||||
if len(tokens) == 0 {
|
||||
return false
|
||||
}
|
||||
first := tokens[0]
|
||||
return strings.HasPrefix(first, "/") || strings.HasPrefix(first, "!")
|
||||
}
|
||||
|
||||
func sanitizeReplyContextContent(s string) string {
|
||||
tagEscaper := strings.NewReplacer(
|
||||
"[replied_message", `\[replied_message`,
|
||||
"[/replied_message]", `\[/replied_message]`,
|
||||
"[current_message]", `\[current_message]`,
|
||||
"[/current_message]", `\[/current_message]`,
|
||||
)
|
||||
return tagEscaper.Replace(s)
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
func TestBuildInboundMetadata(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
t.Run("includes basic and reply fields", func(t *testing.T) {
|
||||
message := &larkim.EventMessage{
|
||||
MessageId: strPtr("om_msg_1"),
|
||||
MessageType: strPtr("text"),
|
||||
ChatType: strPtr("group"),
|
||||
ParentId: strPtr("om_parent_1"),
|
||||
RootId: strPtr("om_root_1"),
|
||||
ThreadId: strPtr("omt_thread_1"),
|
||||
}
|
||||
sender := &larkim.EventSender{TenantKey: strPtr("tenant_x")}
|
||||
|
||||
got := buildInboundMetadata(message, sender)
|
||||
|
||||
if got["message_id"] != "om_msg_1" {
|
||||
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_1")
|
||||
}
|
||||
if got["message_type"] != "text" {
|
||||
t.Fatalf("message_type = %q, want %q", got["message_type"], "text")
|
||||
}
|
||||
if got["chat_type"] != "group" {
|
||||
t.Fatalf("chat_type = %q, want %q", got["chat_type"], "group")
|
||||
}
|
||||
if got["parent_id"] != "om_parent_1" {
|
||||
t.Fatalf("parent_id = %q, want %q", got["parent_id"], "om_parent_1")
|
||||
}
|
||||
if got["reply_to_message_id"] != "om_parent_1" {
|
||||
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_parent_1")
|
||||
}
|
||||
if got["root_id"] != "om_root_1" {
|
||||
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_1")
|
||||
}
|
||||
if got["thread_id"] != "omt_thread_1" {
|
||||
t.Fatalf("thread_id = %q, want %q", got["thread_id"], "omt_thread_1")
|
||||
}
|
||||
if got["tenant_key"] != "tenant_x" {
|
||||
t.Fatalf("tenant_key = %q, want %q", got["tenant_key"], "tenant_x")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back reply_to_message_id to root_id", func(t *testing.T) {
|
||||
message := &larkim.EventMessage{
|
||||
MessageId: strPtr("om_msg_3"),
|
||||
RootId: strPtr("om_root_3"),
|
||||
}
|
||||
|
||||
got := buildInboundMetadata(message, nil)
|
||||
|
||||
if got["root_id"] != "om_root_3" {
|
||||
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_3")
|
||||
}
|
||||
if got["reply_to_message_id"] != "om_root_3" {
|
||||
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_root_3")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("omits empty values", func(t *testing.T) {
|
||||
message := &larkim.EventMessage{
|
||||
MessageId: strPtr("om_msg_2"),
|
||||
}
|
||||
|
||||
got := buildInboundMetadata(message, nil)
|
||||
|
||||
if got["message_id"] != "om_msg_2" {
|
||||
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_2")
|
||||
}
|
||||
if _, ok := got["parent_id"]; ok {
|
||||
t.Fatalf("parent_id should be absent, got %q", got["parent_id"])
|
||||
}
|
||||
if _, ok := got["reply_to_message_id"]; ok {
|
||||
t.Fatalf("reply_to_message_id should be absent, got %q", got["reply_to_message_id"])
|
||||
}
|
||||
if _, ok := got["tenant_key"]; ok {
|
||||
t.Fatalf("tenant_key should be absent, got %q", got["tenant_key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil message returns empty map", func(t *testing.T) {
|
||||
got := buildInboundMetadata(nil, nil)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("len(metadata) = %d, want 0", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatReplyContext(t *testing.T) {
|
||||
t.Run("formats reply context with content", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "new reply")
|
||||
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]\n\n[current_message]\nnew reply\n[/current_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns reply context when current content is empty", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "")
|
||||
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns original content when parent or replied content missing", func(t *testing.T) {
|
||||
if got := formatReplyContext("", "original", "new reply"); got != "new reply" {
|
||||
t.Fatalf("missing parent: got %q, want %q", got, "new reply")
|
||||
}
|
||||
if got := formatReplyContext("om_parent_1", "", "new reply"); got != "new reply" {
|
||||
t.Fatalf("missing replied content: got %q, want %q", got, "new reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("escapes reserved wrapper tags in payload", func(t *testing.T) {
|
||||
replied := "payload [replied_message id=\"x\"] x [/replied_message]"
|
||||
current := "hello [current_message]injected[/current_message]"
|
||||
got := formatReplyContext("om_parent_1", replied, current)
|
||||
|
||||
if !strings.HasPrefix(got, "[replied_message id=\"om_parent_1\"]") {
|
||||
t.Fatalf("outer replied_message wrapper missing: %q", got)
|
||||
}
|
||||
if strings.Contains(got, "\n[replied_message id=\"x\"]") {
|
||||
t.Fatalf("nested replied_message tag should be escaped: %q", got)
|
||||
}
|
||||
if strings.Contains(got, "\n[current_message]injected") {
|
||||
t.Fatalf("nested current_message tag should be escaped: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, `\[replied_message id="x"]`) {
|
||||
t.Fatalf("escaped replied tag missing: %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves leading slash command prefix", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "/help")
|
||||
want := "/help\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves leading bang command prefix", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "!status now")
|
||||
want := "!status now\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReplyTargetID(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
t.Run("prefer parent_id", func(t *testing.T) {
|
||||
msg := &larkim.EventMessage{ParentId: strPtr("om_parent"), RootId: strPtr("om_root")}
|
||||
if got := replyTargetID(msg); got != "om_parent" {
|
||||
t.Fatalf("replyTargetID() = %q, want %q", got, "om_parent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback to root_id", func(t *testing.T) {
|
||||
msg := &larkim.EventMessage{RootId: strPtr("om_root")}
|
||||
if got := replyTargetID(msg); got != "om_root" {
|
||||
t.Fatalf("replyTargetID() = %q, want %q", got, "om_root")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty when no fields", func(t *testing.T) {
|
||||
if got := replyTargetID(&larkim.EventMessage{}); got != "" {
|
||||
t.Fatalf("replyTargetID() = %q, want empty", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeRepliedContent(t *testing.T) {
|
||||
t.Run("filters feishu upgrade placeholder for interactive", func(t *testing.T) {
|
||||
raw := `{"text":"\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef\uff0c\u4ee5\u67e5\u770b\u5185\u5bb9"}`
|
||||
got := normalizeRepliedContent("interactive", raw, nil)
|
||||
if got != "[replied interactive card]" {
|
||||
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied interactive card]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("keeps filename and file tag for replied file", func(t *testing.T) {
|
||||
got := normalizeRepliedContent("file", `{"file_key":"file_xxx","file_name":"doc.pdf"}`, []string{"media://r1"})
|
||||
if got != "doc.pdf [file]" {
|
||||
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "doc.pdf [file]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back when file content missing", func(t *testing.T) {
|
||||
got := normalizeRepliedContent("file", `{"file_key":"file_xxx"}`, nil)
|
||||
if got != "[replied file]" {
|
||||
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied file]")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasLeadingCommandPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{name: "slash command", input: "/help", want: true},
|
||||
{name: "bang command", input: "!status", want: true},
|
||||
{name: "leading spaces slash", input: " /ping arg", want: true},
|
||||
{name: "normal text", input: "hello /help", want: false},
|
||||
{name: "empty", input: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := hasLeadingCommandPrefix(tt.input); got != tt.want {
|
||||
t.Fatalf("hasLeadingCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewFeishuChannel(cfg.Channels.Feishu, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelFeishu,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.FeishuSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewFeishuChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -51,14 +51,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&")
|
||||
|
||||
var chatID string
|
||||
var peer bus.Peer
|
||||
|
||||
if isDM {
|
||||
chatID = nick
|
||||
peer = bus.Peer{Kind: "direct", ID: nick}
|
||||
} else {
|
||||
chatID = target
|
||||
peer = bus.Peer{Kind: "group", ID: target}
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
@@ -73,9 +70,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
return
|
||||
}
|
||||
|
||||
isMentioned := false
|
||||
|
||||
// For channel messages, check group trigger (mention detection)
|
||||
if !isDM {
|
||||
isMentioned := isBotMentioned(content, currentNick)
|
||||
isMentioned = isBotMentioned(content, currentNick)
|
||||
if isMentioned {
|
||||
content = stripBotMention(content, currentNick)
|
||||
}
|
||||
@@ -100,7 +99,21 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
metadata["channel"] = target
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "irc",
|
||||
ChatID: chatID,
|
||||
SenderID: nick,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if isDM {
|
||||
inboundCtx.ChatType = "direct"
|
||||
} else {
|
||||
inboundCtx.ChatType = "group"
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// nickMentionedAt returns the byte index where botNick is mentioned in content
|
||||
|
||||
@@ -7,10 +7,29 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("irc", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
if !cfg.Channels.IRC.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
return NewIRCChannel(cfg.Channels.IRC, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelIRC,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
if bc == nil || !bc.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.IRCSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
ch, err := NewIRCChannel(bc, c, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelName != config.ChannelIRC {
|
||||
ch.SetName(channelName)
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -18,14 +18,15 @@ import (
|
||||
// IRCChannel implements the Channel interface for IRC servers.
|
||||
type IRCChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.IRCConfig
|
||||
bc *config.Channel
|
||||
config *config.IRCSettings
|
||||
conn *ircevent.Connection
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewIRCChannel creates a new IRC channel.
|
||||
func NewIRCChannel(cfg config.IRCConfig, messageBus *bus.MessageBus) (*IRCChannel, error) {
|
||||
func NewIRCChannel(bc *config.Channel, cfg *config.IRCSettings, messageBus *bus.MessageBus) (*IRCChannel, error) {
|
||||
if cfg.Server == "" {
|
||||
return nil, fmt.Errorf("irc server is required")
|
||||
}
|
||||
@@ -33,14 +34,15 @@ func NewIRCChannel(cfg config.IRCConfig, messageBus *bus.MessageBus) (*IRCChanne
|
||||
return nil, fmt.Errorf("irc nick is required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("irc", cfg, messageBus, cfg.AllowFrom,
|
||||
base := channels.NewBaseChannel("irc", cfg, messageBus, bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(400),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &IRCChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
@@ -166,7 +168,7 @@ func (c *IRCChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]strin
|
||||
func (c *IRCChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
||||
noop := func() {}
|
||||
|
||||
if !c.config.Typing.Enabled || !c.IsRunning() || c.conn == nil {
|
||||
if !c.bc.Typing.Enabled || !c.IsRunning() || c.conn == nil {
|
||||
return noop, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,28 +11,31 @@ func TestNewIRCChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing server", func(t *testing.T) {
|
||||
cfg := config.IRCConfig{Nick: "bot"}
|
||||
_, err := NewIRCChannel(cfg, msgBus)
|
||||
bc := &config.Channel{Type: config.ChannelIRC, Enabled: true}
|
||||
cfg := &config.IRCSettings{Nick: "bot"}
|
||||
_, err := NewIRCChannel(bc, cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing server, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing nick", func(t *testing.T) {
|
||||
cfg := config.IRCConfig{Server: "irc.example.com:6667"}
|
||||
_, err := NewIRCChannel(cfg, msgBus)
|
||||
bc := &config.Channel{Type: config.ChannelIRC, Enabled: true}
|
||||
cfg := &config.IRCSettings{Server: "irc.example.com:6667"}
|
||||
_, err := NewIRCChannel(bc, cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing nick, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.IRCConfig{
|
||||
bc := &config.Channel{Type: config.ChannelIRC, Enabled: true}
|
||||
cfg := &config.IRCSettings{
|
||||
Server: "irc.example.com:6667",
|
||||
Nick: "testbot",
|
||||
Channels: []string{"#test"},
|
||||
}
|
||||
ch, err := NewIRCChannel(cfg, msgBus)
|
||||
ch, err := NewIRCChannel(bc, cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewLINEChannel(cfg.Channels.LINE, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelLINE,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.LINESettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewLINEChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
+30
-13
@@ -40,7 +40,7 @@ type replyTokenEntry struct {
|
||||
// and the official LINE Bot SDK for sending messages.
|
||||
type LINEChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.LINEConfig
|
||||
config *config.LINESettings
|
||||
client *messaging_api.MessagingApiAPI
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
@@ -52,7 +52,11 @@ type LINEChannel struct {
|
||||
}
|
||||
|
||||
// NewLINEChannel creates a new LINE channel instance.
|
||||
func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) {
|
||||
func NewLINEChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.LINESettings,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*LINEChannel, error) {
|
||||
if cfg.ChannelSecret.String() == "" || cfg.ChannelAccessToken.String() == "" {
|
||||
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
|
||||
}
|
||||
@@ -62,10 +66,10 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
|
||||
return nil, fmt.Errorf("failed to create LINE messaging client: %w", err)
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom,
|
||||
base := channels.NewBaseChannel("line", cfg, messageBus, bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(5000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &LINEChannel{
|
||||
@@ -191,6 +195,7 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
|
||||
var content string
|
||||
var mediaPaths []string
|
||||
var messageID string
|
||||
var quoteToken string
|
||||
var isMentioned bool
|
||||
|
||||
// Helper to register a local file with the media store
|
||||
@@ -214,6 +219,7 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
|
||||
isMentioned = c.isBotMentioned(msg)
|
||||
// Store quote token for quoting the original message in reply
|
||||
if msg.QuoteToken != "" {
|
||||
quoteToken = msg.QuoteToken
|
||||
c.quoteTokens.Store(chatID, msg.QuoteToken)
|
||||
}
|
||||
// Strip bot mention from text in group chats
|
||||
@@ -275,13 +281,6 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
|
||||
"source_type": sourceType,
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
if isGroup {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
}
|
||||
|
||||
logger.DebugCF("line", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
@@ -300,7 +299,25 @@ func (c *LINEChannel) processEvent(event webhook.EventInterface) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: chatID,
|
||||
ChatType: map[bool]string{true: "group", false: "direct"}[isGroup],
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if msgEvent.ReplyToken != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"reply_token": msgEvent.ReplyToken,
|
||||
}
|
||||
if quoteToken != "" {
|
||||
inboundCtx.ReplyHandles["quote_token"] = quoteToken
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message.
|
||||
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestWebhookRejectsOversizedBody(t *testing.T) {
|
||||
ch := &LINEChannel{}
|
||||
ch := &LINEChannel{config: &config.LINESettings{}}
|
||||
|
||||
oversized := bytes.Repeat([]byte("A"), maxWebhookBodySize+1)
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(oversized))
|
||||
@@ -23,7 +25,7 @@ func TestWebhookRejectsOversizedBody(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebhookAcceptsMaxBodySize(t *testing.T) {
|
||||
ch := &LINEChannel{}
|
||||
ch := &LINEChannel{config: &config.LINESettings{}}
|
||||
|
||||
body := bytes.Repeat([]byte("A"), maxWebhookBodySize)
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body))
|
||||
@@ -38,7 +40,7 @@ func TestWebhookAcceptsMaxBodySize(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebhookRejectsOversizedBodyBeforeSignatureCheck(t *testing.T) {
|
||||
ch := &LINEChannel{}
|
||||
ch := &LINEChannel{config: &config.LINESettings{}}
|
||||
|
||||
oversized := bytes.Repeat([]byte("A"), maxWebhookBodySize+1)
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(oversized))
|
||||
@@ -53,7 +55,7 @@ func TestWebhookRejectsOversizedBodyBeforeSignatureCheck(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebhookRejectsNonPostMethod(t *testing.T) {
|
||||
ch := &LINEChannel{}
|
||||
ch := &LINEChannel{config: &config.LINESettings{}}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/webhook", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -66,7 +68,9 @@ func TestWebhookRejectsNonPostMethod(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebhookRejectsInvalidSignature(t *testing.T) {
|
||||
ch := &LINEChannel{}
|
||||
ch := &LINEChannel{
|
||||
config: &config.LINESettings{},
|
||||
}
|
||||
|
||||
body := `{"events":[]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhook", strings.NewReader(body))
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewMaixCamChannel(cfg.Channels.MaixCam, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelMaixCam,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.MaixCamSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewMaixCamChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
|
||||
type MaixCamChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.MaixCamConfig
|
||||
config *config.MaixCamSettings
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -32,13 +32,17 @@ type MaixCamMessage struct {
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) {
|
||||
func NewMaixCamChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.MaixCamSettings,
|
||||
bus *bus.MessageBus,
|
||||
) (*MaixCamChannel, error) {
|
||||
base := channels.NewBaseChannel(
|
||||
"maixcam",
|
||||
cfg,
|
||||
bus,
|
||||
cfg.AllowFrom,
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
bc.AllowFrom,
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &MaixCamChannel{
|
||||
@@ -196,17 +200,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: "channel", ID: "default"},
|
||||
"",
|
||||
senderID,
|
||||
chatID,
|
||||
content,
|
||||
[]string{},
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "maixcam",
|
||||
ChatID: chatID,
|
||||
ChatType: "channel",
|
||||
SenderID: senderID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
|
||||
|
||||
+195
-130
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
@@ -86,6 +87,7 @@ type Manager struct {
|
||||
dispatchTask *asyncTask
|
||||
mux *dynamicServeMux
|
||||
httpServer *http.Server
|
||||
httpListeners []net.Listener
|
||||
mu sync.RWMutex
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
@@ -98,6 +100,22 @@ type asyncTask struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func outboundMessageChannel(msg bus.OutboundMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
|
||||
func outboundMessageChatID(msg bus.OutboundMessage) string {
|
||||
return msg.ChatID
|
||||
}
|
||||
|
||||
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
|
||||
func outboundMediaChatID(msg bus.OutboundMediaMessage) string {
|
||||
return msg.ChatID
|
||||
}
|
||||
|
||||
// RecordPlaceholder registers a placeholder message for later editing.
|
||||
// Implements PlaceholderRecorder.
|
||||
func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) {
|
||||
@@ -161,7 +179,8 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
|
||||
// preSend handles typing stop, reaction undo, and placeholder editing before sending a message.
|
||||
// Returns the delivered message IDs and true when delivery completed before a normal Send.
|
||||
func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) {
|
||||
key := name + ":" + msg.ChatID
|
||||
chatID := outboundMessageChatID(msg)
|
||||
key := name + ":" + chatID
|
||||
|
||||
// 1. Stop typing
|
||||
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
|
||||
@@ -183,9 +202,9 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
// Prefer deleting the placeholder (cleaner UX than editing to same content)
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
|
||||
} else if editor, ok := ch.(MessageEditor); ok {
|
||||
editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback
|
||||
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,7 +215,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if editor, ok := ch.(MessageEditor); ok {
|
||||
if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil {
|
||||
if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil {
|
||||
return []string{entry.id}, true
|
||||
}
|
||||
// edit failed → fall through to normal Send
|
||||
@@ -212,7 +231,8 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
// delivery never edits the placeholder because there is no text payload to
|
||||
// replace it with; it only attempts to delete the placeholder when possible.
|
||||
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
|
||||
key := name + ":" + msg.ChatID
|
||||
chatID := outboundMediaChatID(msg)
|
||||
key := name + ":" + chatID
|
||||
|
||||
// 1. Stop typing
|
||||
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
|
||||
@@ -235,7 +255,7 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -311,22 +331,27 @@ func (s *finalizeHookStreamer) Finalize(ctx context.Context, content string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// initChannel is a helper that looks up a factory by name and creates the channel.
|
||||
func (m *Manager) initChannel(name, displayName string) {
|
||||
f, ok := getFactory(name)
|
||||
// initChannel is a helper that looks up a factory by type name and creates the channel.
|
||||
// typeName is the channel type used for factory lookup (e.g., "telegram").
|
||||
// channelName is the config map key used as the channel's runtime name (e.g., "my_telegram").
|
||||
func (m *Manager) initChannel(typeName, channelName string) {
|
||||
f, ok := getFactory(typeName)
|
||||
if !ok {
|
||||
logger.WarnCF("channels", "Factory not registered", map[string]any{
|
||||
"channel": displayName,
|
||||
"channel": channelName,
|
||||
"type": typeName,
|
||||
})
|
||||
return
|
||||
}
|
||||
logger.DebugCF("channels", "Attempting to initialize channel", map[string]any{
|
||||
"channel": displayName,
|
||||
"channel": channelName,
|
||||
"type": typeName,
|
||||
})
|
||||
ch, err := f(m.config, m.bus)
|
||||
ch, err := f(channelName, typeName, m.config, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize channel", map[string]any{
|
||||
"channel": displayName,
|
||||
"channel": channelName,
|
||||
"type": typeName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
@@ -344,103 +369,100 @@ func (m *Manager) initChannel(name, displayName string) {
|
||||
if setter, ok := ch.(interface{ SetOwner(ch Channel) }); ok {
|
||||
setter.SetOwner(ch)
|
||||
}
|
||||
m.channels[name] = ch
|
||||
m.channels[channelName] = ch
|
||||
logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
|
||||
"channel": displayName,
|
||||
"channel": channelName,
|
||||
"type": typeName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getChannelConfigAndEnabled(channelName string) (*config.Channel, bool) {
|
||||
bc, ok := m.config.Channels[channelName]
|
||||
if !ok || bc == nil {
|
||||
return nil, false
|
||||
}
|
||||
if !bc.Enabled {
|
||||
return bc, false
|
||||
}
|
||||
|
||||
// Use Type to determine the config struct for validation.
|
||||
// The map key (channelName) is the config key, which may differ from the type.
|
||||
channelType := bc.Type
|
||||
if channelType == "" {
|
||||
channelType = channelName
|
||||
}
|
||||
|
||||
// Settings have already been decoded by InitChannelList, so we just need to
|
||||
// type-assert and check the relevant fields.
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return bc, false
|
||||
}
|
||||
//nolint:revive
|
||||
switch settings := decoded.(type) {
|
||||
case *config.WhatsAppSettings:
|
||||
if channelType == config.ChannelWhatsApp {
|
||||
return bc, settings.BridgeURL != ""
|
||||
}
|
||||
return bc, channelType == config.ChannelWhatsAppNative && settings.UseNative
|
||||
case *config.MatrixSettings:
|
||||
return bc, settings.Homeserver != "" && settings.UserID != "" && settings.AccessToken.String() != ""
|
||||
case *config.WeComSettings:
|
||||
return bc, settings.BotID != "" && settings.Secret.String() != ""
|
||||
case *config.PicoClientSettings:
|
||||
return bc, settings.URL != ""
|
||||
case *config.DingTalkSettings:
|
||||
return bc, settings.ClientID != ""
|
||||
case *config.SlackSettings:
|
||||
return bc, settings.BotToken.String() != ""
|
||||
case *config.WeixinSettings:
|
||||
return bc, settings.Token.String() != ""
|
||||
case *config.PicoSettings:
|
||||
return bc, settings.Token.String() != ""
|
||||
case *config.IRCSettings:
|
||||
return bc, settings.Server != ""
|
||||
case *config.LINESettings:
|
||||
return bc, settings.ChannelAccessToken.String() != ""
|
||||
case *config.OneBotSettings:
|
||||
return bc, settings.WSUrl != ""
|
||||
case *config.QQSettings:
|
||||
return bc, settings.AppSecret.String() != ""
|
||||
case *config.TelegramSettings:
|
||||
return bc, settings.Token.String() != ""
|
||||
case *config.FeishuSettings:
|
||||
return bc, settings.AppSecret.String() != ""
|
||||
case *config.MaixCamSettings:
|
||||
return bc, true
|
||||
case *config.TeamsWebhookSettings:
|
||||
return bc, true
|
||||
case *config.DiscordSettings:
|
||||
return bc, settings.Token.String() != ""
|
||||
case *config.VKSettings:
|
||||
return bc, settings.GroupID != 0 && settings.Token.String() != ""
|
||||
}
|
||||
|
||||
return bc, bc.Enabled
|
||||
}
|
||||
|
||||
// initChannels initializes all enabled channels based on the configuration.
|
||||
// It iterates config entries and uses bc.Type to look up the appropriate factory.
|
||||
func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
logger.InfoC("channels", "Initializing channel manager")
|
||||
|
||||
if channels.Telegram.Enabled && channels.Telegram.Token.String() != "" {
|
||||
m.initChannel("telegram", "Telegram")
|
||||
}
|
||||
|
||||
if channels.WhatsApp.Enabled {
|
||||
waCfg := channels.WhatsApp
|
||||
if waCfg.UseNative {
|
||||
m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
} else if waCfg.BridgeURL != "" {
|
||||
m.initChannel("whatsapp", "WhatsApp")
|
||||
for name, bc := range *channels {
|
||||
if !bc.Enabled {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if channels.Feishu.Enabled {
|
||||
m.initChannel("feishu", "Feishu")
|
||||
}
|
||||
|
||||
if channels.Discord.Enabled && channels.Discord.Token.String() != "" {
|
||||
m.initChannel("discord", "Discord")
|
||||
}
|
||||
|
||||
if channels.MaixCam.Enabled {
|
||||
m.initChannel("maixcam", "MaixCam")
|
||||
}
|
||||
|
||||
if channels.QQ.Enabled {
|
||||
m.initChannel("qq", "QQ")
|
||||
}
|
||||
|
||||
if channels.DingTalk.Enabled && channels.DingTalk.ClientID != "" {
|
||||
m.initChannel("dingtalk", "DingTalk")
|
||||
}
|
||||
|
||||
if channels.Slack.Enabled && channels.Slack.BotToken.String() != "" {
|
||||
m.initChannel("slack", "Slack")
|
||||
}
|
||||
|
||||
if channels.Matrix.Enabled &&
|
||||
m.config.Channels.Matrix.Homeserver != "" &&
|
||||
m.config.Channels.Matrix.UserID != "" &&
|
||||
m.config.Channels.Matrix.AccessToken.String() != "" {
|
||||
m.initChannel("matrix", "Matrix")
|
||||
}
|
||||
|
||||
if channels.LINE.Enabled && channels.LINE.ChannelAccessToken.String() != "" {
|
||||
m.initChannel("line", "LINE")
|
||||
}
|
||||
|
||||
if channels.OneBot.Enabled && channels.OneBot.WSUrl != "" {
|
||||
m.initChannel("onebot", "OneBot")
|
||||
}
|
||||
|
||||
if channels.WeCom.Enabled && channels.WeCom.BotID != "" && channels.WeCom.Secret.String() != "" {
|
||||
m.initChannel("wecom", "WeCom")
|
||||
}
|
||||
|
||||
if channels.Weixin.Enabled && channels.Weixin.Token.String() != "" {
|
||||
m.initChannel("weixin", "Weixin")
|
||||
}
|
||||
|
||||
if channels.Pico.Enabled && channels.Pico.Token.String() != "" {
|
||||
m.initChannel("pico", "Pico")
|
||||
}
|
||||
|
||||
if channels.PicoClient.Enabled && channels.PicoClient.URL != "" {
|
||||
m.initChannel("pico_client", "Pico Client")
|
||||
}
|
||||
|
||||
if channels.IRC.Enabled && channels.IRC.Server != "" {
|
||||
m.initChannel("irc", "IRC")
|
||||
}
|
||||
|
||||
if channels.VK.Enabled && channels.VK.Token.String() != "" && channels.VK.GroupID != 0 {
|
||||
m.initChannel("vk", "VK")
|
||||
}
|
||||
|
||||
if channels.TeamsWebhook.Enabled && len(channels.TeamsWebhook.Webhooks) > 0 {
|
||||
hasValidTarget := false
|
||||
for _, target := range channels.TeamsWebhook.Webhooks {
|
||||
if target.WebhookURL.String() != "" {
|
||||
hasValidTarget = true
|
||||
break
|
||||
}
|
||||
_, ready := m.getChannelConfigAndEnabled(name)
|
||||
if !ready {
|
||||
continue
|
||||
}
|
||||
if hasValidTarget {
|
||||
m.initChannel("teams_webhook", "Teams Webhook")
|
||||
typeName := bc.Type
|
||||
if typeName == "" {
|
||||
typeName = name
|
||||
}
|
||||
m.initChannel(typeName, name)
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
|
||||
@@ -454,6 +476,12 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
// It registers health endpoints from the health server and discovers channels
|
||||
// that implement WebhookHandler and/or HealthChecker to register their handlers.
|
||||
func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
m.SetupHTTPServerListeners(nil, addr, healthServer)
|
||||
}
|
||||
|
||||
// SetupHTTPServerListeners creates a shared HTTP server on pre-opened listeners.
|
||||
// When listeners is empty it falls back to Addr-based ListenAndServe behavior.
|
||||
func (m *Manager) SetupHTTPServerListeners(listeners []net.Listener, addr string, healthServer *health.Server) {
|
||||
m.mux = newDynamicServeMux()
|
||||
|
||||
// Register health endpoints
|
||||
@@ -470,6 +498,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
m.httpListeners = append([]net.Listener(nil), listeners...)
|
||||
}
|
||||
|
||||
// registerHTTPHandlersLocked registers webhook and health-check handlers for
|
||||
@@ -548,7 +577,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
// Lazily create worker only after channel starts successfully
|
||||
w := newChannelWorker(name, channel)
|
||||
channelType := name
|
||||
if m.config != nil {
|
||||
if bc := m.config.Channels.Get(name); bc != nil && bc.Type != "" {
|
||||
channelType = bc.Type
|
||||
}
|
||||
}
|
||||
w := newChannelWorker(name, channel, channelType)
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
@@ -593,16 +628,33 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
// Start shared HTTP server if configured
|
||||
if m.httpServer != nil {
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
if len(m.httpListeners) > 0 {
|
||||
for _, listener := range m.httpListeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": ln.Addr().String(),
|
||||
})
|
||||
if err := m.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"addr": ln.Addr().String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel startup completed", map[string]any{
|
||||
@@ -629,6 +681,7 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
||||
})
|
||||
}
|
||||
m.httpServer = nil
|
||||
m.httpListeners = nil
|
||||
}
|
||||
|
||||
// Cancel dispatcher
|
||||
@@ -678,10 +731,10 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// newChannelWorker creates a channelWorker with a rate limiter configured
|
||||
// for the given channel name.
|
||||
func newChannelWorker(name string, ch Channel) *channelWorker {
|
||||
// for the given channel type. channelType is used for rate limit lookup.
|
||||
func newChannelWorker(name string, ch Channel, channelType string) *channelWorker {
|
||||
rateVal := float64(defaultRateLimit)
|
||||
if r, ok := channelRateConfig[name]; ok {
|
||||
if r, ok := channelRateConfig[channelType]; ok {
|
||||
rateVal = r
|
||||
}
|
||||
burst := int(math.Max(1, math.Ceil(rateVal/2)))
|
||||
@@ -812,7 +865,7 @@ func (m *Manager) sendWithRetry(
|
||||
// All retries exhausted or permanent failure
|
||||
logger.ErrorCF("channels", "Send failed", map[string]any{
|
||||
"channel": name,
|
||||
"chat_id": msg.ChatID,
|
||||
"chat_id": outboundMessageChatID(msg),
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
@@ -874,7 +927,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.OutboundChan(),
|
||||
func(msg bus.OutboundMessage) string { return msg.Channel },
|
||||
func(msg bus.OutboundMessage) string { return outboundMessageChannel(msg) },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
@@ -894,7 +947,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.OutboundMediaChan(),
|
||||
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
|
||||
func(msg bus.OutboundMediaMessage) string { return outboundMediaChannel(msg) },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
@@ -993,7 +1046,7 @@ func (m *Manager) sendMediaWithRetry(
|
||||
// All retries exhausted or permanent failure
|
||||
logger.ErrorCF("channels", "SendMedia failed", map[string]any{
|
||||
"channel": name,
|
||||
"chat_id": msg.ChatID,
|
||||
"chat_id": outboundMediaChatID(msg),
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
@@ -1137,7 +1190,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
continue
|
||||
}
|
||||
// Lazily create worker only after channel starts successfully
|
||||
w := newChannelWorker(name, channel)
|
||||
channelType := name
|
||||
if m.config != nil {
|
||||
if bc := m.config.Channels.Get(name); bc != nil && bc.Type != "" {
|
||||
channelType = bc.Type
|
||||
}
|
||||
}
|
||||
w := newChannelWorker(name, channel, channelType)
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
@@ -1186,16 +1245,19 @@ func (m *Manager) UnregisterChannel(name string) {
|
||||
// delivered (or all retries are exhausted), which preserves ordering when
|
||||
// a subsequent operation depends on the message having been sent.
|
||||
func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
msg = bus.NormalizeOutboundMessage(msg)
|
||||
channelName := outboundMessageChannel(msg)
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channelName]
|
||||
w, wExists := m.workers[channelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("channel %s not found", msg.Channel)
|
||||
return fmt.Errorf("channel %s not found", channelName)
|
||||
}
|
||||
if !wExists || w == nil {
|
||||
return fmt.Errorf("channel %s has no active worker", msg.Channel)
|
||||
return fmt.Errorf("channel %s has no active worker", channelName)
|
||||
}
|
||||
|
||||
maxLen := 0
|
||||
@@ -1206,10 +1268,10 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
for _, chunk := range SplitMessage(msg.Content, maxLen) {
|
||||
chunkMsg := msg
|
||||
chunkMsg.Content = chunk
|
||||
m.sendWithRetry(ctx, msg.Channel, w, chunkMsg)
|
||||
m.sendWithRetry(ctx, channelName, w, chunkMsg)
|
||||
}
|
||||
} else {
|
||||
m.sendWithRetry(ctx, msg.Channel, w, msg)
|
||||
m.sendWithRetry(ctx, channelName, w, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1219,19 +1281,22 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
// retries are exhausted), which preserves ordering when later agent behavior
|
||||
// depends on actual media delivery.
|
||||
func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
msg = bus.NormalizeOutboundMediaMessage(msg)
|
||||
channelName := outboundMediaChannel(msg)
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channelName]
|
||||
w, wExists := m.workers[channelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("channel %s not found", msg.Channel)
|
||||
return fmt.Errorf("channel %s not found", channelName)
|
||||
}
|
||||
if !wExists || w == nil {
|
||||
return fmt.Errorf("channel %s has no active worker", msg.Channel)
|
||||
return fmt.Errorf("channel %s has no active worker", channelName)
|
||||
}
|
||||
|
||||
_, err := m.sendMediaWithRetry(ctx, msg.Channel, w, msg)
|
||||
_, err := m.sendMediaWithRetry(ctx, channelName, w, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1246,10 +1311,10 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
Channel: channelName,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channelName, chatID, ""),
|
||||
Content: content,
|
||||
}
|
||||
msg = bus.NormalizeOutboundMessage(msg)
|
||||
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
|
||||
+62
-110
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func toChannelHashes(cfg *config.Config) map[string]string {
|
||||
@@ -21,7 +20,7 @@ func toChannelHashes(cfg *config.Config) map[string]string {
|
||||
if !value["enabled"].(bool) {
|
||||
continue
|
||||
}
|
||||
hiddenValues(key, value, ch)
|
||||
hiddenValues(key, value, ch.Get(key))
|
||||
valueBytes, _ := json.Marshal(value)
|
||||
hash := md5.Sum(valueBytes)
|
||||
result[key] = hex.EncodeToString(hash[:])
|
||||
@@ -30,43 +29,77 @@ func toChannelHashes(cfg *config.Config) map[string]string {
|
||||
return result
|
||||
}
|
||||
|
||||
func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
|
||||
func hiddenValues(key string, value map[string]any, ch *config.Channel) {
|
||||
v, err := ch.GetDecoded()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch key {
|
||||
case "pico":
|
||||
value["token"] = ch.Pico.Token.String()
|
||||
if settings, ok := v.(*config.PicoSettings); ok {
|
||||
value["token"] = settings.Token.String()
|
||||
}
|
||||
case "telegram":
|
||||
value["token"] = ch.Telegram.Token.String()
|
||||
if settings, ok := v.(*config.TelegramSettings); ok {
|
||||
value["token"] = settings.Token.String()
|
||||
}
|
||||
case "discord":
|
||||
value["token"] = ch.Discord.Token.String()
|
||||
if settings, ok := v.(*config.DiscordSettings); ok {
|
||||
value["token"] = settings.Token.String()
|
||||
}
|
||||
case "slack":
|
||||
value["bot_token"] = ch.Slack.BotToken.String()
|
||||
value["app_token"] = ch.Slack.AppToken.String()
|
||||
if settings, ok := v.(*config.SlackSettings); ok {
|
||||
value["bot_token"] = settings.BotToken.String()
|
||||
value["app_token"] = settings.AppToken.String()
|
||||
}
|
||||
case "matrix":
|
||||
value["token"] = ch.Matrix.AccessToken.String()
|
||||
if settings, ok := v.(*config.MatrixSettings); ok {
|
||||
value["token"] = settings.AccessToken.String()
|
||||
}
|
||||
case "onebot":
|
||||
value["token"] = ch.OneBot.AccessToken.String()
|
||||
if settings, ok := v.(*config.OneBotSettings); ok {
|
||||
value["token"] = settings.AccessToken.String()
|
||||
}
|
||||
case "line":
|
||||
value["token"] = ch.LINE.ChannelAccessToken.String()
|
||||
value["secret"] = ch.LINE.ChannelSecret.String()
|
||||
if settings, ok := v.(*config.LINESettings); ok {
|
||||
value["token"] = settings.ChannelAccessToken.String()
|
||||
value["secret"] = settings.ChannelSecret.String()
|
||||
}
|
||||
case "wecom":
|
||||
value["secret"] = ch.WeCom.Secret.String()
|
||||
if settings, ok := v.(*config.WeComSettings); ok {
|
||||
value["secret"] = settings.Secret.String()
|
||||
}
|
||||
case "dingtalk":
|
||||
value["secret"] = ch.DingTalk.ClientSecret.String()
|
||||
if settings, ok := v.(*config.DingTalkSettings); ok {
|
||||
value["secret"] = settings.ClientSecret.String()
|
||||
}
|
||||
case "qq":
|
||||
value["secret"] = ch.QQ.AppSecret.String()
|
||||
if settings, ok := v.(*config.QQSettings); ok {
|
||||
value["secret"] = settings.AppSecret.String()
|
||||
}
|
||||
case "irc":
|
||||
value["password"] = ch.IRC.Password.String()
|
||||
value["serv_password"] = ch.IRC.NickServPassword.String()
|
||||
value["sasl_password"] = ch.IRC.SASLPassword.String()
|
||||
if settings, ok := v.(*config.IRCSettings); ok {
|
||||
value["password"] = settings.Password.String()
|
||||
value["serv_password"] = settings.NickServPassword.String()
|
||||
value["sasl_password"] = settings.SASLPassword.String()
|
||||
}
|
||||
case "feishu":
|
||||
value["app_secret"] = ch.Feishu.AppSecret.String()
|
||||
value["encrypt_key"] = ch.Feishu.EncryptKey.String()
|
||||
value["verification_token"] = ch.Feishu.VerificationToken.String()
|
||||
if settings, ok := v.(*config.FeishuSettings); ok {
|
||||
value["app_secret"] = settings.AppSecret.String()
|
||||
value["encrypt_key"] = settings.EncryptKey.String()
|
||||
value["verification_token"] = settings.VerificationToken.String()
|
||||
}
|
||||
case "teams_webhook":
|
||||
// Expose webhook URLs for hash computation (they contain secrets)
|
||||
vv := value["webhooks"]
|
||||
webhooks := make(map[string]string)
|
||||
for name, target := range ch.TeamsWebhook.Webhooks {
|
||||
webhooks[name] = target.WebhookURL.String()
|
||||
if vv != nil {
|
||||
webhooks = vv.(map[string]string)
|
||||
}
|
||||
if settings, ok := v.(*config.TeamsWebhookSettings); ok {
|
||||
for name, target := range settings.Webhooks {
|
||||
webhooks[name] = target.WebhookURL.String()
|
||||
}
|
||||
}
|
||||
value["webhooks"] = webhooks
|
||||
}
|
||||
@@ -92,94 +125,13 @@ func compareChannels(old, news map[string]string) (added, removed []string) {
|
||||
}
|
||||
|
||||
func toChannelConfig(cfg *config.Config, list []string) (*config.ChannelsConfig, error) {
|
||||
result := &config.ChannelsConfig{}
|
||||
ch := cfg.Channels
|
||||
// should not be error
|
||||
marshal, _ := json.Marshal(ch)
|
||||
var channelConfig map[string]map[string]any
|
||||
_ = json.Unmarshal(marshal, &channelConfig)
|
||||
temp := make(map[string]map[string]any, 0)
|
||||
|
||||
for key, value := range channelConfig {
|
||||
found := false
|
||||
for _, s := range list {
|
||||
if key == s {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found || !value["enabled"].(bool) {
|
||||
result := make(config.ChannelsConfig)
|
||||
for _, name := range list {
|
||||
bc, ok := cfg.Channels[name]
|
||||
if !ok || !bc.Enabled {
|
||||
continue
|
||||
}
|
||||
temp[key] = value
|
||||
}
|
||||
|
||||
marshal, err := json.Marshal(temp)
|
||||
if err != nil {
|
||||
logger.Errorf("marshal error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(marshal, result)
|
||||
if err != nil {
|
||||
logger.Errorf("unmarshal error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updateKeys(result, &ch)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func updateKeys(newcfg, old *config.ChannelsConfig) {
|
||||
if newcfg.Pico.Enabled {
|
||||
newcfg.Pico.Token = old.Pico.Token
|
||||
}
|
||||
if newcfg.Telegram.Enabled {
|
||||
newcfg.Telegram.Token = old.Telegram.Token
|
||||
}
|
||||
if newcfg.Discord.Enabled {
|
||||
newcfg.Discord.Token = old.Discord.Token
|
||||
}
|
||||
if newcfg.Slack.Enabled {
|
||||
newcfg.Slack.BotToken = old.Slack.BotToken
|
||||
newcfg.Slack.AppToken = old.Slack.AppToken
|
||||
}
|
||||
if newcfg.Matrix.Enabled {
|
||||
newcfg.Matrix.AccessToken = old.Matrix.AccessToken
|
||||
}
|
||||
if newcfg.OneBot.Enabled {
|
||||
newcfg.OneBot.AccessToken = old.OneBot.AccessToken
|
||||
}
|
||||
if newcfg.LINE.Enabled {
|
||||
newcfg.LINE.ChannelAccessToken = old.LINE.ChannelAccessToken
|
||||
newcfg.LINE.ChannelSecret = old.LINE.ChannelSecret
|
||||
}
|
||||
if newcfg.WeCom.Enabled {
|
||||
newcfg.WeCom.Secret = old.WeCom.Secret
|
||||
}
|
||||
if newcfg.DingTalk.Enabled {
|
||||
newcfg.DingTalk.ClientSecret = old.DingTalk.ClientSecret
|
||||
}
|
||||
if newcfg.QQ.Enabled {
|
||||
newcfg.QQ.AppSecret = old.QQ.AppSecret
|
||||
}
|
||||
if newcfg.IRC.Enabled {
|
||||
newcfg.IRC.Password = old.IRC.Password
|
||||
newcfg.IRC.NickServPassword = old.IRC.NickServPassword
|
||||
newcfg.IRC.SASLPassword = old.IRC.SASLPassword
|
||||
}
|
||||
if newcfg.Feishu.Enabled {
|
||||
newcfg.Feishu.AppSecret = old.Feishu.AppSecret
|
||||
newcfg.Feishu.EncryptKey = old.Feishu.EncryptKey
|
||||
newcfg.Feishu.VerificationToken = old.Feishu.VerificationToken
|
||||
}
|
||||
if newcfg.TeamsWebhook.Enabled {
|
||||
// Copy SecureString webhook URLs from old config
|
||||
for name, oldTarget := range old.TeamsWebhook.Webhooks {
|
||||
if newTarget, ok := newcfg.TeamsWebhook.Webhooks[name]; ok {
|
||||
newTarget.WebhookURL = oldTarget.WebhookURL
|
||||
newcfg.TeamsWebhook.Webhooks[name] = newTarget
|
||||
}
|
||||
}
|
||||
result[name] = bc
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -15,37 +16,138 @@ func TestToChannelHashes(t *testing.T) {
|
||||
results := toChannelHashes(cfg)
|
||||
assert.Equal(t, 0, len(results))
|
||||
logger.Debugf("results: %v", results)
|
||||
|
||||
// Add dingtalk channel via map
|
||||
cfg2 := config.DefaultConfig()
|
||||
cfg2.Channels.DingTalk.Enabled = true
|
||||
cfg2.Channels["dingtalk"] = &config.Channel{
|
||||
Enabled: true,
|
||||
Type: config.ChannelDingTalk,
|
||||
Settings: config.RawNode(`{"enabled":true}`),
|
||||
}
|
||||
results2 := toChannelHashes(cfg2)
|
||||
assert.Equal(t, 1, len(results2))
|
||||
logger.Debugf("results2: %v", results2)
|
||||
added, removed := compareChannels(results, results2)
|
||||
assert.EqualValues(t, []string{"dingtalk"}, added)
|
||||
assert.EqualValues(t, []string(nil), removed)
|
||||
|
||||
// Add telegram channel
|
||||
cfg3 := config.DefaultConfig()
|
||||
cfg3.Channels.Telegram.Enabled = true
|
||||
cfg3.Channels["telegram"] = &config.Channel{
|
||||
Enabled: true,
|
||||
Type: config.ChannelTelegram,
|
||||
Settings: config.RawNode(`{"enabled":true,"token":"test-token"}`),
|
||||
}
|
||||
results3 := toChannelHashes(cfg3)
|
||||
assert.Equal(t, 1, len(results3))
|
||||
logger.Debugf("results3: %v", results3)
|
||||
added, removed = compareChannels(results2, results3)
|
||||
assert.EqualValues(t, []string{"dingtalk"}, removed)
|
||||
assert.EqualValues(t, []string{"telegram"}, added)
|
||||
cfg3.Channels.Telegram.SetToken("114314")
|
||||
|
||||
// Modify telegram channel — hash should change
|
||||
cfg3.Channels["telegram"] = &config.Channel{
|
||||
Enabled: true,
|
||||
Type: config.ChannelTelegram,
|
||||
Settings: config.RawNode(`{"enabled":true,"token":"114314"}`),
|
||||
}
|
||||
results4 := toChannelHashes(cfg3)
|
||||
assert.Equal(t, 1, len(results4))
|
||||
logger.Debugf("results4: %v", results4)
|
||||
added, removed = compareChannels(results3, results4)
|
||||
assert.EqualValues(t, []string{"telegram"}, removed)
|
||||
assert.EqualValues(t, []string{"telegram"}, added)
|
||||
|
||||
// toChannelConfig with telegram
|
||||
cc, err := toChannelConfig(cfg3, added)
|
||||
assert.NoError(t, err)
|
||||
logger.Debugf("cc: %#v", cc.Telegram)
|
||||
assert.Equal(t, "114314", cc.Telegram.Token.String())
|
||||
assert.Equal(t, true, cc.Telegram.Enabled)
|
||||
bc := cc.Get("telegram")
|
||||
assert.NotNil(t, bc)
|
||||
var tc config.TelegramSettings
|
||||
bc.Decode(&tc)
|
||||
assert.Equal(t, "114314", tc.Token.String())
|
||||
assert.Equal(t, true, bc.Enabled)
|
||||
|
||||
// toChannelConfig with dingtalk (no telegram)
|
||||
cc, err = toChannelConfig(cfg2, added)
|
||||
assert.NoError(t, err)
|
||||
logger.Debugf("cc: %#v", cc.Telegram)
|
||||
assert.Equal(t, "", cc.Telegram.Token.String())
|
||||
assert.Equal(t, false, cc.Telegram.Enabled)
|
||||
bc = cc.Get("telegram")
|
||||
assert.Nil(t, bc)
|
||||
}
|
||||
|
||||
func TestToChannelHashes_SerializationStability(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels["test"] = &config.Channel{
|
||||
Enabled: true,
|
||||
Settings: config.RawNode(`{"enabled":true,"key":"value"}`),
|
||||
}
|
||||
h1 := toChannelHashes(cfg)
|
||||
|
||||
// Same config should produce same hash
|
||||
cfg2 := config.DefaultConfig()
|
||||
cfg2.Channels["test"] = &config.Channel{
|
||||
Enabled: true,
|
||||
Settings: config.RawNode(`{"enabled":true,"key":"value"}`),
|
||||
}
|
||||
h2 := toChannelHashes(cfg2)
|
||||
assert.Equal(t, h1["test"], h2["test"])
|
||||
}
|
||||
|
||||
func TestCompareChannels_NoChanges(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels["a"] = &config.Channel{Enabled: true, Settings: config.RawNode(`{}`)}
|
||||
cfg.Channels["b"] = &config.Channel{Enabled: true, Settings: config.RawNode(`{}`)}
|
||||
h := toChannelHashes(cfg)
|
||||
|
||||
added, removed := compareChannels(h, h)
|
||||
assert.EqualValues(t, []string(nil), added)
|
||||
assert.EqualValues(t, []string(nil), removed)
|
||||
}
|
||||
|
||||
func TestToChannelConfig_EmptyList(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels["test"] = &config.Channel{Enabled: true, Settings: config.RawNode(`{}`)}
|
||||
|
||||
cc, err := toChannelConfig(cfg, []string{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(*cc))
|
||||
}
|
||||
|
||||
func TestToChannelHashes_NonEnabledSkipped(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels["test"] = &config.Channel{Enabled: false, Settings: config.RawNode(`{"enabled":false}`)}
|
||||
|
||||
h := toChannelHashes(cfg)
|
||||
assert.Equal(t, 0, len(h))
|
||||
}
|
||||
|
||||
func TestToChannelHashes_InvalidJSON(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels["test"] = &config.Channel{
|
||||
Enabled: true,
|
||||
Settings: config.RawNode(`invalid-json`),
|
||||
}
|
||||
|
||||
// Should not panic, just skip the invalid entry
|
||||
h := toChannelHashes(cfg)
|
||||
assert.Equal(t, 0, len(h))
|
||||
}
|
||||
|
||||
func TestToChannelHashes_RealWorldChannel(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// Simulate a telegram channel config
|
||||
telegramSettings, _ := json.Marshal(map[string]any{
|
||||
"enabled": true,
|
||||
"token": "123456:ABC-DEF",
|
||||
})
|
||||
cfg.Channels["telegram"] = &config.Channel{
|
||||
Enabled: true,
|
||||
Type: config.ChannelTelegram,
|
||||
Settings: config.RawNode(telegramSettings),
|
||||
}
|
||||
|
||||
h := toChannelHashes(cfg)
|
||||
assert.Equal(t, 1, len(h))
|
||||
assert.Contains(t, h, "telegram")
|
||||
}
|
||||
|
||||
+143
-52
@@ -175,11 +175,11 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
if err := m.bus.PublishOutbound(pubCtx, testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "good",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}); err != nil {
|
||||
})); err != nil {
|
||||
t.Fatalf("PublishOutbound() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -197,6 +197,20 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
|
||||
}
|
||||
return bus.NormalizeOutboundMessage(msg)
|
||||
}
|
||||
|
||||
func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, "")
|
||||
}
|
||||
return bus.NormalizeOutboundMediaMessage(msg)
|
||||
}
|
||||
|
||||
func TestSendWithRetry_Success(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
@@ -212,7 +226,7 @@ func TestSendWithRetry_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -239,7 +253,7 @@ func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -263,7 +277,7 @@ func TestSendWithRetry_PermanentFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -287,7 +301,7 @@ func TestSendWithRetry_NotRunning(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -314,7 +328,7 @@ func TestSendWithRetry_RateLimitRetry(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
@@ -344,7 +358,7 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -370,11 +384,11 @@ func TestSendMedia_Success(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
@@ -397,11 +411,11 @@ func TestSendMedia_PropagatesFailure(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error")
|
||||
}
|
||||
@@ -424,11 +438,11 @@ func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error for unsupported channel")
|
||||
}
|
||||
@@ -454,11 +468,11 @@ func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) {
|
||||
m.workers["test"] = w
|
||||
m.RecordPlaceholder("test", "chat1", "placeholder-1")
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
@@ -491,7 +505,7 @@ func TestSendWithRetry_UnknownError(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -515,7 +529,7 @@ func TestSendWithRetry_ContextCancelled(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
// Cancel context after first Send attempt returns
|
||||
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
@@ -561,7 +575,7 @@ func TestWorkerRateLimiter(t *testing.T) {
|
||||
|
||||
// Enqueue 4 messages
|
||||
for i := range 4 {
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
|
||||
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)})
|
||||
}
|
||||
|
||||
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
|
||||
@@ -586,7 +600,7 @@ func TestWorkerRateLimiter(t *testing.T) {
|
||||
|
||||
func TestNewChannelWorker_DefaultRate(t *testing.T) {
|
||||
ch := &mockChannel{}
|
||||
w := newChannelWorker("unknown_channel", ch)
|
||||
w := newChannelWorker("unknown_channel", ch, "unknown_channel")
|
||||
|
||||
if w.limiter == nil {
|
||||
t.Fatal("expected limiter to be non-nil")
|
||||
@@ -599,10 +613,10 @@ func TestNewChannelWorker_DefaultRate(t *testing.T) {
|
||||
func TestNewChannelWorker_ConfiguredRate(t *testing.T) {
|
||||
ch := &mockChannel{}
|
||||
|
||||
for name, expectedRate := range channelRateConfig {
|
||||
w := newChannelWorker(name, ch)
|
||||
for channelType, expectedRate := range channelRateConfig {
|
||||
w := newChannelWorker(channelType, ch, channelType)
|
||||
if w.limiter.Limit() != rate.Limit(expectedRate) {
|
||||
t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit())
|
||||
t.Fatalf("channel %s: expected rate %v, got %v", channelType, expectedRate, w.limiter.Limit())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -637,7 +651,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
// Send a message that should be split
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
|
||||
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -678,7 +692,7 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
@@ -738,7 +752,7 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
|
||||
// Register placeholder
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !edited {
|
||||
@@ -768,7 +782,7 @@ func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) {
|
||||
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if edited {
|
||||
@@ -827,7 +841,7 @@ func TestPreSend_TypingStopCalled(t *testing.T) {
|
||||
stopCalled = true
|
||||
})
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -844,7 +858,7 @@ func TestPreSend_NoRegisteredState(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if edited {
|
||||
@@ -874,7 +888,7 @@ func TestPreSend_TypingAndPlaceholder(t *testing.T) {
|
||||
})
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -938,7 +952,7 @@ func TestRecordTypingStop_ReplacesExistingStop(t *testing.T) {
|
||||
t.Fatalf("expected replacement typing stop to stay active until preSend, got %d calls", newStopCalls)
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.preSend(context.Background(), "test", msg, &mockChannel{})
|
||||
|
||||
if newStopCalls != 1 {
|
||||
@@ -972,7 +986,7 @@ func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) {
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.sendWithRetry(context.Background(), "test", w, msg)
|
||||
|
||||
if sendCalled {
|
||||
@@ -1135,7 +1149,7 @@ func TestPreSendStillWorksWithWrappedTypes(t *testing.T) {
|
||||
})
|
||||
m.RecordPlaceholder("test", "chat1", "ph_id")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -1222,7 +1236,7 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
worker := newChannelWorker("mock", mockCh)
|
||||
worker := newChannelWorker("mock", mockCh, "mock")
|
||||
mgr.channels["mock"] = mockCh
|
||||
mgr.workers["mock"] = worker
|
||||
|
||||
@@ -1238,11 +1252,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
|
||||
|
||||
// Transcription feedback arrives first — it should consume the placeholder
|
||||
// and be delivered via EditMessage, not Send.
|
||||
msgTranscript := bus.OutboundMessage{
|
||||
msgTranscript := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "mock",
|
||||
ChatID: "chat-1",
|
||||
Content: "Transcript: hello",
|
||||
}
|
||||
})
|
||||
mgr.sendWithRetry(ctx, "mock", worker, msgTranscript)
|
||||
|
||||
if mockCh.editedMessages != 1 {
|
||||
@@ -1258,11 +1272,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
// Final LLM response arrives — no placeholder left, so it goes through Send
|
||||
msgFinal := bus.OutboundMessage{
|
||||
msgFinal := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "mock",
|
||||
ChatID: "chat-1",
|
||||
Content: "Final Answer",
|
||||
}
|
||||
})
|
||||
mgr.sendWithRetry(ctx, "mock", worker, msgFinal)
|
||||
|
||||
if len(mockCh.sentMessages) != 1 {
|
||||
@@ -1288,12 +1302,12 @@ func TestSendMessage_Synchronous(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello world",
|
||||
ReplyToMessageID: "msg-456",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1315,11 +1329,11 @@ func TestSendMessage_Synchronous(t *testing.T) {
|
||||
func TestSendMessage_UnknownChannel(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "nonexistent",
|
||||
ChatID: "123",
|
||||
Content: "hello",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err == nil {
|
||||
@@ -1336,11 +1350,11 @@ func TestSendMessage_NoWorker(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
// No worker registered
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err == nil {
|
||||
@@ -1369,11 +1383,11 @@ func TestSendMessage_WithRetry(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "retry me",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1385,6 +1399,46 @@ func TestSendMessage_WithRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_ContextOnlyUsesContextAddressing(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var received []bus.OutboundMessage
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
|
||||
received = append(received, msg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Context: bus.NewOutboundContext("test", "123", "msg-9"),
|
||||
Content: "hello",
|
||||
})
|
||||
|
||||
if err := m.SendMessage(context.Background(), msg); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 message sent, got %d", len(received))
|
||||
}
|
||||
if received[0].Channel != "test" || received[0].ChatID != "123" {
|
||||
t.Fatalf("expected mirrored legacy address, got %+v", received[0])
|
||||
}
|
||||
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "123" {
|
||||
t.Fatalf("expected context address to be preserved, got %+v", received[0].Context)
|
||||
}
|
||||
if received[0].ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected reply_to_message_id msg-9, got %q", received[0].ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
@@ -1406,11 +1460,11 @@ func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello world",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1422,6 +1476,43 @@ func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_ContextOnlyUsesContextAddressing(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var received []bus.OutboundMediaMessage
|
||||
ch := &mockMediaChannel{
|
||||
sendMediaFn: func(_ context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
received = append(received, msg)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Context: bus.NewOutboundContext("test", "media-chat", ""),
|
||||
Parts: []bus.MediaPart{{Type: "image", Ref: "media://1"}},
|
||||
})
|
||||
|
||||
if err := m.SendMedia(context.Background(), msg); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 media message sent, got %d", len(received))
|
||||
}
|
||||
if received[0].Channel != "test" || received[0].ChatID != "media-chat" {
|
||||
t.Fatalf("expected mirrored legacy media address, got %+v", received[0])
|
||||
}
|
||||
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "media-chat" {
|
||||
t.Fatalf("expected media context address to be preserved, got %+v", received[0].Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_PreservesOrdering(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
@@ -1441,12 +1532,12 @@ func TestSendMessage_PreservesOrdering(t *testing.T) {
|
||||
m.workers["test"] = w
|
||||
|
||||
// Send two messages sequentially — they must arrive in order
|
||||
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
|
||||
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test", ChatID: "1", Content: "first",
|
||||
})
|
||||
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
|
||||
}))
|
||||
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test", ChatID: "1", Content: "second",
|
||||
})
|
||||
}))
|
||||
|
||||
if len(order) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(order))
|
||||
|
||||
@@ -9,12 +9,30 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
matrixCfg := cfg.Channels.Matrix
|
||||
cryptoDatabasePath := matrixCfg.CryptoDatabasePath
|
||||
if cryptoDatabasePath == "" {
|
||||
cryptoDatabasePath = filepath.Join(cfg.WorkspacePath(), "matrix")
|
||||
}
|
||||
return NewMatrixChannel(matrixCfg, b, cryptoDatabasePath)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelMatrix,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.MatrixSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
cryptoDatabasePath := c.CryptoDatabasePath
|
||||
if cryptoDatabasePath == "" {
|
||||
cryptoDatabasePath = filepath.Join(cfg.WorkspacePath(), "matrix")
|
||||
}
|
||||
ch, err := NewMatrixChannel(bc, c, b, cryptoDatabasePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelName != config.ChannelMatrix {
|
||||
ch.SetName(channelName)
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -174,9 +174,10 @@ func (s *typingSession) stop() {
|
||||
// MatrixChannel implements the Channel interface for Matrix.
|
||||
type MatrixChannel struct {
|
||||
*channels.BaseChannel
|
||||
bc *config.Channel
|
||||
|
||||
client *mautrix.Client
|
||||
config config.MatrixConfig
|
||||
config *config.MatrixSettings
|
||||
syncer *mautrix.DefaultSyncer
|
||||
|
||||
ctx context.Context
|
||||
@@ -194,7 +195,8 @@ type MatrixChannel struct {
|
||||
}
|
||||
|
||||
func NewMatrixChannel(
|
||||
cfg config.MatrixConfig,
|
||||
bc *config.Channel,
|
||||
cfg *config.MatrixSettings,
|
||||
messageBus *bus.MessageBus,
|
||||
cryptoDatabasePath string,
|
||||
) (*MatrixChannel, error) {
|
||||
@@ -228,14 +230,15 @@ func NewMatrixChannel(
|
||||
"matrix",
|
||||
cfg,
|
||||
messageBus,
|
||||
cfg.AllowFrom,
|
||||
bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(65536),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &MatrixChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
client: client,
|
||||
config: cfg,
|
||||
syncer: syncer,
|
||||
@@ -570,7 +573,7 @@ func (c *MatrixChannel) StartTyping(ctx context.Context, chatID string) (func(),
|
||||
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
if !c.config.Placeholder.Enabled {
|
||||
if !c.bc.Placeholder.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -579,7 +582,7 @@ func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (str
|
||||
return "", fmt.Errorf("matrix room ID is empty")
|
||||
}
|
||||
|
||||
text := c.config.Placeholder.GetRandomText()
|
||||
text := c.bc.Placeholder.GetRandomText()
|
||||
|
||||
resp, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{
|
||||
MsgType: event.MsgNotice,
|
||||
@@ -720,8 +723,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
|
||||
logger.DebugCF("matrix", "Ignoring group message by trigger rules", map[string]any{
|
||||
"room_id": roomID,
|
||||
"is_mentioned": isMentioned,
|
||||
"mention_only": c.config.GroupTrigger.MentionOnly,
|
||||
"prefixes": c.config.GroupTrigger.Prefixes,
|
||||
"mention_only": c.bc.GroupTrigger.MentionOnly,
|
||||
"prefixes": c.bc.GroupTrigger.Prefixes,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -736,10 +739,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := senderID
|
||||
if isGroup {
|
||||
peerKind = "group"
|
||||
peerID = roomID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -752,17 +753,19 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
|
||||
metadata["reply_to_msg_id"] = replyTo.String()
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.baseContext(),
|
||||
bus.Peer{Kind: peerKind, ID: peerID},
|
||||
evt.ID.String(),
|
||||
senderID,
|
||||
roomID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "matrix",
|
||||
ChatID: roomID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: evt.ID.String(),
|
||||
Raw: metadata,
|
||||
}
|
||||
if replyTo := msgEvt.GetRelatesTo().GetReplyTo(); replyTo != "" {
|
||||
inboundCtx.ReplyToMessageID = replyTo.String()
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.baseContext(), roomID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// decryptEvent decrypts an encrypted event and returns the decrypted message event content.
|
||||
|
||||
@@ -437,9 +437,9 @@ func TestMarkdownToHTML(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMessageContent(t *testing.T) {
|
||||
richtext := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "richtext"}}
|
||||
plain := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "plain"}}
|
||||
defaultt := &MatrixChannel{config: config.MatrixConfig{}}
|
||||
richtext := &MatrixChannel{config: &config.MatrixSettings{MessageFormat: "richtext"}}
|
||||
plain := &MatrixChannel{config: &config.MatrixSettings{MessageFormat: "plain"}}
|
||||
defaultt := &MatrixChannel{config: &config.MatrixSettings{}}
|
||||
|
||||
for _, c := range []*MatrixChannel{richtext, defaultt} {
|
||||
mc := c.messageContent("**hi**")
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewOneBotChannel(cfg.Channels.OneBot, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelOneBot,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.OneBotSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewOneBotChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
|
||||
type OneBotChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.OneBotConfig
|
||||
config *config.OneBotSettings
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -96,10 +96,14 @@ type oneBotMessageSegment struct {
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
|
||||
base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
func NewOneBotChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.OneBotSettings,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*OneBotChannel, error) {
|
||||
base := channels.NewBaseChannel("onebot", cfg, messageBus, bc.AllowFrom,
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
const dedupSize = 1024
|
||||
@@ -991,8 +995,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
|
||||
senderID := strconv.FormatInt(userID, 10)
|
||||
var chatID string
|
||||
|
||||
var peer bus.Peer
|
||||
var contextChatID string
|
||||
var contextChatType string
|
||||
|
||||
metadata := map[string]string{}
|
||||
|
||||
@@ -1003,12 +1007,14 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
switch raw.MessageType {
|
||||
case "private":
|
||||
chatID = "private:" + senderID
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
contextChatID = senderID
|
||||
contextChatType = "direct"
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(groupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
peer = bus.Peer{Kind: "group", ID: groupIDStr}
|
||||
contextChatID = groupIDStr
|
||||
contextChatType = "group"
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(sender.UserID)
|
||||
@@ -1072,7 +1078,18 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: contextChatID,
|
||||
ChatType: contextChatType,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isBotMentioned,
|
||||
ReplyToMessageID: parsed.ReplyTo,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, parsed.Media, inboundCtx, senderInfo)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
|
||||
+23
-11
@@ -22,7 +22,7 @@ import (
|
||||
// PicoClientChannel connects to a remote Pico Protocol WebSocket server.
|
||||
type PicoClientChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.PicoClientConfig
|
||||
config *config.PicoClientSettings
|
||||
conn *picoConn
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
@@ -31,14 +31,15 @@ type PicoClientChannel struct {
|
||||
|
||||
// NewPicoClientChannel creates a new Pico Protocol client channel.
|
||||
func NewPicoClientChannel(
|
||||
cfg config.PicoClientConfig,
|
||||
bc *config.Channel,
|
||||
cfg *config.PicoClientSettings,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*PicoClientChannel, error) {
|
||||
if cfg.URL == "" {
|
||||
return nil, fmt.Errorf("pico_client url is required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("pico_client", cfg, messageBus, cfg.AllowFrom)
|
||||
base := channels.NewBaseChannel("pico_client", cfg, messageBus, bc.AllowFrom)
|
||||
|
||||
return &PicoClientChannel{
|
||||
BaseChannel: base,
|
||||
@@ -242,7 +243,11 @@ func (c *PicoClientChannel) handleInbound(pc *picoConn, msg PicoMessage) {
|
||||
}
|
||||
|
||||
func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
content, _ := msg.Payload["content"].(string)
|
||||
if isThoughtPayload(msg.Payload) {
|
||||
return
|
||||
}
|
||||
|
||||
content, _ := msg.Payload[PayloadKeyContent].(string)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
@@ -254,8 +259,6 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
|
||||
chatID := "pico_client:" + sessionID
|
||||
senderID := "pico-remote"
|
||||
peer := bus.Peer{Kind: "direct", ID: chatID}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "pico_client",
|
||||
PlatformID: senderID,
|
||||
@@ -266,10 +269,19 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{
|
||||
"platform": "pico_client",
|
||||
"session_id": sessionID,
|
||||
}, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "pico_client",
|
||||
ChatID: chatID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Raw: map[string]string{
|
||||
"platform": "pico_client",
|
||||
"session_id": sessionID,
|
||||
},
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// Send sends a message to the remote server.
|
||||
@@ -285,7 +297,7 @@ func (c *PicoClientChannel) Send(ctx context.Context, msg bus.OutboundMessage) (
|
||||
}
|
||||
|
||||
outMsg := newMessage(TypeMessageSend, map[string]any{
|
||||
"content": msg.Content,
|
||||
PayloadKeyContent: msg.Content,
|
||||
})
|
||||
outMsg.SessionID = strings.TrimPrefix(msg.ChatID, "pico_client:")
|
||||
return nil, pc.writeJSON(outMsg)
|
||||
|
||||
@@ -18,7 +18,8 @@ import (
|
||||
)
|
||||
|
||||
func TestNewPicoClientChannel_MissingURL(t *testing.T) {
|
||||
_, err := NewPicoClientChannel(config.PicoClientConfig{}, bus.NewMessageBus())
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
_, err := NewPicoClientChannel(bc, &config.PicoClientSettings{}, bus.NewMessageBus())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing URL")
|
||||
}
|
||||
@@ -28,7 +29,8 @@ func TestNewPicoClientChannel_MissingURL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewPicoClientChannel_OK(t *testing.T) {
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: "ws://localhost:9999/ws",
|
||||
}, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
@@ -40,7 +42,8 @@ func TestNewPicoClientChannel_OK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSend_NotRunning(t *testing.T) {
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: "ws://localhost:9999/ws",
|
||||
}, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
@@ -104,7 +107,8 @@ func TestClientChannel_ConnectAndSend(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: wsURL(srv.URL),
|
||||
Token: *config.NewSecureString("test-token"),
|
||||
SessionID: "sess-1",
|
||||
@@ -137,7 +141,8 @@ func TestClientChannel_AuthFailure(t *testing.T) {
|
||||
srv := testServer(t, "correct-token")
|
||||
defer srv.Close()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: wsURL(srv.URL),
|
||||
Token: *config.NewSecureString("wrong-token"),
|
||||
}, bus.NewMessageBus())
|
||||
@@ -161,7 +166,8 @@ func TestClientChannel_ReceivesServerMessage(t *testing.T) {
|
||||
|
||||
mb := bus.NewMessageBus()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: wsURL(srv.URL),
|
||||
SessionID: "sess-echo",
|
||||
ReadTimeout: 10,
|
||||
@@ -203,7 +209,8 @@ func TestClientChannel_StartTyping(t *testing.T) {
|
||||
srv := testServer(t, "")
|
||||
defer srv.Close()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: wsURL(srv.URL),
|
||||
SessionID: "sess-type",
|
||||
ReadTimeout: 10,
|
||||
@@ -231,7 +238,8 @@ func TestSend_ClosedConnection(t *testing.T) {
|
||||
srv := testServer(t, "")
|
||||
defer srv.Close()
|
||||
|
||||
ch, err := NewPicoClientChannel(config.PicoClientConfig{
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: wsURL(srv.URL),
|
||||
SessionID: "sess-close",
|
||||
ReadTimeout: 10,
|
||||
@@ -279,7 +287,8 @@ func TestParseInlineImageMedia_Valid(t *testing.T) {
|
||||
|
||||
func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewPicoChannel(config.PicoConfig{
|
||||
bc := &config.Channel{Type: "pico", Enabled: true}
|
||||
ch, err := NewPicoChannel(bc, &config.PicoSettings{
|
||||
Token: *config.NewSecureString("test-token"),
|
||||
}, mb)
|
||||
if err != nil {
|
||||
@@ -316,3 +325,68 @@ func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
|
||||
t.Fatal("timed out waiting for inbound media message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsThoughtPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload map[string]any
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "explicit thought bool",
|
||||
payload: map[string]any{PayloadKeyThought: true},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "thought false",
|
||||
payload: map[string]any{PayloadKeyThought: false},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "thought string ignored",
|
||||
payload: map[string]any{PayloadKeyThought: "true"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "default normal",
|
||||
payload: map[string]any{PayloadKeyContent: "hello"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isThoughtPayload(tt.payload); got != tt.want {
|
||||
t.Fatalf("isThoughtPayload() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: "ws://localhost:8080/ws",
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("NewPicoClientChannel() error = %v", err)
|
||||
}
|
||||
|
||||
ch.ctx = context.Background()
|
||||
pc := &picoConn{sessionID: "sess-thought"}
|
||||
|
||||
ch.handleServerMessage(pc, PicoMessage{
|
||||
Type: TypeMessageCreate,
|
||||
Payload: map[string]any{
|
||||
PayloadKeyContent: "internal reasoning",
|
||||
PayloadKeyThought: true,
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case msg := <-mb.InboundChan():
|
||||
t.Fatalf("expected no inbound publish for thought payload, got %+v", msg)
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,48 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewPicoChannel(cfg.Channels.Pico, b)
|
||||
})
|
||||
channels.RegisterFactory("pico_client", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewPicoClientChannel(cfg.Channels.PicoClient, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelPico,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.PicoSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
ch, err := NewPicoChannel(bc, c, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelName != config.ChannelPico {
|
||||
ch.SetName(channelName)
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
channels.RegisterFactory(
|
||||
config.ChannelPicoClient,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.PicoClientSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
ch, err := NewPicoClientChannel(bc, c, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelName != config.ChannelPicoClient {
|
||||
ch.SetName(channelName)
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
+34
-11
@@ -39,6 +39,13 @@ var allowedInlineImageMIMETypes = map[string]struct{}{
|
||||
"image/bmp": {},
|
||||
}
|
||||
|
||||
func outboundMessageIsThought(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindThought)
|
||||
}
|
||||
|
||||
// writeJSON sends a JSON message to the connection with write locking.
|
||||
func (pc *picoConn) writeJSON(v any) error {
|
||||
if pc.closed.Load() {
|
||||
@@ -63,7 +70,8 @@ func (pc *picoConn) close() {
|
||||
// It serves as the reference implementation for all optional capability interfaces.
|
||||
type PicoChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.PicoConfig
|
||||
bc *config.Channel
|
||||
config *config.PicoSettings
|
||||
upgrader websocket.Upgrader
|
||||
connections map[string]*picoConn // connID -> *picoConn
|
||||
sessionConnections map[string]map[string]*picoConn // sessionID -> connID -> *picoConn
|
||||
@@ -73,12 +81,16 @@ type PicoChannel struct {
|
||||
}
|
||||
|
||||
// NewPicoChannel creates a new Pico Protocol channel.
|
||||
func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) {
|
||||
func NewPicoChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.PicoSettings,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*PicoChannel, error) {
|
||||
if cfg.Token.String() == "" {
|
||||
return nil, fmt.Errorf("pico token is required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom)
|
||||
base := channels.NewBaseChannel("pico", cfg, messageBus, bc.AllowFrom)
|
||||
|
||||
allowOrigins := cfg.AllowOrigins
|
||||
checkOrigin := func(r *http.Request) bool {
|
||||
@@ -96,6 +108,7 @@ func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoCha
|
||||
|
||||
return &PicoChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
config: cfg,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: checkOrigin,
|
||||
@@ -247,9 +260,11 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
isThought := outboundMessageIsThought(msg)
|
||||
|
||||
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
||||
"content": msg.Content,
|
||||
PayloadKeyContent: msg.Content,
|
||||
PayloadKeyThought: isThought,
|
||||
})
|
||||
|
||||
return nil, c.broadcastToSession(msg.ChatID, outMsg)
|
||||
@@ -280,16 +295,17 @@ func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), e
|
||||
// It sends a placeholder message via the Pico Protocol that will later be
|
||||
// edited to the actual response via EditMessage (channels.MessageEditor).
|
||||
func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
if !c.config.Placeholder.Enabled {
|
||||
if !c.bc.Placeholder.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
text := c.config.Placeholder.GetRandomText()
|
||||
text := c.bc.Placeholder.GetRandomText()
|
||||
|
||||
msgID := uuid.New().String()
|
||||
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
||||
"content": text,
|
||||
"message_id": msgID,
|
||||
PayloadKeyContent: text,
|
||||
PayloadKeyThought: false,
|
||||
"message_id": msgID,
|
||||
})
|
||||
|
||||
if err := c.broadcastToSession(chatID, outMsg); err != nil {
|
||||
@@ -562,8 +578,6 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
||||
chatID := "pico:" + sessionID
|
||||
senderID := "pico-user"
|
||||
|
||||
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"platform": "pico",
|
||||
"session_id": sessionID,
|
||||
@@ -586,7 +600,16 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, media, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "pico",
|
||||
ChatID: chatID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, media, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// truncate truncates a string to maxLen runes.
|
||||
|
||||
@@ -15,9 +15,10 @@ import (
|
||||
func newTestPicoChannel(t *testing.T) *PicoChannel {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.PicoConfig{}
|
||||
bc := &config.Channel{Type: config.ChannelPico, Enabled: true}
|
||||
cfg := &config.PicoSettings{}
|
||||
cfg.SetToken("test-token")
|
||||
ch, err := NewPicoChannel(cfg, bus.NewMessageBus())
|
||||
ch, err := NewPicoChannel(bc, cfg, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatalf("NewPicoChannel: %v", err)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ const (
|
||||
TypePong = "pong"
|
||||
|
||||
PicoTokenPrefix = "pico-"
|
||||
|
||||
PayloadKeyContent = "content"
|
||||
PayloadKeyThought = "thought"
|
||||
|
||||
MessageKindThought = "thought"
|
||||
)
|
||||
|
||||
// PicoMessage is the wire format for all Pico Protocol messages.
|
||||
@@ -39,6 +44,11 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
|
||||
}
|
||||
}
|
||||
|
||||
func isThoughtPayload(payload map[string]any) bool {
|
||||
thought, _ := payload[PayloadKeyThought].(bool)
|
||||
return thought
|
||||
}
|
||||
|
||||
func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage {
|
||||
payload := map[string]any{
|
||||
"code": code,
|
||||
|
||||
+15
-3
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewQQChannel(cfg.Channels.QQ, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelQQ,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.QQSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewQQChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
+40
-27
@@ -56,7 +56,8 @@ type qqAPI interface {
|
||||
|
||||
type QQChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.QQConfig
|
||||
bc *config.Channel
|
||||
config *config.QQSettings
|
||||
api qqAPI
|
||||
tokenSource oauth2.TokenSource
|
||||
ctx context.Context
|
||||
@@ -82,15 +83,16 @@ type QQChannel struct {
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) {
|
||||
base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom,
|
||||
func NewQQChannel(bc *config.Channel, cfg *config.QQSettings, messageBus *bus.MessageBus) (*QQChannel, error) {
|
||||
base := channels.NewBaseChannel("qq", cfg, messageBus, bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(cfg.MaxMessageLength),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &QQChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
config: cfg,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -161,8 +163,8 @@ func (c *QQChannel) Start(ctx context.Context) error {
|
||||
|
||||
// Pre-register reasoning_channel_id as group chat if configured,
|
||||
// so outbound-only destinations are routed correctly.
|
||||
if c.config.ReasoningChannelID != "" {
|
||||
c.chatType.Store(c.config.ReasoningChannelID, "group")
|
||||
if c.bc.ReasoningChannelID != "" {
|
||||
c.chatType.Store(c.bc.ReasoningChannelID, "group")
|
||||
}
|
||||
|
||||
c.SetRunning(true)
|
||||
@@ -588,12 +590,22 @@ func qqFileType(partType string) uint64 {
|
||||
}
|
||||
|
||||
func (c *QQChannel) maxBase64FileSizeBytes() int64 {
|
||||
if c.config == nil {
|
||||
return 0
|
||||
}
|
||||
if c.config.MaxBase64FileSizeMiB <= 0 {
|
||||
return 0
|
||||
}
|
||||
return c.config.MaxBase64FileSizeMiB * bytesPerMiB
|
||||
}
|
||||
|
||||
func (c *QQChannel) accountID() string {
|
||||
if c.config == nil {
|
||||
return ""
|
||||
}
|
||||
return c.config.AppID
|
||||
}
|
||||
|
||||
// handleC2CMessage handles QQ private messages.
|
||||
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
|
||||
@@ -647,17 +659,17 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
metadata := map[string]string{
|
||||
"account_id": senderID,
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.accountID(),
|
||||
ChatID: senderID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: data.ID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
bus.Peer{Kind: "direct", ID: senderID},
|
||||
data.ID,
|
||||
senderID,
|
||||
senderID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, senderID, content, mediaPaths, inboundCtx, sender)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -725,17 +737,18 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
"account_id": senderID,
|
||||
"group_id": data.GroupID,
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.accountID(),
|
||||
ChatID: data.GroupID,
|
||||
ChatType: "group",
|
||||
SenderID: senderID,
|
||||
MessageID: data.ID,
|
||||
Mentioned: true,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
bus.Peer{Kind: "group", ID: data.GroupID},
|
||||
data.ID,
|
||||
senderID,
|
||||
data.GroupID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, data.GroupID, content, mediaPaths, inboundCtx, sender)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,8 +54,8 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message")
|
||||
}
|
||||
if inbound.Metadata["account_id"] != "7750283E123456" {
|
||||
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
|
||||
if inbound.Context.Raw["account_id"] != "7750283E123456" {
|
||||
t.Fatalf("account_id raw = %q, want %q", inbound.Context.Raw["account_id"], "7750283E123456")
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -165,8 +165,8 @@ func TestHandleGroupATMessage_AttachmentOnlyPublishesMedia(t *testing.T) {
|
||||
if !strings.HasPrefix(inbound.Media[0], "media://") {
|
||||
t.Fatalf("inbound.Media[0] = %q, want media:// ref", inbound.Media[0])
|
||||
}
|
||||
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-1" {
|
||||
t.Fatalf("inbound.Peer = %+v, want group/group-1", inbound.Peer)
|
||||
if inbound.Context.ChatType != "group" {
|
||||
t.Fatalf("inbound.Context.ChatType = %q, want group", inbound.Context.ChatType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,6 +198,7 @@ func TestSendMedia_UploadsLocalFileAsBase64(t *testing.T) {
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: &config.QQSettings{},
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -294,6 +295,7 @@ func assertAudioWAVUploadType(t *testing.T, duration time.Duration, wantFileType
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: &config.QQSettings{},
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -329,6 +331,7 @@ func TestSendMedia_RemoteAudioFallsBackToFileUpload(t *testing.T) {
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: &config.QQSettings{},
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -374,6 +377,7 @@ func TestSendMedia_LocalAudioWithUnknownDurationFallsBackToFileUpload(t *testing
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: &config.QQSettings{},
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -409,6 +413,7 @@ func TestSendMedia_UsesRemoteURLUploadForC2C(t *testing.T) {
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: &config.QQSettings{},
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -481,6 +486,7 @@ func TestSendMedia_LocalFileUploadIncludesStoredFilename(t *testing.T) {
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: &config.QQSettings{},
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -520,6 +526,7 @@ func TestSendMedia_ReturnsSendFailedWithoutMediaStore(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: &config.QQSettings{},
|
||||
api: &fakeQQAPI{},
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
@@ -566,7 +573,7 @@ func TestSendMedia_ReturnsSendFailedWhenLocalFileExceedsBase64MiBLimit(t *testin
|
||||
api := &fakeQQAPI{}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: config.QQConfig{
|
||||
config: &config.QQSettings{
|
||||
MaxBase64FileSizeMiB: 1,
|
||||
},
|
||||
api: api,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -9,7 +10,9 @@ import (
|
||||
|
||||
// ChannelFactory is a constructor function that creates a Channel from config and message bus.
|
||||
// Each channel subpackage registers one or more factories via init().
|
||||
type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
|
||||
// channelName is the config map key for this channel instance (may differ from the channel type).
|
||||
// channelType is the channel type string used to look up the Channel config.
|
||||
type ChannelFactory func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (Channel, error)
|
||||
|
||||
var (
|
||||
factoriesMu sync.RWMutex
|
||||
@@ -23,6 +26,38 @@ func RegisterFactory(name string, f ChannelFactory) {
|
||||
factories[name] = f
|
||||
}
|
||||
|
||||
// RegisterSafeFactory is a convenience wrapper that handles GetDecoded() error checking
|
||||
// and type assertion, reducing boilerplate in channel init() functions.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// func init() {
|
||||
// channels.RegisterSafeFactory(config.ChannelTelegram,
|
||||
// func(bc *config.Channel, c *config.TelegramSettings, b *bus.MessageBus) (channels.Channel, error) {
|
||||
// return NewTelegramChannel(bc, c, b)
|
||||
// })
|
||||
// }
|
||||
func RegisterSafeFactory[S any](
|
||||
channelType string,
|
||||
ctor func(bc *config.Channel, settings *S, bus *bus.MessageBus) (Channel, error),
|
||||
) {
|
||||
RegisterFactory(channelType, func(channelName, _ string, cfg *config.Config, b *bus.MessageBus) (Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
if bc == nil {
|
||||
return nil, fmt.Errorf("channel %q: config not found", channelName)
|
||||
}
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("channel %q: failed to decode settings: %w", channelName, err)
|
||||
}
|
||||
settings, ok := decoded.(*S)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("channel %q: expected %T settings, got %T", channelName, (*S)(nil), decoded)
|
||||
}
|
||||
return ctor(bc, settings, b)
|
||||
})
|
||||
}
|
||||
|
||||
// getFactory looks up a channel factory by name.
|
||||
func getFactory(name string) (ChannelFactory, bool) {
|
||||
factoriesMu.RLock()
|
||||
@@ -30,3 +65,14 @@ func getFactory(name string) (ChannelFactory, bool) {
|
||||
f, ok := factories[name]
|
||||
return f, ok
|
||||
}
|
||||
|
||||
// GetRegisteredFactoryNames returns a slice of all registered channel factory names.
|
||||
func GetRegisteredFactoryNames() []string {
|
||||
factoriesMu.RLock()
|
||||
defer factoriesMu.RUnlock()
|
||||
names := make([]string, 0, len(factories))
|
||||
for name := range factories {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewSlackChannel(cfg.Channels.Slack, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelSlack,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.SlackSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewSlackChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
+92
-33
@@ -21,7 +21,7 @@ import (
|
||||
|
||||
type SlackChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.SlackConfig
|
||||
config *config.SlackSettings
|
||||
api *slack.Client
|
||||
socketClient *socketmode.Client
|
||||
botUserID string
|
||||
@@ -36,7 +36,11 @@ type slackMessageRef struct {
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
|
||||
func NewSlackChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.SlackSettings,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*SlackChannel, error) {
|
||||
if cfg.BotToken.String() == "" || cfg.AppToken.String() == "" {
|
||||
return nil, fmt.Errorf("slack bot_token and app_token are required")
|
||||
}
|
||||
@@ -48,10 +52,10 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack
|
||||
|
||||
socketClient := socketmode.New(api)
|
||||
|
||||
base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom,
|
||||
base := channels.NewBaseChannel("slack", cfg, messageBus, bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(40000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &SlackChannel{
|
||||
@@ -113,7 +117,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
channelID, threadTS := parseSlackChatID(msg.ChatID)
|
||||
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if channelID == "" {
|
||||
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
|
||||
}
|
||||
@@ -135,7 +139,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
|
||||
return nil, fmt.Errorf("slack send: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
|
||||
if ref, ok := c.pendingAcks.LoadAndDelete(deliveryChatID); ok {
|
||||
msgRef := ref.(slackMessageRef)
|
||||
c.api.AddReaction("white_check_mark", slack.ItemRef{
|
||||
Channel: msgRef.ChannelID,
|
||||
@@ -157,7 +161,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
channelID, _ := parseSlackChatID(msg.ChatID)
|
||||
_, channelID, threadTS := resolveSlackMediaOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if channelID == "" {
|
||||
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
|
||||
}
|
||||
@@ -188,10 +192,11 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
}
|
||||
|
||||
_, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{
|
||||
Channel: channelID,
|
||||
File: localPath,
|
||||
Filename: filename,
|
||||
Title: title,
|
||||
Channel: channelID,
|
||||
ThreadTimestamp: threadTS,
|
||||
File: localPath,
|
||||
Filename: filename,
|
||||
Title: title,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Failed to upload media", map[string]any{
|
||||
@@ -356,14 +361,10 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
}
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
@@ -379,7 +380,22 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
"has_thread": threadTS != "",
|
||||
})
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.teamID,
|
||||
ChatID: channelID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: messageTS,
|
||||
SpaceID: c.teamID,
|
||||
SpaceType: "workspace",
|
||||
Raw: metadata,
|
||||
}
|
||||
if threadTS != "" {
|
||||
inboundCtx.TopicID = threadTS
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
@@ -427,14 +443,10 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
}
|
||||
|
||||
mentionPeerKind := "channel"
|
||||
mentionPeerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
mentionPeerKind = "direct"
|
||||
mentionPeerID = senderID
|
||||
}
|
||||
|
||||
mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
@@ -443,8 +455,21 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
"is_mention": "true",
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.teamID,
|
||||
ChatID: channelID,
|
||||
ChatType: mentionPeerKind,
|
||||
TopicID: threadTS,
|
||||
SenderID: senderID,
|
||||
MessageID: messageTS,
|
||||
SpaceID: c.teamID,
|
||||
SpaceType: "workspace",
|
||||
Mentioned: true,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, mentionSender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
@@ -491,18 +516,22 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
"command": cmd.Command,
|
||||
"text": utils.Truncate(content, 50),
|
||||
})
|
||||
peerKind := "channel"
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.teamID,
|
||||
ChatID: channelID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
SpaceID: c.teamID,
|
||||
SpaceType: "workspace",
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: "channel", ID: channelID},
|
||||
"",
|
||||
senderID,
|
||||
chatID,
|
||||
content,
|
||||
nil,
|
||||
metadata,
|
||||
cmdSender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, cmdSender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||
@@ -537,3 +566,33 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
|
||||
}
|
||||
return channelID, threadTS
|
||||
}
|
||||
|
||||
func resolveSlackOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
|
||||
deliveryChatID := strings.TrimSpace(chatID)
|
||||
if deliveryChatID == "" && outboundCtx != nil {
|
||||
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
|
||||
}
|
||||
channelID, threadTS := parseSlackChatID(deliveryChatID)
|
||||
if threadTS == "" && outboundCtx != nil {
|
||||
threadTS = strings.TrimSpace(outboundCtx.TopicID)
|
||||
if threadTS != "" && channelID != "" {
|
||||
deliveryChatID = channelID + "/" + threadTS
|
||||
}
|
||||
}
|
||||
return deliveryChatID, channelID, threadTS
|
||||
}
|
||||
|
||||
func resolveSlackMediaOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
|
||||
deliveryChatID := strings.TrimSpace(chatID)
|
||||
if deliveryChatID == "" && outboundCtx != nil {
|
||||
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
|
||||
}
|
||||
channelID, threadTS := parseSlackChatID(deliveryChatID)
|
||||
if threadTS == "" && outboundCtx != nil {
|
||||
threadTS = strings.TrimSpace(outboundCtx.TopicID)
|
||||
if threadTS != "" && channelID != "" {
|
||||
deliveryChatID = channelID + "/" + threadTS
|
||||
}
|
||||
}
|
||||
return deliveryChatID, channelID, threadTS
|
||||
}
|
||||
|
||||
@@ -53,6 +53,24 @@ func TestParseSlackChatID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSlackOutboundTarget_PrefersContextTopicID(t *testing.T) {
|
||||
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget("C123456", &bus.InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C123456",
|
||||
TopicID: "1234567890.123456",
|
||||
})
|
||||
|
||||
if deliveryChatID != "C123456/1234567890.123456" {
|
||||
t.Fatalf("deliveryChatID = %q, want %q", deliveryChatID, "C123456/1234567890.123456")
|
||||
}
|
||||
if channelID != "C123456" {
|
||||
t.Fatalf("channelID = %q, want %q", channelID, "C123456")
|
||||
}
|
||||
if threadTS != "1234567890.123456" {
|
||||
t.Fatalf("threadTS = %q, want %q", threadTS, "1234567890.123456")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBotMention(t *testing.T) {
|
||||
ch := &SlackChannel{botUserID: "U12345BOT"}
|
||||
|
||||
@@ -100,32 +118,32 @@ func TestStripBotMention(t *testing.T) {
|
||||
|
||||
func TestNewSlackChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
bc := &config.Channel{Type: "slack", Enabled: true}
|
||||
|
||||
t.Run("missing bot token", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{}
|
||||
cfg := &config.SlackSettings{}
|
||||
cfg.AppToken = *config.NewSecureString("xapp-test")
|
||||
_, err := NewSlackChannel(cfg, msgBus)
|
||||
_, err := NewSlackChannel(bc, cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing bot_token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing app token", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{}
|
||||
cfg := &config.SlackSettings{}
|
||||
cfg.BotToken = *config.NewSecureString("xoxb-test")
|
||||
_, err := NewSlackChannel(cfg, msgBus)
|
||||
_, err := NewSlackChannel(bc, cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing app_token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
AllowFrom: []string{"U123"},
|
||||
}
|
||||
cfg := &config.SlackSettings{}
|
||||
cfg.BotToken = *config.NewSecureString("xoxb-test")
|
||||
cfg.AppToken = *config.NewSecureString("xapp-test")
|
||||
ch, err := NewSlackChannel(cfg, msgBus)
|
||||
bc := &config.Channel{Type: "slack", Enabled: true, AllowFrom: []string{"U123"}}
|
||||
ch, err := NewSlackChannel(bc, cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -142,24 +160,22 @@ func TestSlackChannelIsAllowed(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("empty allowlist allows all", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
AllowFrom: []string{},
|
||||
}
|
||||
bc := &config.Channel{Type: config.ChannelSlack, Enabled: true, AllowFrom: []string{}}
|
||||
cfg := &config.SlackSettings{}
|
||||
cfg.BotToken = *config.NewSecureString("xoxb-test")
|
||||
cfg.AppToken = *config.NewSecureString("xapp-test")
|
||||
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||
ch, _ := NewSlackChannel(bc, cfg, msgBus)
|
||||
if !ch.IsAllowed("U_ANYONE") {
|
||||
t.Error("empty allowlist should allow all users")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allowlist restricts users", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
AllowFrom: []string{"U_ALLOWED"},
|
||||
}
|
||||
bc := &config.Channel{Type: config.ChannelSlack, Enabled: true, AllowFrom: []string{"U_ALLOWED"}}
|
||||
cfg := &config.SlackSettings{}
|
||||
cfg.BotToken = *config.NewSecureString("xoxb-test")
|
||||
cfg.AppToken = *config.NewSecureString("xapp-test")
|
||||
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||
ch, _ := NewSlackChannel(bc, cfg, msgBus)
|
||||
if !ch.IsAllowed("U_ALLOWED") {
|
||||
t.Error("allowed user should pass allowlist check")
|
||||
}
|
||||
|
||||
@@ -7,7 +7,26 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("teams_webhook", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewTeamsWebhookChannel(cfg.Channels.TeamsWebhook, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelTeamsWebHook,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.TeamsWebhookSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, c, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelName != config.ChannelTeamsWebHook {
|
||||
ch.SetName(channelName)
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -52,13 +52,15 @@ func classifyTeamsError(err error) error {
|
||||
// Multiple webhook targets can be configured and selected via ChatID.
|
||||
type TeamsWebhookChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.TeamsWebhookConfig
|
||||
bc *config.Channel
|
||||
config *config.TeamsWebhookSettings
|
||||
client teamsMessageSender
|
||||
}
|
||||
|
||||
// NewTeamsWebhookChannel creates a new Teams webhook channel.
|
||||
func NewTeamsWebhookChannel(
|
||||
cfg config.TeamsWebhookConfig,
|
||||
bc *config.Channel,
|
||||
cfg *config.TeamsWebhookSettings,
|
||||
bus *bus.MessageBus,
|
||||
) (*TeamsWebhookChannel, error) {
|
||||
if len(cfg.Webhooks) == 0 {
|
||||
@@ -99,6 +101,7 @@ func NewTeamsWebhookChannel(
|
||||
|
||||
return &TeamsWebhookChannel{
|
||||
BaseChannel: base,
|
||||
bc: bc,
|
||||
config: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
|
||||
@@ -31,67 +31,60 @@ func TestNewTeamsWebhookChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
// Test missing webhooks
|
||||
_, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
|
||||
cfg := config.TeamsWebhookSettings{
|
||||
Webhooks: nil,
|
||||
}, msgBus)
|
||||
}
|
||||
_, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing webhooks")
|
||||
}
|
||||
|
||||
// Test missing "default" webhook
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
_, err = NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing 'default' webhook")
|
||||
}
|
||||
|
||||
// Test empty webhook URL
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {Title: "Default"},
|
||||
},
|
||||
}, msgBus)
|
||||
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
|
||||
"default": {Title: "Default"},
|
||||
}
|
||||
_, err = NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty webhook_url")
|
||||
}
|
||||
|
||||
// Test HTTP URL (should fail, must be HTTPS)
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
|
||||
Title: "Default",
|
||||
},
|
||||
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
|
||||
Title: "Default",
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
_, err = NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for HTTP webhook URL (must be HTTPS)")
|
||||
}
|
||||
|
||||
// Test valid config with HTTPS (must include "default")
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
cfg.Webhooks = map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
}, msgBus)
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -103,14 +96,15 @@ func TestNewTeamsWebhookChannel(t *testing.T) {
|
||||
|
||||
func TestTeamsWebhookChannel_StartStop(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
|
||||
cfg := config.TeamsWebhookSettings{
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -140,8 +134,8 @@ func TestTeamsWebhookChannel_StartStop(t *testing.T) {
|
||||
|
||||
func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
|
||||
cfg := config.TeamsWebhookSettings{
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
@@ -152,7 +146,8 @@ func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
|
||||
Title: "Custom Title",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -175,14 +170,15 @@ func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
|
||||
|
||||
func TestTeamsWebhookChannel_SendNotRunning(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
|
||||
cfg := config.TeamsWebhookSettings{
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -208,8 +204,8 @@ func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
|
||||
cfg := config.TeamsWebhookSettings{
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
@@ -218,7 +214,8 @@ func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -250,8 +247,8 @@ func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
|
||||
|
||||
func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
|
||||
cfg := config.TeamsWebhookSettings{
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
@@ -262,7 +259,8 @@ func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
|
||||
Title: "Test Alerts",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -294,8 +292,8 @@ func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
|
||||
|
||||
func TestTeamsWebhookChannel_SendError(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
bc := &config.Channel{Type: config.ChannelTeamsWebHook, Enabled: true}
|
||||
cfg := config.TeamsWebhookSettings{
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
@@ -304,7 +302,8 @@ func TestTeamsWebhookChannel_SendError(t *testing.T) {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
}
|
||||
ch, err := NewTeamsWebhookChannel(bc, &cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewTelegramChannel(cfg, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelTelegram,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.TelegramSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewTelegramChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -47,18 +47,23 @@ type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
bot *telego.Bot
|
||||
bh *th.BotHandler
|
||||
config *config.Config
|
||||
bc *config.Channel
|
||||
chatIDs map[string]int64
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
tgCfg *config.TelegramSettings
|
||||
|
||||
registerFunc func(context.Context, []commands.Definition) error
|
||||
commandRegCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
func NewTelegramChannel(
|
||||
bc *config.Channel,
|
||||
telegramCfg *config.TelegramSettings,
|
||||
bus *bus.MessageBus,
|
||||
) (*TelegramChannel, error) {
|
||||
channelName := bc.Name()
|
||||
var opts []telego.BotOption
|
||||
telegramCfg := cfg.Channels.Telegram
|
||||
|
||||
if telegramCfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
|
||||
@@ -90,20 +95,21 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"telegram",
|
||||
channelName,
|
||||
telegramCfg,
|
||||
bus,
|
||||
telegramCfg.AllowFrom,
|
||||
bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(4000),
|
||||
channels.WithGroupTrigger(telegramCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
bc: bc,
|
||||
chatIDs: make(map[string]int64),
|
||||
tgCfg: telegramCfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -174,9 +180,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
|
||||
useMarkdownV2 := c.tgCfg.UseMarkdownV2
|
||||
|
||||
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
|
||||
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
@@ -360,7 +366,7 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
|
||||
|
||||
// EditMessage implements channels.MessageEditor.
|
||||
func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
|
||||
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
|
||||
useMarkdownV2 := c.tgCfg.UseMarkdownV2
|
||||
cid, _, err := parseTelegramChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -435,7 +441,7 @@ func (c *TelegramChannel) DeleteMessage(ctx context.Context, chatID string, mess
|
||||
// It sends a placeholder message (e.g. "Thinking... 💭") that will later be
|
||||
// edited to the actual response via EditMessage (channels.MessageEditor).
|
||||
func (c *TelegramChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
phCfg := c.config.Channels.Telegram.Placeholder
|
||||
phCfg := c.bc.Placeholder
|
||||
if !phCfg.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
@@ -463,7 +469,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
|
||||
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
@@ -691,8 +697,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
}
|
||||
|
||||
// In group chats, apply unified group trigger filtering
|
||||
isMentioned := false
|
||||
if message.Chat.Type != "private" {
|
||||
isMentioned := c.isBotMentioned(message)
|
||||
isMentioned = c.isBotMentioned(message)
|
||||
if isMentioned {
|
||||
content = c.stripBotMention(content)
|
||||
}
|
||||
@@ -738,13 +745,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
})
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := fmt.Sprintf("%d", user.ID)
|
||||
if message.Chat.Type != "private" {
|
||||
peerKind = "group"
|
||||
peerID = compositeChatID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
messageID := fmt.Sprintf("%d", message.MessageID)
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -753,24 +756,29 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
"first_name": user.FirstName,
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
}
|
||||
if message.ReplyToMessage != nil {
|
||||
metadata["reply_to_message_id"] = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
|
||||
}
|
||||
|
||||
// Set parent_peer metadata for per-topic agent binding.
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: fmt.Sprintf("%d", chatID),
|
||||
ChatType: peerKind,
|
||||
SenderID: platformID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if message.Chat.IsForum && threadID != 0 {
|
||||
metadata["parent_peer_kind"] = "topic"
|
||||
metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID)
|
||||
inboundCtx.TopicID = fmt.Sprintf("%d", threadID)
|
||||
}
|
||||
if message.ReplyToMessage != nil {
|
||||
inboundCtx.ReplyToMessageID = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
peer,
|
||||
messageID,
|
||||
platformID,
|
||||
c.HandleMessageWithContext(
|
||||
c.ctx,
|
||||
compositeChatID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
inboundCtx,
|
||||
sender,
|
||||
)
|
||||
return nil
|
||||
@@ -958,6 +966,28 @@ func parseTelegramChatID(chatID string) (int64, int, error) {
|
||||
return cid, tid, nil
|
||||
}
|
||||
|
||||
func resolveTelegramOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (int64, int, error) {
|
||||
targetChatID := strings.TrimSpace(chatID)
|
||||
if targetChatID == "" && outboundCtx != nil {
|
||||
targetChatID = strings.TrimSpace(outboundCtx.ChatID)
|
||||
}
|
||||
resolvedChatID, resolvedThreadID, err := parseTelegramChatID(targetChatID)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if resolvedThreadID != 0 || outboundCtx == nil {
|
||||
return resolvedChatID, resolvedThreadID, nil
|
||||
}
|
||||
topicID := strings.TrimSpace(outboundCtx.TopicID)
|
||||
if topicID == "" {
|
||||
return resolvedChatID, resolvedThreadID, nil
|
||||
}
|
||||
if threadID, convErr := strconv.Atoi(topicID); convErr == nil {
|
||||
return resolvedChatID, threadID, nil
|
||||
}
|
||||
return resolvedChatID, resolvedThreadID, nil
|
||||
}
|
||||
|
||||
func logParseFailed(err error, useMarkdownV2 bool) {
|
||||
parsingName := "HTML"
|
||||
if useMarkdownV2 {
|
||||
@@ -1063,7 +1093,7 @@ func (c *TelegramChannel) stripBotMention(content string) string {
|
||||
|
||||
// BeginStream implements channels.StreamingCapable.
|
||||
func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (channels.Streamer, error) {
|
||||
if !c.config.Channels.Telegram.Streaming.Enabled {
|
||||
if !c.tgCfg.Streaming.Enabled {
|
||||
return nil, fmt.Errorf("streaming disabled in config")
|
||||
}
|
||||
|
||||
@@ -1072,7 +1102,7 @@ func (c *TelegramChannel) BeginStream(ctx context.Context, chatID string) (chann
|
||||
return nil, err
|
||||
}
|
||||
|
||||
streamCfg := c.config.Channels.Telegram.Streaming
|
||||
streamCfg := c.tgCfg.Streaming
|
||||
return &telegramStreamer{
|
||||
bot: c.bot,
|
||||
chatID: cid,
|
||||
|
||||
@@ -140,7 +140,8 @@ func newTestChannelWithConstructor(
|
||||
BaseChannel: base,
|
||||
bot: bot,
|
||||
chatIDs: make(map[string]int64),
|
||||
config: config.DefaultConfig(),
|
||||
bc: &config.Channel{Type: config.ChannelTelegram, Enabled: true},
|
||||
tgCfg: &config.TelegramSettings{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,6 +528,38 @@ func TestSend_WithForumThreadID(t *testing.T) {
|
||||
assert.Len(t, caller.calls, 1)
|
||||
}
|
||||
|
||||
func TestSend_UsesContextTopicIDWhenChatIDDoesNotIncludeThread(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
_, err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "-1001234567890",
|
||||
Content: "Hello from topic context",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "-1001234567890",
|
||||
TopicID: "42",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, caller.calls, 1)
|
||||
|
||||
var params struct {
|
||||
ChatID int64 `json:"chat_id"`
|
||||
MessageThreadID int `json:"message_thread_id"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, ¶ms))
|
||||
assert.Equal(t, int64(-1001234567890), params.ChatID)
|
||||
assert.Equal(t, 42, params.MessageThreadID)
|
||||
assert.Equal(t, "Hello from topic context", params.Text)
|
||||
}
|
||||
|
||||
func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
@@ -556,16 +589,10 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok, "expected inbound message")
|
||||
|
||||
// Composite chatID should include thread ID
|
||||
assert.Equal(t, "-1001234567890/42", inbound.ChatID)
|
||||
|
||||
// Peer ID should include thread ID for session key isolation
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-1001234567890/42", inbound.Peer.ID)
|
||||
|
||||
// Parent peer metadata should be set for agent binding
|
||||
assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"])
|
||||
assert.Equal(t, "42", inbound.Metadata["parent_peer_id"])
|
||||
// ChatID remains the parent chat; TopicID isolates the sub-conversation.
|
||||
assert.Equal(t, "-1001234567890", inbound.ChatID)
|
||||
assert.Equal(t, "group", inbound.Context.ChatType)
|
||||
assert.Equal(t, "42", inbound.Context.TopicID)
|
||||
}
|
||||
|
||||
func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
|
||||
@@ -598,13 +625,8 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
|
||||
// Plain chatID without thread suffix
|
||||
assert.Equal(t, "-100999", inbound.ChatID)
|
||||
|
||||
// Peer ID should be raw chat ID (no thread suffix)
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-100999", inbound.Peer.ID)
|
||||
|
||||
// No parent peer metadata
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_id"])
|
||||
assert.Equal(t, "group", inbound.Context.ChatType)
|
||||
assert.Empty(t, inbound.Context.TopicID)
|
||||
}
|
||||
|
||||
func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
|
||||
@@ -641,13 +663,8 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
|
||||
// chatID should NOT include thread suffix for non-forum groups
|
||||
assert.Equal(t, "-100999", inbound.ChatID)
|
||||
|
||||
// Peer ID should be raw chat ID (shared session for whole group)
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-100999", inbound.Peer.ID)
|
||||
|
||||
// No parent peer metadata
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_id"])
|
||||
assert.Equal(t, "group", inbound.Context.ChatType)
|
||||
assert.Empty(t, inbound.Context.TopicID)
|
||||
}
|
||||
|
||||
func assertHandleMessageQuotedUserReply(
|
||||
@@ -700,7 +717,7 @@ func assertHandleMessageQuotedUserReply(
|
||||
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Metadata["reply_to_message_id"])
|
||||
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Context.ReplyToMessageID)
|
||||
assert.Equal(t, expectedContent, inbound.Content)
|
||||
}
|
||||
|
||||
@@ -786,7 +803,7 @@ func TestHandleMessage_ReplyToOwnBotMessage_UsesAssistantRole(t *testing.T) {
|
||||
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "101", inbound.Metadata["reply_to_message_id"])
|
||||
assert.Equal(t, "101", inbound.Context.ReplyToMessageID)
|
||||
assert.Equal(
|
||||
t,
|
||||
"[quoted assistant message from afjcjsbx_picoclaw_bot]: Fatto! Ho creato il file notizie_2026_03_28.md\n\nti ricordi questo file?",
|
||||
|
||||
+10
-3
@@ -7,7 +7,14 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("vk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewVKChannel(cfg, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelVK,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
if bc == nil {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewVKChannel(channelName, bc, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
+39
-30
@@ -21,41 +21,54 @@ import (
|
||||
|
||||
type VKChannel struct {
|
||||
*channels.BaseChannel
|
||||
vk *api.VK
|
||||
lp *longpoll.LongPoll
|
||||
config *config.Config
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
vk *api.VK
|
||||
lp *longpoll.LongPoll
|
||||
channelName string
|
||||
bc *config.Channel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewVKChannel(cfg *config.Config, bus *bus.MessageBus) (*VKChannel, error) {
|
||||
vkCfg := cfg.Channels.VK
|
||||
func NewVKChannel(channelName string, bc *config.Channel, bus *bus.MessageBus) (*VKChannel, error) {
|
||||
var vkCfg config.VKSettings
|
||||
if err := bc.Decode(&vkCfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vk := api.NewVK(vkCfg.Token.String())
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"vk",
|
||||
vkCfg,
|
||||
channelName,
|
||||
&vkCfg,
|
||||
bus,
|
||||
vkCfg.AllowFrom,
|
||||
bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(4000),
|
||||
channels.WithGroupTrigger(vkCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(vkCfg.ReasoningChannelID),
|
||||
channels.WithGroupTrigger(bc.GroupTrigger),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &VKChannel{
|
||||
BaseChannel: base,
|
||||
vk: vk,
|
||||
config: cfg,
|
||||
channelName: channelName,
|
||||
bc: bc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *VKChannel) getVKCfg() *config.VKSettings {
|
||||
var v config.VKSettings
|
||||
if err := c.bc.Decode(&v); err != nil {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
func (c *VKChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("vk", "Starting VK bot (Long Poll mode)...")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
groupID := c.config.Channels.VK.GroupID
|
||||
groupID := c.getVKCfg().GroupID
|
||||
if groupID == 0 {
|
||||
c.cancel()
|
||||
return fmt.Errorf("group_id is required for VK bot")
|
||||
@@ -143,7 +156,7 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
groupTrigger := c.config.Channels.VK.GroupTrigger
|
||||
groupTrigger := c.bc.GroupTrigger
|
||||
isGroupChat := peerID != fromID
|
||||
|
||||
if isGroupChat {
|
||||
@@ -159,14 +172,11 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
|
||||
_ = groupTrigger
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerIDStr := userID
|
||||
chatType := "direct"
|
||||
if isGroupChat {
|
||||
peerKind = "group"
|
||||
peerIDStr = chatID
|
||||
chatType = "group"
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerIDStr}
|
||||
messageID := strconv.Itoa(msg.ConversationMessageID)
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -174,16 +184,15 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
|
||||
"is_group": fmt.Sprintf("%t", isGroupChat),
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
peer,
|
||||
messageID,
|
||||
userID,
|
||||
chatID,
|
||||
text,
|
||||
nil,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{
|
||||
Channel: "vk",
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
SenderID: userID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isGroupChat && c.isMentioned(msg),
|
||||
Raw: metadata,
|
||||
}, sender)
|
||||
}
|
||||
|
||||
func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
|
||||
+54
-62
@@ -1,6 +1,7 @@
|
||||
package vk
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -8,19 +9,23 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func makeVKTestBaseChannel(vkCfg config.VKSettings) *config.Channel {
|
||||
settings, _ := json.Marshal(vkCfg)
|
||||
return &config.Channel{
|
||||
Enabled: true,
|
||||
Type: config.ChannelVK,
|
||||
Settings: settings,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVKChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing group_id", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
bc := makeVKTestBaseChannel(config.VKSettings{
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
})
|
||||
ch, err := NewVKChannel("vk", bc, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during creation: %v", err)
|
||||
}
|
||||
@@ -33,16 +38,11 @@ func TestNewVKChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("valid config with group_id", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
bc := makeVKTestBaseChannel(config.VKSettings{
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
})
|
||||
ch, err := NewVKChannel("vk", bc, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -55,17 +55,18 @@ func TestNewVKChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("with allow_from", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
AllowFrom: []string{"123456789"},
|
||||
},
|
||||
},
|
||||
vkCfg := config.VKSettings{
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
settings, _ := json.Marshal(vkCfg)
|
||||
bc := &config.Channel{
|
||||
Enabled: true,
|
||||
Type: "vk",
|
||||
AllowFrom: []string{"123456789"},
|
||||
Settings: settings,
|
||||
}
|
||||
ch, err := NewVKChannel("vk", bc, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -78,20 +79,21 @@ func TestNewVKChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("with group_trigger", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
GroupTrigger: config.GroupTriggerConfig{
|
||||
MentionOnly: false,
|
||||
Prefixes: []string{"/bot", "!bot"},
|
||||
},
|
||||
},
|
||||
},
|
||||
vkCfg := config.VKSettings{
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
settings, _ := json.Marshal(vkCfg)
|
||||
bc := &config.Channel{
|
||||
Enabled: true,
|
||||
Type: "vk",
|
||||
GroupTrigger: config.GroupTriggerConfig{
|
||||
MentionOnly: false,
|
||||
Prefixes: []string{"/bot", "!bot"},
|
||||
},
|
||||
Settings: settings,
|
||||
}
|
||||
ch, err := NewVKChannel("vk", bc, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -103,16 +105,11 @@ func TestNewVKChannel(t *testing.T) {
|
||||
|
||||
func TestVKChannel_MaxMessageLength(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
bc := makeVKTestBaseChannel(config.VKSettings{
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
})
|
||||
ch, err := NewVKChannel("vk", bc, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -236,16 +233,11 @@ func TestVKChannel_ProcessAttachments(t *testing.T) {
|
||||
|
||||
func TestVKChannel_VoiceCapabilities(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
bc := makeVKTestBaseChannel(config.VKSettings{
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
})
|
||||
ch, err := NewVKChannel("vk", bc, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewChannel(cfg.Channels.WeCom, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelWeCom,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.WeComSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ const (
|
||||
|
||||
type WeComChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComConfig
|
||||
config *config.WeComSettings
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -108,7 +108,7 @@ func (s *recentMessageSet) Mark(id string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChannel, error) {
|
||||
func NewChannel(bc *config.Channel, cfg *config.WeComSettings, messageBus *bus.MessageBus) (*WeComChannel, error) {
|
||||
if cfg.BotID == "" || cfg.Secret.String() == "" {
|
||||
return nil, fmt.Errorf("wecom bot_id and secret are required")
|
||||
}
|
||||
@@ -120,8 +120,8 @@ func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChann
|
||||
"wecom",
|
||||
cfg,
|
||||
messageBus,
|
||||
cfg.AllowFrom,
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
bc.AllowFrom,
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
ch := &WeComChannel{
|
||||
@@ -570,7 +570,6 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
|
||||
return err
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
|
||||
metadata := map[string]string{
|
||||
"channel": "wecom",
|
||||
"req_id": reqID,
|
||||
@@ -583,7 +582,20 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
|
||||
metadata["quote_text"] = quoteText
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: strings.TrimSpace(msg.AIBotID),
|
||||
ChatID: actualChatID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: msg.MsgID,
|
||||
ReplyHandles: map[string]string{
|
||||
"req_id": reqID,
|
||||
},
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -50,11 +50,11 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
|
||||
if inbound.MessageID != "msg-1" {
|
||||
t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID)
|
||||
}
|
||||
if inbound.Peer.ID != "chat-1" {
|
||||
t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID)
|
||||
if inbound.Context.ChatType != "direct" {
|
||||
t.Fatalf("inbound Context.ChatType = %q, want direct", inbound.Context.ChatType)
|
||||
}
|
||||
if inbound.Metadata["req_id"] != "req-1" {
|
||||
t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"])
|
||||
if inbound.Context.ReplyHandles["req_id"] != "req-1" {
|
||||
t.Fatalf("inbound req_id = %q, want req-1", inbound.Context.ReplyHandles["req_id"])
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected inbound message to be published")
|
||||
@@ -605,9 +605,10 @@ func TestSendMedia_SendsActiveFile(t *testing.T) {
|
||||
func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.WeComConfig{BotID: "bot-1"}
|
||||
cfg := &config.WeComSettings{BotID: "bot-1"}
|
||||
cfg.SetSecret("secret-1")
|
||||
ch, err := NewChannel(cfg, messageBus)
|
||||
bc := &config.Channel{Type: config.ChannelWeCom, Enabled: true}
|
||||
ch, err := NewChannel(bc, cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("NewChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func picoclawHomeDir() string {
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func genWeixinAccountKey(cfg config.WeixinConfig) string {
|
||||
func genWeixinAccountKey(cfg *config.WeixinSettings) string {
|
||||
token := strings.TrimSpace(cfg.Token.String())
|
||||
if token == "" {
|
||||
return "default"
|
||||
@@ -53,11 +53,11 @@ func genWeixinAccountKey(cfg config.WeixinConfig) string {
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func buildWeixinSyncBufPath(cfg config.WeixinConfig) string {
|
||||
func buildWeixinSyncBufPath(cfg *config.WeixinSettings) string {
|
||||
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "sync", genWeixinAccountKey(cfg)+".json")
|
||||
}
|
||||
|
||||
func buildWeixinContextTokensPath(cfg config.WeixinConfig) string {
|
||||
func buildWeixinContextTokensPath(cfg *config.WeixinSettings) string {
|
||||
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "context-tokens", genWeixinAccountKey(cfg)+".json")
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
type WeixinChannel struct {
|
||||
*channels.BaseChannel
|
||||
api *ApiClient
|
||||
config config.WeixinConfig
|
||||
config *config.WeixinSettings
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
bus *bus.MessageBus
|
||||
@@ -36,25 +36,48 @@ type WeixinChannel struct {
|
||||
}
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("weixin", func(cfg *config.Config, bus *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeixinChannel(cfg.Channels.Weixin, bus)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelWeixin,
|
||||
func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
weixinCfg, ok := decoded.(*config.WeixinSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
ch, err := NewWeixinChannel(bc, weixinCfg, bus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelName != config.ChannelWeixin {
|
||||
ch.SetName(channelName)
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// NewWeixinChannel creates a new WeixinChannel from config.
|
||||
func NewWeixinChannel(cfg config.WeixinConfig, messageBus *bus.MessageBus) (*WeixinChannel, error) {
|
||||
func NewWeixinChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.WeixinSettings,
|
||||
messageBus *bus.MessageBus,
|
||||
) (*WeixinChannel, error) {
|
||||
api, err := NewApiClient(cfg.BaseURL, cfg.Token.String(), cfg.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weixin: failed to create API client: %w", err)
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"weixin",
|
||||
bc.Name(),
|
||||
cfg,
|
||||
messageBus,
|
||||
cfg.AllowFrom,
|
||||
bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(4000),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &WeixinChannel{
|
||||
@@ -334,8 +357,6 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
|
||||
return
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: "direct", ID: fromUserID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"from_user_id": fromUserID,
|
||||
"context_token": msg.ContextToken,
|
||||
@@ -354,7 +375,21 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
|
||||
c.persistContextTokens()
|
||||
}
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "weixin",
|
||||
ChatID: fromUserID,
|
||||
ChatType: "direct",
|
||||
SenderID: fromUserID,
|
||||
MessageID: messageID,
|
||||
Raw: metadata,
|
||||
}
|
||||
if msg.ContextToken != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"context_token": msg.ContextToken,
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, fromUserID, content, mediaRefs, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// Send implements channels.Channel by sending a text message to the WeChat user.
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
config: &config.WeixinSettings{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
@@ -105,7 +105,7 @@ func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) {
|
||||
return nil, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
config: &config.WeixinSettings{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
@@ -155,7 +155,7 @@ func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
config: &config.WeixinSettings{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
@@ -224,7 +224,7 @@ func TestUploadBufferToCDN(t *testing.T) {
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
config: &config.WeixinSettings{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
@@ -259,7 +259,7 @@ func TestBuildWeixinSyncBufPathUsesPicoclawHome(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv(config.EnvHome, home)
|
||||
|
||||
wxCfg := config.WeixinConfig{
|
||||
wxCfg := &config.WeixinSettings{
|
||||
BaseURL: "https://ilinkai.weixin.qq.com/",
|
||||
}
|
||||
wxCfg.SetToken("token-123")
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWhatsAppChannel(cfg.Channels.WhatsApp, b)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelWhatsApp,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.WhatsAppSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
return NewWhatsAppChannel(bc, c, b)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
type WhatsAppChannel struct {
|
||||
*channels.BaseChannel
|
||||
conn *websocket.Conn
|
||||
config config.WhatsAppConfig
|
||||
config *config.WhatsAppSettings
|
||||
url string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -28,14 +28,18 @@ type WhatsAppChannel struct {
|
||||
connected bool
|
||||
}
|
||||
|
||||
func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) {
|
||||
func NewWhatsAppChannel(
|
||||
bc *config.Channel,
|
||||
cfg *config.WhatsAppSettings,
|
||||
bus *bus.MessageBus,
|
||||
) (*WhatsAppChannel, error) {
|
||||
base := channels.NewBaseChannel(
|
||||
"whatsapp",
|
||||
cfg,
|
||||
bus,
|
||||
cfg.AllowFrom,
|
||||
bc.AllowFrom,
|
||||
channels.WithMaxMessageLength(65536),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
channels.WithReasoningChannelID(bc.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &WhatsAppChannel{
|
||||
@@ -223,13 +227,6 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
metadata["user_name"] = userName
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
if chatID == senderID {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
}
|
||||
|
||||
logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
|
||||
"sender": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
@@ -248,5 +245,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: chatID,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Raw: metadata,
|
||||
}
|
||||
if chatID == senderID {
|
||||
inboundCtx.ChatType = "direct"
|
||||
} else {
|
||||
inboundCtx.ChatType = "group"
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &WhatsAppChannel{
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil),
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppSettings{}, messageBus, nil),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
|
||||
@@ -9,12 +9,27 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("whatsapp_native", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
waCfg := cfg.Channels.WhatsApp
|
||||
storePath := waCfg.SessionStorePath
|
||||
if storePath == "" {
|
||||
storePath = filepath.Join(cfg.WorkspacePath(), "whatsapp")
|
||||
}
|
||||
return NewWhatsAppNativeChannel(waCfg, b, storePath)
|
||||
})
|
||||
channels.RegisterFactory(
|
||||
config.ChannelWhatsAppNative,
|
||||
func(channelName, channelType string, cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
bc := cfg.Channels[channelName]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := decoded.(*config.WhatsAppSettings)
|
||||
if !ok {
|
||||
return nil, channels.ErrSendFailed
|
||||
}
|
||||
storePath := c.SessionStorePath
|
||||
if storePath == "" {
|
||||
storePath = filepath.Join(cfg.WorkspacePath(), "whatsapp")
|
||||
}
|
||||
ch, err := NewWhatsAppNativeChannel(bc, channelName, c, b, storePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ch, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &WhatsAppNativeChannel{
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil),
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppSettings{}, messageBus, nil),
|
||||
runCtx: context.Background(),
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ const (
|
||||
// WhatsAppNativeChannel implements the WhatsApp channel using whatsmeow (in-process, no external bridge).
|
||||
type WhatsAppNativeChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WhatsAppConfig
|
||||
config *config.WhatsAppSettings
|
||||
storePath string
|
||||
client *whatsmeow.Client
|
||||
container *sqlstore.Container
|
||||
@@ -64,11 +64,13 @@ type WhatsAppNativeChannel struct {
|
||||
// NewWhatsAppNativeChannel creates a WhatsApp channel that uses whatsmeow for connection.
|
||||
// storePath is the directory for the SQLite session store (e.g. workspace/whatsapp).
|
||||
func NewWhatsAppNativeChannel(
|
||||
cfg config.WhatsAppConfig,
|
||||
bc *config.Channel,
|
||||
name string,
|
||||
cfg *config.WhatsAppSettings,
|
||||
bus *bus.MessageBus,
|
||||
storePath string,
|
||||
) (channels.Channel, error) {
|
||||
base := channels.NewBaseChannel("whatsapp_native", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(65536))
|
||||
base := channels.NewBaseChannel(name, cfg, bus, bc.AllowFrom, channels.WithMaxMessageLength(65536))
|
||||
if storePath == "" {
|
||||
storePath = "whatsapp"
|
||||
}
|
||||
@@ -375,7 +377,6 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
|
||||
if evt.Info.Chat.Server == types.GroupServer {
|
||||
peerKind = "group"
|
||||
}
|
||||
peer := bus.Peer{Kind: peerKind, ID: chatID}
|
||||
messageID := evt.Info.ID
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "whatsapp",
|
||||
@@ -393,7 +394,17 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
|
||||
"WhatsApp message received",
|
||||
map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)},
|
||||
)
|
||||
c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: chatID,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
ChatType: peerKind,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.runCtx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user