Merge pull request #2249 from alexhoshina/refactor-inbound-context-routing-session

Refactor inbound context routing session
This commit is contained in:
daming大铭
2026-04-14 12:45:34 +08:00
committed by GitHub
75 changed files with 5976 additions and 1871 deletions
+61 -116
View File
@@ -120,133 +120,78 @@ dammi le ultime news
- Unknown slash command (for example `/foo`) passes through to normal LLM processing.
- Registered but unsupported command on the current channel (for example `/show` on WhatsApp) returns an explicit user-facing error and stops further processing.
### Agent Bindings (Route messages to specific agents)
### Routing
Use `bindings` in `config.json` to route incoming messages to different agents by channel/account/context.
Routing is configured through `agents.dispatch.rules`.
Each rule matches against the normalized inbound context produced by channels.
Rules are evaluated from top to bottom. The first matching rule wins. If no
rule matches, PicoClaw falls back to the configured default agent.
Supported match fields:
* `channel`
* `account`
* `space`
* `chat`
* `topic`
* `sender`
* `mentioned`
Match values use the same scope vocabulary as the session system:
* `space`: `workspace:t001`, `guild:123456`
* `chat`: `direct:user123`, `group:-100123`, `channel:c123`
* `topic`: `topic:42`
* `sender`: a normalized sender identifier for the platform
Rules may optionally override the global `session.dimensions` value through
`session_dimensions`. This allows routing and session allocation to stay aligned
without reintroducing the old `bindings` or `dm_scope` formats.
Example:
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model_name": "gpt-4o-mini"
},
"list": [
{ "id": "main", "default": true, "name": "Main Assistant" },
{ "id": "support", "name": "Support Assistant" },
{ "id": "sales", "name": "Sales Assistant" }
]
},
"bindings": [
{
"agent_id": "support",
"match": {
"channel": "telegram",
"account_id": "*",
"peer": { "kind": "direct", "id": "user123" }
}
},
{
"agent_id": "sales",
"match": {
"channel": "discord",
"account_id": "my-discord-bot",
"guild_id": "987654321"
}
{ "id": "main", "default": true },
{ "id": "support" },
{ "id": "sales" }
],
"dispatch": {
"rules": [
{
"name": "vip in support group",
"agent": "sales",
"when": {
"channel": "telegram",
"chat": "group:-1001234567890",
"sender": "12345"
},
"session_dimensions": ["chat", "sender"]
},
{
"name": "telegram support group",
"agent": "support",
"when": {
"channel": "telegram",
"chat": "group:-1001234567890"
},
"session_dimensions": ["chat"]
}
]
}
]
}
```
#### `bindings` fields
| Field | Required | Description |
|-------|----------|-------------|
| `agent_id` | Yes | Target agent id in `agents.list` |
| `match.channel` | Yes | Channel name (e.g. `telegram`, `discord`) |
| `match.account_id` | No | Channel account filter. Use `"*"` for all accounts of that channel. If omitted, only default account is matched |
| `match.peer.kind` + `match.peer.id` | No | Exact peer match (e.g. direct chat / topic / group id) |
| `match.guild_id` | No | Guild/server-level match |
| `match.team_id` | No | Team/workspace-level match |
#### Matching priority
When multiple bindings exist, PicoClaw resolves in this order:
1. `peer`
2. `parent_peer` (for thread/topic parent contexts)
3. `guild_id`
4. `team_id`
5. `account_id` (non-wildcard)
6. channel wildcard (`account_id: "*"`)
7. default agent
If a binding points to a missing `agent_id`, PicoClaw falls back to the default agent.
#### How matching works (step-by-step)
1. PicoClaw first filters bindings by `match.channel` (must equal current channel).
2. It then filters by `match.account_id`:
- omitted: match only the channel's default account
- `"*"`: match all accounts on this channel
- explicit value: exact account id match (case-insensitive)
3. From the remaining candidates, it applies the priority chain above and stops at the first hit.
In other words: **channel + account form the candidate set; peer/guild/team then decide final winner**.
#### Common recipes
**1) Route one specific DM user to a specialist agent**
```json
{
"agent_id": "support",
"match": {
"channel": "telegram",
"account_id": "*",
"peer": { "kind": "direct", "id": "user123" }
},
"session": {
"dimensions": ["chat"]
}
}
```
**2) Route one Discord server (guild) to a dedicated agent**
```json
{
"agent_id": "sales",
"match": {
"channel": "discord",
"account_id": "my-discord-bot",
"guild_id": "987654321"
}
}
```
**3) Route all remaining traffic of a channel to a fallback agent**
```json
{
"agent_id": "main",
"match": {
"channel": "discord",
"account_id": "*"
}
}
```
#### Authoring guidelines (important)
- Keep exactly one clear default agent in `agents.list` (`"default": true`).
- Put specific rules (`peer`, `guild_id`, `team_id`) and broad rules (`account_id: "*"` only) together safely; priority already guarantees specific rules win.
- Avoid duplicate rules with the same specificity and match values. If duplicates exist, the first matching entry in the config array wins.
- Ensure every `agent_id` exists in `agents.list`; unknown IDs silently fall back to default.
#### Troubleshooting checklist
- **Rule not taking effect?** Check `match.channel` spelling first (must be exact).
- **Expected account-specific routing but still using default?** Verify `match.account_id` equals actual runtime account id.
- **Wildcard catches too much traffic?** Add more specific `peer/guild/team` rules for critical paths.
- **Unexpected default fallback?** Confirm `agent_id` exists and is not misspelled.
In the example above, the VIP rule must appear before the broader group rule.
Because routing is strictly ordered, more specific rules should be placed
earlier and broader fallback rules later.
### 🔒 Security Sandbox
+2 -2
View File
@@ -42,7 +42,7 @@ func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) e
if result, ok := m.forceCompression(req.SessionKey); ok {
m.al.emitEvent(
EventKindContextCompress,
m.al.newTurnEventScope("", req.SessionKey).meta(0, "forceCompression", "turn.context.compress"),
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
ContextCompressPayload{
Reason: req.Reason,
DroppedMessages: result.DroppedMessages,
@@ -247,7 +247,7 @@ func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey
agent.Sessions.Save(sessionKey)
m.al.emitEvent(
EventKindSessionSummarize,
m.al.newTurnEventScope(agent.ID, sessionKey).meta(0, "summarizeSession", "turn.session.summarize"),
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
SessionSummarizePayload{
SummarizedMessages: len(validMessages),
KeptMessages: keepCount,
+147
View File
@@ -0,0 +1,147 @@
package agent
import (
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
// DispatchRequest is the normalized runtime input passed into the agent loop
// after routing and session allocation have completed.
type DispatchRequest struct {
SessionKey string
SessionAliases []string
InboundContext *bus.InboundContext
RouteResult *routing.ResolvedRoute
SessionScope *session.SessionScope
UserMessage string
Media []string
}
func (r DispatchRequest) Channel() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.Channel
}
func (r DispatchRequest) ChatID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.ChatID
}
func (r DispatchRequest) MessageID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.MessageID
}
func (r DispatchRequest) ReplyToMessageID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.ReplyToMessageID
}
func (r DispatchRequest) SenderID() string {
if r.InboundContext == nil {
return ""
}
return r.InboundContext.SenderID
}
func normalizeProcessOptionsInPlace(opts *processOptions) {
if opts == nil {
return
}
*opts = normalizeProcessOptions(*opts)
}
func normalizeProcessOptions(opts processOptions) processOptions {
if opts.Dispatch.SessionKey == "" {
opts.Dispatch.SessionKey = strings.TrimSpace(opts.SessionKey)
}
if len(opts.Dispatch.SessionAliases) == 0 && len(opts.SessionAliases) > 0 {
opts.Dispatch.SessionAliases = append([]string(nil), opts.SessionAliases...)
}
if opts.Dispatch.UserMessage == "" {
opts.Dispatch.UserMessage = opts.UserMessage
}
if len(opts.Dispatch.Media) == 0 && len(opts.Media) > 0 {
opts.Dispatch.Media = append([]string(nil), opts.Media...)
}
if opts.Dispatch.RouteResult == nil {
opts.Dispatch.RouteResult = cloneResolvedRoute(opts.RouteResult)
}
if opts.Dispatch.SessionScope == nil {
opts.Dispatch.SessionScope = session.CloneScope(opts.SessionScope)
}
if opts.Dispatch.InboundContext == nil {
if opts.InboundContext != nil {
opts.Dispatch.InboundContext = cloneInboundContext(opts.InboundContext)
} else if opts.Channel != "" || opts.ChatID != "" || opts.SenderID != "" ||
opts.MessageID != "" || opts.ReplyToMessageID != "" {
inbound := bus.InboundContext{
Channel: strings.TrimSpace(opts.Channel),
ChatID: strings.TrimSpace(opts.ChatID),
SenderID: strings.TrimSpace(opts.SenderID),
MessageID: strings.TrimSpace(opts.MessageID),
ReplyToMessageID: strings.TrimSpace(opts.ReplyToMessageID),
}
inbound.ChatType = inferChatTypeFromSessionScope(opts.Dispatch.SessionScope)
if inbound.Channel != "" || inbound.ChatID != "" || inbound.SenderID != "" ||
inbound.MessageID != "" || inbound.ReplyToMessageID != "" {
inbound = bus.NormalizeInboundMessage(bus.InboundMessage{Context: inbound}).Context
opts.Dispatch.InboundContext = &inbound
}
}
}
// Keep legacy mirrors populated while the rest of the runtime migrates.
opts.SessionKey = opts.Dispatch.SessionKey
opts.SessionAliases = append([]string(nil), opts.Dispatch.SessionAliases...)
opts.UserMessage = opts.Dispatch.UserMessage
opts.Media = append([]string(nil), opts.Dispatch.Media...)
opts.InboundContext = cloneInboundContext(opts.Dispatch.InboundContext)
opts.RouteResult = cloneResolvedRoute(opts.Dispatch.RouteResult)
opts.SessionScope = session.CloneScope(opts.Dispatch.SessionScope)
if opts.InboundContext != nil {
if opts.Channel == "" {
opts.Channel = opts.InboundContext.Channel
}
if opts.ChatID == "" {
opts.ChatID = opts.InboundContext.ChatID
}
if opts.MessageID == "" {
opts.MessageID = opts.InboundContext.MessageID
}
if opts.ReplyToMessageID == "" {
opts.ReplyToMessageID = opts.InboundContext.ReplyToMessageID
}
if opts.SenderID == "" {
opts.SenderID = opts.InboundContext.SenderID
}
}
return opts
}
func inferChatTypeFromSessionScope(scope *session.SessionScope) string {
if scope == nil || len(scope.Values) == 0 {
return ""
}
chatValue := strings.TrimSpace(scope.Values["chat"])
if chatValue == "" {
return ""
}
chatType, _, ok := strings.Cut(chatValue, ":")
if !ok {
return ""
}
return strings.ToLower(strings.TrimSpace(chatType))
}
+135
View File
@@ -0,0 +1,135 @@
package agent
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
func TestNormalizeProcessOptions_PopulatesDispatchFromLegacyFields(t *testing.T) {
opts := normalizeProcessOptions(processOptions{
SessionKey: "session-1",
SessionAliases: []string{"legacy:one"},
Channel: "telegram",
ChatID: "chat-1",
MessageID: "msg-1",
ReplyToMessageID: "reply-1",
SenderID: "user-1",
UserMessage: "hello",
Media: []string{"media://one"},
})
if opts.Dispatch.SessionKey != "session-1" {
t.Fatalf("Dispatch.SessionKey = %q, want session-1", opts.Dispatch.SessionKey)
}
if len(opts.Dispatch.SessionAliases) != 1 || opts.Dispatch.SessionAliases[0] != "legacy:one" {
t.Fatalf("Dispatch.SessionAliases = %v, want [legacy:one]", opts.Dispatch.SessionAliases)
}
if opts.Dispatch.Channel() != "telegram" || opts.Dispatch.ChatID() != "chat-1" {
t.Fatalf(
"dispatch addressing = (%q,%q), want (telegram,chat-1)",
opts.Dispatch.Channel(),
opts.Dispatch.ChatID(),
)
}
if opts.Dispatch.SenderID() != "user-1" || opts.Dispatch.MessageID() != "msg-1" {
t.Fatalf("dispatch sender/message = (%q,%q)", opts.Dispatch.SenderID(), opts.Dispatch.MessageID())
}
if opts.Dispatch.ReplyToMessageID() != "reply-1" {
t.Fatalf("Dispatch.ReplyToMessageID() = %q, want reply-1", opts.Dispatch.ReplyToMessageID())
}
if opts.Dispatch.UserMessage != "hello" {
t.Fatalf("Dispatch.UserMessage = %q, want hello", opts.Dispatch.UserMessage)
}
if len(opts.Dispatch.Media) != 1 || opts.Dispatch.Media[0] != "media://one" {
t.Fatalf("Dispatch.Media = %v, want [media://one]", opts.Dispatch.Media)
}
}
func TestNormalizeProcessOptions_UsesDispatchAsSourceOfTruth(t *testing.T) {
inbound := &bus.InboundContext{
Channel: "slack",
ChatID: "C123",
ChatType: "channel",
SenderID: "U123",
MessageID: "m-1",
ReplyToMessageID: "parent-1",
}
route := &routing.ResolvedRoute{
AgentID: "support",
Channel: "slack",
AccountID: "workspace-a",
MatchedBy: "dispatch.rule:test",
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"chat", "sender"},
},
}
scope := &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "support",
Channel: "slack",
Account: "workspace-a",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "channel:c123",
},
}
opts := normalizeProcessOptions(processOptions{
Dispatch: DispatchRequest{
SessionKey: "sk_v1_example",
SessionAliases: []string{"agent:support:slack:channel:c123"},
InboundContext: inbound,
RouteResult: route,
SessionScope: scope,
UserMessage: "hello",
Media: []string{"media://one"},
},
})
if opts.SessionKey != "sk_v1_example" {
t.Fatalf("SessionKey = %q, want sk_v1_example", opts.SessionKey)
}
if opts.Channel != "slack" || opts.ChatID != "C123" {
t.Fatalf("legacy mirrors = (%q,%q), want (slack,C123)", opts.Channel, opts.ChatID)
}
if opts.SenderID != "U123" || opts.MessageID != "m-1" {
t.Fatalf("legacy sender/message = (%q,%q)", opts.SenderID, opts.MessageID)
}
if opts.ReplyToMessageID != "parent-1" {
t.Fatalf("ReplyToMessageID = %q, want parent-1", opts.ReplyToMessageID)
}
if opts.RouteResult == nil || opts.RouteResult.AgentID != "support" {
t.Fatalf("RouteResult = %#v, want support route", opts.RouteResult)
}
if opts.SessionScope == nil || opts.SessionScope.AgentID != "support" {
t.Fatalf("SessionScope = %#v, want support scope", opts.SessionScope)
}
}
func TestNormalizeProcessOptions_InfersLegacyChatTypeFromSessionScope(t *testing.T) {
opts := normalizeProcessOptions(processOptions{
Channel: "telegram",
ChatID: "-100123",
SenderID: "user-1",
UserMessage: "hello",
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "group:-100123",
},
},
})
if opts.Dispatch.InboundContext == nil {
t.Fatal("Dispatch.InboundContext is nil")
}
if opts.Dispatch.InboundContext.ChatType != "group" {
t.Fatalf("Dispatch.InboundContext.ChatType = %q, want group", opts.Dispatch.InboundContext.ChatType)
}
}
+39 -7
View File
@@ -10,6 +10,8 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -136,6 +138,31 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
InboundContext: &bus.InboundContext{
Channel: "cli",
ChatID: "direct",
ChatType: "direct",
SenderID: "tester",
},
RouteResult: &routing.ResolvedRoute{
AgentID: "main",
Channel: "cli",
AccountID: routing.DefaultAccountID,
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
MatchedBy: "default",
},
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "cli",
Account: routing.DefaultAccountID,
Dimensions: []string{"sender"},
Values: map[string]string{
"sender": "tester",
},
},
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
@@ -176,6 +203,18 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
if evt.Meta.SessionKey != "session-1" {
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
}
if evt.Context == nil || evt.Context.Inbound == nil {
t.Fatalf("event %d missing inbound turn context", i)
}
if evt.Context.Inbound.Channel != "cli" || evt.Context.Inbound.SenderID != "tester" {
t.Fatalf("event %d inbound context = %+v", i, evt.Context.Inbound)
}
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
t.Fatalf("event %d missing route context: %+v", i, evt.Context.Route)
}
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "tester" {
t.Fatalf("event %d missing session scope: %+v", i, evt.Context.Scope)
}
}
startPayload, ok := events[0].Payload.(TurnStartPayload)
@@ -472,7 +511,6 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
// Use legacyContextManager's summarizeSession via contextManager interface
lcm := &legacyContextManager{al: al}
lcm.summarizeSession(defaultAgent, "session-1")
@@ -572,12 +610,6 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
if payload.SourceTool != "async_followup" {
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
}
if payload.Channel != "cli" {
t.Fatalf("expected channel cli, got %q", payload.Channel)
}
if payload.ChatID != "direct" {
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
}
if payload.ContentLen != len("background result") {
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
}
+2 -4
View File
@@ -86,6 +86,7 @@ type Event struct {
Kind EventKind
Time time.Time
Meta EventMeta
Context *TurnContext
Payload any
}
@@ -98,6 +99,7 @@ type EventMeta struct {
Iteration int
TracePath string
Source string
turnContext *TurnContext
}
// TurnEndStatus describes the terminal state of a turn.
@@ -114,8 +116,6 @@ const (
// TurnStartPayload describes the start of a turn.
type TurnStartPayload struct {
Channel string
ChatID string
UserMessage string
MediaCount int
}
@@ -217,8 +217,6 @@ type SteeringInjectedPayload struct {
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
type FollowUpQueuedPayload struct {
SourceTool string
Channel string
ChatID string
ContentLen int
}
+15 -8
View File
@@ -90,12 +90,11 @@ type ToolApprover interface {
type LLMHookRequest struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
Tools []providers.ToolDefinition `json:"tools,omitempty"`
Options map[string]any `json:"options,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
}
@@ -104,6 +103,8 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Messages = cloneProviderMessages(r.Messages)
cloned.Tools = cloneToolDefinitions(r.Tools)
cloned.Options = cloneStringAnyMap(r.Options)
@@ -112,10 +113,9 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
type LLMHookResponse struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Model string `json:"model"`
Response *providers.LLMResponse `json:"response,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *LLMHookResponse) Clone() *LLMHookResponse {
@@ -123,12 +123,15 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Response = cloneLLMResponse(r.Response)
return &cloned
}
type ToolCallHookRequest struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
@@ -141,6 +144,8 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.HookResult = cloneToolResult(r.HookResult)
return &cloned
@@ -148,10 +153,9 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
type ToolApprovalRequest struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
@@ -159,18 +163,19 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolResultHookResponse struct {
Meta EventMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Result *tools.ToolResult `json:"result,omitempty"`
Duration time.Duration `json:"duration"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
@@ -178,6 +183,8 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.Result = cloneToolResult(r.Result)
return &cloned
+51 -3
View File
@@ -12,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -108,7 +109,8 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
}
type llmObserverHook struct {
eventCh chan Event
eventCh chan Event
lastInbound *bus.InboundContext
}
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
@@ -125,6 +127,9 @@ func (h *llmObserverHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
if req.Context != nil {
h.lastInbound = cloneInboundContext(req.Context.Inbound)
}
next := req.Clone()
next.Model = "hook-model"
return next, HookDecision{Action: HookActionModify}, nil
@@ -157,6 +162,31 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
InboundContext: &bus.InboundContext{
Channel: "cli",
ChatID: "direct",
ChatType: "direct",
SenderID: "hook-user",
},
RouteResult: &routing.ResolvedRoute{
AgentID: "main",
Channel: "cli",
AccountID: routing.DefaultAccountID,
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
MatchedBy: "default",
},
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "cli",
Account: routing.DefaultAccountID,
Dimensions: []string{"sender"},
Values: map[string]string{
"sender": "hook-user",
},
},
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
@@ -171,12 +201,30 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
if lastModel != "hook-model" {
t.Fatalf("expected model hook-model, got %q", lastModel)
}
if hook.lastInbound == nil {
t.Fatal("expected hook to receive inbound context")
}
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
}
if hook.lastInbound != nil && hook.lastInbound.ChatID != "direct" {
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
}
select {
case evt := <-hook.eventCh:
if evt.Kind != EventKindTurnEnd {
t.Fatalf("expected turn end event, got %v", evt.Kind)
}
if evt.Context == nil || evt.Context.Inbound == nil {
t.Fatal("expected observer event to carry inbound context")
}
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route)
}
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" {
t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for hook observer event")
}
@@ -725,7 +773,7 @@ func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
type result struct {
resp string
@@ -801,7 +849,7 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
type result struct {
resp string
+450 -207
View File
File diff suppressed because it is too large Load Diff
+513 -106
View File
@@ -20,6 +20,7 @@ import (
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -38,7 +39,13 @@ func (f *fakeChannel) ReasoningChannelID() string { return f.id
type fakeMediaChannel struct {
fakeChannel
sentMedia []bus.OutboundMediaMessage
sentMessages []bus.OutboundMessage
sentMedia []bus.OutboundMediaMessage
}
func (f *fakeMediaChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
f.sentMessages = append(f.sentMessages, msg)
return nil, nil
}
func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
@@ -139,7 +146,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "discord",
SenderID: "discord:123",
Sender: bus.SenderInfo{
@@ -147,7 +154,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
},
ChatID: "group-1",
Content: "hello",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -198,12 +205,12 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/use shell explain how to list files",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -288,12 +295,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/use shell",
})
}))
if err != nil {
t.Fatalf("processMessage() arm error = %v", err)
}
@@ -301,12 +308,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
t.Fatalf("arm response = %q, want armed confirmation", response)
}
response, err = al.processMessage(context.Background(), bus.InboundMessage{
response, err = al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "explain how to list files",
})
}))
if err != nil {
t.Fatalf("processMessage() follow-up error = %v", err)
}
@@ -619,12 +626,12 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
path: imagePath,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -661,16 +668,21 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
if defaultAgent == nil {
t.Fatal("expected default agent")
}
route, _, err := al.resolveMessageRoute(bus.InboundMessage{
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
}))
if err != nil {
t.Fatalf("resolveMessageRoute() error = %v", err)
}
sessionKey := resolveScopeKey(route, "")
sessionKey := resolveScopeKey(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})).SessionKey, "")
history := defaultAgent.Sessions.GetHistory(sessionKey)
if len(history) == 0 {
t.Fatal("expected session history to be saved")
@@ -714,12 +726,12 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
loop: al,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -734,6 +746,263 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
}
}
func TestRunAgentLoop_ResponseHandledToolPublishesForUserWhenSendResponseDisabled(t *testing.T) {
tmpDir := t.TempDir()
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = tmpDir
cfg.Agents.Defaults.ModelName = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
msgBus := bus.NewMessageBus()
provider := &handledUserProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
store := media.NewFileMediaStore()
al.SetMediaStore(store)
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
al.RegisterTool(&handledUserTool{})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
Dispatch: DispatchRequest{
SessionKey: "session-1",
UserMessage: "take a screenshot of the screen and send it to me",
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: defaultAgent.ID,
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "direct:chat1",
},
},
InboundContext: &bus.InboundContext{
Channel: "telegram",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
},
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop() error = %v", err)
}
if response != "" {
t.Fatalf("expected no final response when tool already handled delivery, got %q", response)
}
deadline := time.Now().Add(2 * time.Second)
for len(telegramChannel.sentMessages) == 0 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
if len(telegramChannel.sentMessages) != 1 {
t.Fatalf("expected exactly 1 sent text message, got %d", len(telegramChannel.sentMessages))
}
if telegramChannel.sentMessages[0].Content != "Handled user output from tool." {
t.Fatalf("unexpected sent text message: %+v", telegramChannel.sentMessages[0])
}
if telegramChannel.sentMessages[0].AgentID != defaultAgent.ID {
t.Fatalf("sent text agent_id = %q, want %q", telegramChannel.sentMessages[0].AgentID, defaultAgent.ID)
}
if telegramChannel.sentMessages[0].SessionKey != "session-1" {
t.Fatalf("sent text session_key = %q, want session-1", telegramChannel.sentMessages[0].SessionKey)
}
if telegramChannel.sentMessages[0].Scope == nil ||
telegramChannel.sentMessages[0].Scope.Values["chat"] != "direct:chat1" {
t.Fatalf("unexpected sent text scope: %+v", telegramChannel.sentMessages[0].Scope)
}
}
func TestAppendEventContextFields_IncludesInboundRouteAndScope(t *testing.T) {
fields := map[string]any{}
appendEventContextFields(fields, &TurnContext{
Inbound: &bus.InboundContext{
Channel: "slack",
Account: "workspace-a",
ChatID: "C123",
ChatType: "channel",
TopicID: "thread-42",
SpaceType: "workspace",
SpaceID: "T001",
SenderID: "U123",
Mentioned: true,
},
Route: &routing.ResolvedRoute{
AgentID: "support",
Channel: "slack",
AccountID: "workspace-a",
MatchedBy: "default",
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"chat", "sender"},
IdentityLinks: map[string][]string{
"canonical-user": {"slack:U123"},
},
},
},
Scope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "support",
Channel: "slack",
Account: "workspace-a",
Dimensions: []string{"chat", "sender"},
Values: map[string]string{
"chat": "channel:c123",
"sender": "u123",
},
},
})
if fields["inbound_channel"] != "slack" {
t.Fatalf("inbound_channel = %v, want slack", fields["inbound_channel"])
}
if fields["inbound_topic_id"] != "thread-42" {
t.Fatalf("inbound_topic_id = %v, want thread-42", fields["inbound_topic_id"])
}
if fields["route_matched_by"] != "default" {
t.Fatalf("route_matched_by = %v, want default", fields["route_matched_by"])
}
if fields["route_dimensions"] != "chat,sender" {
t.Fatalf("route_dimensions = %v, want chat,sender", fields["route_dimensions"])
}
if fields["route_identity_link_count"] != 1 {
t.Fatalf("route_identity_link_count = %v, want 1", fields["route_identity_link_count"])
}
if fields["scope_dimensions"] != "chat,sender" {
t.Fatalf("scope_dimensions = %v, want chat,sender", fields["scope_dimensions"])
}
if fields["scope_chat"] != "channel:c123" {
t.Fatalf("scope_chat = %v, want channel:c123", fields["scope_chat"])
}
if fields["scope_sender"] != "u123" {
t.Fatalf("scope_sender = %v, want u123", fields["scope_sender"])
}
}
func TestResolveMessageRoute_UsesInboundContextAccount(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
},
List: []config.AgentConfig{
{ID: "main", Default: true},
{ID: "work"},
},
},
Session: config.SessionConfig{
Dimensions: []string{"sender"},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"})
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
Context: bus.InboundContext{
Channel: "slack",
Account: "workspace-a",
ChatID: "C123",
ChatType: "channel",
SenderID: "U123",
SpaceID: "T001",
SpaceType: "workspace",
},
Content: "hello",
}))
if err != nil {
t.Fatalf("resolveMessageRoute() error = %v", err)
}
if route.AgentID != "main" {
t.Fatalf("AgentID = %q, want main", route.AgentID)
}
if route.MatchedBy != "default" {
t.Fatalf("MatchedBy = %q, want default", route.MatchedBy)
}
if route.AccountID != "workspace-a" {
t.Fatalf("AccountID = %q, want workspace-a", route.AccountID)
}
}
func TestResolveMessageRoute_UsesDispatchRulesInOrder(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
},
List: []config.AgentConfig{
{ID: "main", Default: true},
{ID: "support"},
{ID: "sales"},
},
Dispatch: &config.DispatchConfig{
Rules: []config.DispatchRule{
{
Name: "support-group",
Agent: "support",
When: config.DispatchSelector{
Channel: "telegram",
Chat: "group:-100123",
},
SessionDimensions: []string{"chat"},
},
{
Name: "vip-in-group",
Agent: "sales",
When: config.DispatchSelector{
Channel: "telegram",
Chat: "group:-100123",
Sender: "12345",
},
SessionDimensions: []string{"chat", "sender"},
},
},
},
},
Session: config.SessionConfig{
Dimensions: []string{"sender"},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"})
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "-100123",
ChatType: "group",
SenderID: "12345",
},
Content: "hello",
}))
if err != nil {
t.Fatalf("resolveMessageRoute() error = %v", err)
}
if route.AgentID != "support" {
t.Fatalf("AgentID = %q, want support", route.AgentID)
}
if route.MatchedBy != "dispatch.rule:support-group" {
t.Fatalf("MatchedBy = %q, want dispatch.rule:support-group", route.MatchedBy)
}
if got := route.SessionPolicy.Dimensions; len(got) != 1 || got[0] != "chat" {
t.Fatalf("SessionPolicy.Dimensions = %v, want [chat]", got)
}
}
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
tmpDir := t.TempDir()
cfg := config.DefaultConfig()
@@ -765,12 +1034,12 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
path: imagePath,
})
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
ChatID: "chat1",
SenderID: "user1",
Content: "take a screenshot of the screen and send it to me",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -975,6 +1244,66 @@ func (m *handledMediaProvider) GetDefaultModel() string {
return "handled-media-model"
}
type handledUserProvider struct {
calls int
}
func (m *handledUserProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "Delivering the result now.",
ToolCalls: []providers.ToolCall{{
ID: "call_handled_user",
Type: "function",
Name: "handled_user_tool",
Arguments: map[string]any{},
}},
}, nil
}
return &providers.LLMResponse{}, nil
}
func (m *handledUserProvider) GetDefaultModel() string {
return "handled-user-model"
}
type messageToolProvider struct {
calls int
}
func (m *messageToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
Content: "",
ToolCalls: []providers.ToolCall{{
ID: "call_message",
Type: "function",
Name: "message",
Arguments: map[string]any{"content": "direct tool message"},
}},
}, nil
}
return &providers.LLMResponse{}, nil
}
func (m *messageToolProvider) GetDefaultModel() string {
return "message-tool-model"
}
type artifactThenSendProvider struct {
calls int
}
@@ -1178,6 +1507,24 @@ func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *to
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
}
type handledUserTool struct{}
func (m *handledUserTool) Name() string { return "handled_user_tool" }
func (m *handledUserTool) Description() string {
return "Returns a user-visible result and marks delivery as handled"
}
func (m *handledUserTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *handledUserTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.UserResult("Handled user output from tool.").WithResponseHandled()
}
type handledMediaWithSteeringProvider struct {
calls int
}
@@ -1391,13 +1738,39 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
response, err := h.al.processMessage(timeoutCtx, msg)
response, err := h.al.processMessage(timeoutCtx, testInboundMessage(msg))
if err != nil {
tb.Fatalf("processMessage failed: %v", err)
}
return response
}
func testInboundMessage(msg bus.InboundMessage) bus.InboundMessage {
if msg.Context.Channel == "" &&
msg.Context.Account == "" &&
msg.Context.ChatID == "" &&
msg.Context.ChatType == "" &&
msg.Context.TopicID == "" &&
msg.Context.SpaceID == "" &&
msg.Context.SpaceType == "" &&
msg.Context.SenderID == "" &&
msg.Context.MessageID == "" &&
!msg.Context.Mentioned &&
msg.Context.ReplyToMessageID == "" &&
msg.Context.ReplyToSenderID == "" &&
len(msg.Context.ReplyHandles) == 0 &&
len(msg.Context.Raw) == 0 {
msg.Context = bus.InboundContext{
Channel: msg.Channel,
ChatID: msg.ChatID,
ChatType: "direct",
SenderID: msg.SenderID,
MessageID: msg.MessageID,
}
}
return bus.NormalizeInboundMessage(msg)
}
const responseTimeout = 3 * time.Second
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
@@ -1423,21 +1796,17 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
msg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "hello",
}
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
Peer: extractPeer(msg),
})
sessionKey := route.SessionKey
route := al.registry.ResolveRoute(bus.NormalizeInboundMessage(msg).Context)
sessionKey := al.allocateRouteSession(route, msg).SessionKey
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
@@ -1473,7 +1842,7 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
},
},
Session: config.SessionConfig{
DMScope: "per-channel-peer",
Dimensions: []string{"chat"},
},
}
@@ -1483,21 +1852,22 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
helper := testHelper{al: al}
baseMsg := bus.InboundMessage{
Channel: "whatsapp",
SenderID: "user1",
ChatID: "chat1",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "whatsapp",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
}
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/show channel",
Peer: baseMsg.Peer,
Context: bus.InboundContext{
Channel: baseMsg.Context.Channel,
ChatID: baseMsg.Context.ChatID,
ChatType: baseMsg.Context.ChatType,
SenderID: baseMsg.Context.SenderID,
},
Content: "/show channel",
})
if showResp != "Current Channel: whatsapp" {
t.Fatalf("unexpected /show reply: %q", showResp)
@@ -1507,11 +1877,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
}
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/foo",
Peer: baseMsg.Peer,
Context: bus.InboundContext{
Channel: baseMsg.Context.Channel,
ChatID: baseMsg.Context.ChatID,
ChatType: baseMsg.Context.ChatType,
SenderID: baseMsg.Context.SenderID,
},
Content: "/foo",
})
if fooResp != "LLM reply" {
t.Fatalf("unexpected /foo reply: %q", fooResp)
@@ -1521,11 +1893,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
}
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/new",
Peer: baseMsg.Peer,
Context: bus.InboundContext{
Channel: baseMsg.Context.Channel,
ChatID: baseMsg.Context.ChatID,
ChatType: baseMsg.Context.ChatType,
SenderID: baseMsg.Context.SenderID,
},
Content: "/new",
})
if newResp != "LLM reply" {
t.Fatalf("unexpected /new reply: %q", newResp)
@@ -1578,10 +1952,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to deepseek",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
@@ -1592,10 +1962,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
SenderID: "user1",
ChatID: "chat1",
Content: "/show model",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
@@ -1643,10 +2009,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to missing",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if switchResp != `model "missing" not found in model_list or providers` {
t.Fatalf("unexpected /switch error reply: %q", switchResp)
@@ -1657,10 +2019,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
SenderID: "user1",
ChatID: "chat1",
Content: "/show model",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
@@ -1727,10 +2085,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
SenderID: "user1",
ChatID: "chat1",
Content: "hello before switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if firstResp != "local reply" {
t.Fatalf("unexpected response before switch: %q", firstResp)
@@ -1750,10 +2104,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to deepseek",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
@@ -1764,10 +2114,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
SenderID: "user1",
ChatID: "chat1",
Content: "hello after switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if secondResp != "remote reply" {
t.Fatalf("unexpected response after switch: %q", secondResp)
@@ -1857,10 +2203,6 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if resp != "light reply" {
t.Fatalf("response = %q, want %q", resp, "light reply")
@@ -1942,7 +2284,6 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{Kind: "direct", ID: "user1"},
})
if resp != "fallback reply" {
@@ -2020,7 +2361,6 @@ func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{Kind: "direct", ID: "user1"},
})
if resp != "active provider reply" {
@@ -2291,14 +2631,16 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
if defaultAgent == nil {
t.Fatal("No default agent found")
}
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: "test",
Peer: &routing.RoutePeer{
Kind: "direct",
ID: "cron",
},
route := al.registry.ResolveRoute(bus.InboundContext{
Channel: "test",
ChatType: "direct",
SenderID: "cron",
})
history := defaultAgent.Sessions.GetHistory(route.SessionKey)
history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
Channel: "test",
SenderID: "cron",
ChatID: "chat1",
})).SessionKey)
if len(history) != 4 {
t.Fatalf("history len = %d, want 4", len(history))
}
@@ -2556,8 +2898,7 @@ func TestHandleReasoning(t *testing.T) {
for i := 0; ; i++ {
fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{
Channel: "filler",
ChatID: "filler",
Context: bus.NewOutboundContext("filler", "filler", ""),
Content: fmt.Sprintf("filler-%d", i),
})
fillCancel()
@@ -2631,12 +2972,12 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
chManager.RegisterChannel("telegram", &fakeChannel{id: "reason-chat"})
al.SetChannelManager(chManager)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -2652,6 +2993,9 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
if outbound.ChatID != "reason-chat" {
t.Fatalf("reasoning chatID = %q, want %q", outbound.ChatID, "reason-chat")
}
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "reason-chat" {
t.Fatalf("unexpected reasoning context: %+v", outbound.Context)
}
if outbound.Content != "thinking trace" {
t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace")
}
@@ -2711,8 +3055,12 @@ func TestProcessMessage_PicoPublishesReasoningAsThoughtMessage(t *testing.T) {
if thoughtMsg.Channel != "pico" || thoughtMsg.ChatID != "pico:test-session" {
t.Fatalf("thought message route = %s/%s, want pico/pico:test-session", thoughtMsg.Channel, thoughtMsg.ChatID)
}
if thoughtMsg.Metadata[metadataKeyMessageKind] != messageKindThought {
t.Fatalf("thought metadata kind = %q, want %q", thoughtMsg.Metadata[metadataKeyMessageKind], messageKindThought)
if thoughtMsg.Context.Raw[metadataKeyMessageKind] != messageKindThought {
t.Fatalf(
"thought metadata kind = %q, want %q",
thoughtMsg.Context.Raw[metadataKeyMessageKind],
messageKindThought,
)
}
}
@@ -2793,12 +3141,12 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
provider := &toolFeedbackProvider{filePath: heartbeatFile}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "user-1",
ChatID: "chat-1",
Content: "check tool feedback",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -2814,14 +3162,73 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
if outbound.ChatID != "chat-1" {
t.Fatalf("tool feedback chatID = %q, want %q", outbound.ChatID, "chat-1")
}
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
t.Fatalf("unexpected tool feedback context: %+v", outbound.Context)
}
if !strings.Contains(outbound.Content, "`read_file`") {
t.Fatalf("tool feedback content = %q, want read_file preview", outbound.Content)
}
if outbound.AgentID != "main" {
t.Fatalf("tool feedback agent_id = %q, want main", outbound.AgentID)
}
if outbound.SessionKey == "" {
t.Fatal("expected tool feedback to carry session_key")
}
if outbound.Scope == nil || outbound.Scope.AgentID != "main" || outbound.Scope.Channel != "telegram" {
t.Fatalf("expected tool feedback scope, got %+v", outbound.Scope)
}
case <-time.After(2 * time.Second):
t.Fatal("expected outbound tool feedback for regular messages")
}
}
func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = t.TempDir()
cfg.Agents.Defaults.ModelName = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
cfg.Session.Dimensions = []string{"chat"}
msgBus := bus.NewMessageBus()
provider := &messageToolProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "user-1",
ChatID: "chat-1",
Content: "send a direct message",
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response == "" {
t.Fatal("expected processMessage() to return a final loop response")
}
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "direct tool message" {
t.Fatalf("outbound content = %q, want direct tool message", outbound.Content)
}
if outbound.AgentID != "main" {
t.Fatalf("outbound agent_id = %q, want main", outbound.AgentID)
}
if outbound.SessionKey == "" {
t.Fatal("expected message tool outbound to carry session_key")
}
if outbound.Scope == nil || outbound.Scope.Values["chat"] != "direct:chat-1" {
t.Fatalf("unexpected message tool outbound scope: %+v", outbound.Scope)
}
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context)
}
case <-time.After(2 * time.Second):
t.Fatal("expected message tool outbound")
}
}
func TestRun_PicoPublishesAssistantContentDuringToolCallsWithoutFinalDuplicate(t *testing.T) {
tmpDir := t.TempDir()
@@ -3363,13 +3770,13 @@ func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "test",
ChatID: "chat1",
SenderID: "user1",
SessionKey: "test-session",
Content: "trigger recovery",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
@@ -3405,12 +3812,12 @@ func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "test",
ChatID: "chat1",
SenderID: "user1",
Content: "hello",
})
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
+4 -3
View File
@@ -3,6 +3,7 @@ package agent
import (
"sync"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -64,9 +65,9 @@ func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
return agent, ok
}
// ResolveRoute determines which agent handles the message.
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
return r.resolver.ResolveRoute(input)
// ResolveRoute determines which agent handles the normalized inbound context.
func (r *AgentRegistry) ResolveRoute(inbound bus.InboundContext) routing.ResolvedRoute {
return r.resolver.ResolveRoute(inbound)
}
// ListAgentIDs returns all registered agent IDs.
+35 -8
View File
@@ -3,12 +3,14 @@ package agent
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -290,12 +292,22 @@ func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
sessionKey, channel, chatID string,
scope *session.SessionScope,
steeringMsgs []providers.Message,
) (string, error) {
dispatch := DispatchRequest{
SessionKey: sessionKey,
SessionScope: session.CloneScope(scope),
}
if channel != "" || chatID != "" {
dispatch.InboundContext = &bus.InboundContext{
Channel: channel,
ChatID: chatID,
ChatType: inferChatTypeFromSessionScope(scope),
}
}
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
Dispatch: dispatch,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
@@ -310,9 +322,19 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
return nil
}
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
return agent
agentIDs := registry.ListAgentIDs()
sort.Strings(agentIDs)
for _, agentID := range agentIDs {
agent, ok := registry.GetAgent(agentID)
if !ok || agent == nil {
continue
}
resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey)
if resolvedAgentID == "" {
continue
}
if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok {
return scopedAgent
}
}
@@ -352,7 +374,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
}
}
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
var scope *session.SessionScope
if metaStore, ok := agent.Sessions.(session.MetadataAwareSessionStore); ok {
scope = metaStore.GetSessionScope(sessionKey)
}
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, scope, steeringMsgs)
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
+89 -36
View File
@@ -17,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -357,7 +358,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
},
},
Session: config.SessionConfig{
DMScope: "per-peer",
Dimensions: []string{"sender"},
},
}
@@ -365,14 +366,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
activeMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "active turn",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "active turn",
}
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
if !ok {
@@ -380,14 +380,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
}
otherMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user2",
ChatID: "chat2",
Content: "other session",
Peer: bus.Peer{
Kind: "direct",
ID: "user2",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat2",
ChatType: "direct",
SenderID: "user2",
},
Content: "other session",
}
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
if !ok {
@@ -422,9 +421,9 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
select {
case <-ctx.Done():
t.Fatalf("timeout waiting for requeued message on outbound bus")
case requeued := <-msgBus.OutboundChan():
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
t.Fatalf("timeout waiting for requeued message on inbound bus")
case requeued := <-msgBus.InboundChan():
if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID ||
requeued.Content != otherMsg.Content {
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
}
@@ -841,24 +840,22 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
}()
first := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "first message",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "first message",
}
late := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "late append",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "late append",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
@@ -949,7 +946,7 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
@@ -1013,6 +1010,62 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
}
}
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
List: []config.AgentConfig{
{ID: "sales", Default: true},
{ID: "support"},
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
support, ok := al.registry.GetAgent("support")
if !ok || support == nil {
t.Fatal("expected support agent")
}
metaStore, ok := support.Sessions.(session.MetadataAwareSessionStore)
if !ok {
t.Fatal("support session store does not support metadata")
}
alias := "agent:support:slack:channel:c001"
key := session.BuildOpaqueSessionKey(alias)
scope := &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "support",
Channel: "slack",
Account: "default",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "channel:c001",
},
}
metaStore.EnsureSessionMetadata(key, scope, []string{alias})
got := al.agentForSession(key)
if got == nil {
t.Fatal("agentForSession() returned nil")
}
if got.ID != "support" {
t.Fatalf("agentForSession() = %q, want %q", got.ID, "support")
}
}
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -1060,7 +1113,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.SetMediaStore(store)
@@ -1168,7 +1221,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
@@ -1322,7 +1375,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
+13 -7
View File
@@ -351,15 +351,17 @@ func spawnSubTurn(
}
// Create processOptions for the child turn
dispatch := DispatchRequest{
SessionKey: childID,
UserMessage: cfg.SystemPrompt,
Media: nil,
InboundContext: cloneInboundContext(parentTS.opts.Dispatch.InboundContext),
}
opts := processOptions{
SessionKey: childID,
Channel: parentTS.channel,
ChatID: parentTS.chatID,
SenderID: parentTS.opts.SenderID,
Dispatch: dispatch,
SenderID: parentTS.opts.Dispatch.SenderID(),
SenderDisplayName: parentTS.opts.SenderDisplayName,
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
SystemPromptOverride: cfg.ActualSystemPrompt,
Media: nil,
InitialSteeringMessages: cfg.InitialMessages,
DefaultResponse: "",
EnableSummary: false,
@@ -369,7 +371,11 @@ func spawnSubTurn(
}
// Create event scope for the child turn
scope := al.newTurnEventScope(agent.ID, childID)
scope := al.newTurnEventScope(
agent.ID,
childID,
newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope),
)
// Create child turnState using the new API
childTS := newTurnState(&agent, opts, scope)
+15 -12
View File
@@ -56,6 +56,7 @@ type turnState struct {
turnID string
agentID string
sessionKey string
turnCtx *TurnContext
channel string
chatID string
@@ -115,11 +116,12 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.SessionKey,
channel: opts.Channel,
chatID: opts.ChatID,
userMessage: opts.UserMessage,
media: append([]string(nil), opts.Media...),
sessionKey: opts.Dispatch.SessionKey,
turnCtx: cloneTurnContext(scope.context),
channel: opts.Dispatch.Channel(),
chatID: opts.Dispatch.ChatID(),
userMessage: opts.Dispatch.UserMessage,
media: append([]string(nil), opts.Dispatch.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
@@ -127,7 +129,7 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
}
return ts
@@ -302,12 +304,13 @@ func (ts *turnState) hardAbortRequested() bool {
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
snap := ts.snapshot()
return EventMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
turnContext: cloneTurnContext(ts.turnCtx),
}
}
+92
View File
@@ -0,0 +1,92 @@
package agent
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
// TurnContext carries normalized turn-scoped facts that can be shared across
// events, hooks, and other runtime observers without re-parsing legacy fields.
type TurnContext struct {
Inbound *bus.InboundContext `json:"inbound,omitempty"`
Route *routing.ResolvedRoute `json:"route,omitempty"`
Scope *session.SessionScope `json:"scope,omitempty"`
}
func newTurnContext(
inbound *bus.InboundContext,
route *routing.ResolvedRoute,
scope *session.SessionScope,
) *TurnContext {
if inbound == nil && route == nil && scope == nil {
return nil
}
return &TurnContext{
Inbound: cloneInboundContext(inbound),
Route: cloneResolvedRoute(route),
Scope: session.CloneScope(scope),
}
}
func cloneTurnContext(ctx *TurnContext) *TurnContext {
if ctx == nil {
return nil
}
cloned := *ctx
cloned.Inbound = cloneInboundContext(ctx.Inbound)
cloned.Route = cloneResolvedRoute(ctx.Route)
cloned.Scope = session.CloneScope(ctx.Scope)
return &cloned
}
func cloneInboundContext(ctx *bus.InboundContext) *bus.InboundContext {
if ctx == nil {
return nil
}
cloned := *ctx
cloned.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
cloned.Raw = cloneStringMap(ctx.Raw)
return &cloned
}
func cloneStringMap(src map[string]string) map[string]string {
if len(src) == 0 {
return nil
}
cloned := make(map[string]string, len(src))
for k, v := range src {
cloned[k] = v
}
return cloned
}
func cloneEventMeta(meta EventMeta) EventMeta {
meta.turnContext = cloneTurnContext(meta.turnContext)
return meta
}
func cloneResolvedRoute(route *routing.ResolvedRoute) *routing.ResolvedRoute {
if route == nil {
return nil
}
cloned := *route
cloned.SessionPolicy = routing.SessionPolicy{
Dimensions: append([]string(nil), route.SessionPolicy.Dimensions...),
IdentityLinks: cloneIdentityLinks(route.SessionPolicy.IdentityLinks),
}
return &cloned
}
func cloneIdentityLinks(src map[string][]string) map[string][]string {
if len(src) == 0 {
return nil
}
cloned := make(map[string][]string, len(src))
for canonical, ids := range src {
dup := make([]string, len(ids))
copy(dup, ids)
cloned[canonical] = dup
}
return cloned
}
+10 -9
View File
@@ -226,8 +226,7 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err})
}
if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: channelType,
ChatID: acc.chatID,
Context: bus.NewOutboundContext(channelType, acc.chatID, ""),
Content: "Goodbye! Leaving the voice channel.",
}); err != nil {
logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err})
@@ -238,14 +237,16 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally."
if err := a.bus.PublishInbound(ctx, bus.InboundMessage{
Channel: channelType,
SenderID: acc.speakerID,
ChatID: acc.chatID,
Content: res.Text + oralPrompt,
Peer: bus.Peer{Kind: "channel", ID: acc.chatID},
Metadata: map[string]string{
"is_voice": "true",
Context: bus.InboundContext{
Channel: channelType,
ChatID: acc.chatID,
ChatType: "channel",
SenderID: acc.speakerID,
Raw: map[string]string{
"is_voice": "true",
},
},
Content: res.Text + oralPrompt,
}); err != nil {
logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err})
}
+2 -2
View File
@@ -185,8 +185,8 @@ func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) {
if !strings.Contains(msg.Content, "hello there") {
t.Fatalf("unexpected inbound content: %q", msg.Content)
}
if msg.Metadata["is_voice"] != "true" {
t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata)
if msg.Context.Raw["is_voice"] != "true" {
t.Fatalf("expected is_voice metadata, got %#v", msg.Context.Raw)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("expected inbound publish")
+19 -1
View File
@@ -12,6 +12,12 @@ import (
// ErrBusClosed is returned when publishing to a closed MessageBus.
var ErrBusClosed = errors.New("message bus closed")
var (
ErrMissingInboundContext = errors.New("inbound message context is required")
ErrMissingOutboundContext = errors.New("outbound message context is required")
ErrMissingOutboundMediaContext = errors.New("outbound media context is required")
)
const defaultBusBufferSize = 64
// StreamDelegate is implemented by the channel Manager to provide streaming
@@ -49,7 +55,7 @@ func NewMessageBus() *MessageBus {
inbound: make(chan InboundMessage, defaultBusBufferSize),
outbound: make(chan OutboundMessage, defaultBusBufferSize),
outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize),
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer.
voiceControls: make(chan VoiceControl, defaultBusBufferSize),
done: make(chan struct{}),
}
@@ -84,6 +90,10 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
msg = NormalizeInboundMessage(msg)
if msg.Context.isZero() {
return ErrMissingInboundContext
}
return publish(ctx, mb, mb.inbound, msg)
}
@@ -92,6 +102,10 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
}
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
msg = NormalizeOutboundMessage(msg)
if msg.Context.isZero() {
return ErrMissingOutboundContext
}
return publish(ctx, mb, mb.outbound, msg)
}
@@ -100,6 +114,10 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
}
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
msg = NormalizeOutboundMediaMessage(msg)
if msg.Context.isZero() {
return ErrMissingOutboundMediaContext
}
return publish(ctx, mb, mb.outboundMedia, msg)
}
+438 -15
View File
@@ -14,10 +14,13 @@ func TestPublishConsume(t *testing.T) {
ctx := context.Background()
msg := InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "hello",
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "hello",
}
if err := mb.PublishInbound(ctx, msg); err != nil {
@@ -34,6 +37,138 @@ func TestPublishConsume(t *testing.T) {
if got.Channel != "test" {
t.Fatalf("expected channel 'test', got %q", got.Channel)
}
if got.Context.Channel != "test" {
t.Fatalf("expected context channel 'test', got %q", got.Context.Channel)
}
if got.Context.ChatID != "chat1" {
t.Fatalf("expected context chat ID 'chat1', got %q", got.Context.ChatID)
}
if got.Context.SenderID != "user1" {
t.Fatalf("expected context sender ID 'user1', got %q", got.Context.SenderID)
}
}
func TestPublishInbound_NormalizesContext(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := InboundMessage{
Context: InboundContext{
Channel: "slack",
Account: "workspace-a",
ChatID: "C456/1712",
ChatType: "group",
TopicID: "1712",
SpaceID: "T001",
SpaceType: "team",
SenderID: "U123",
MessageID: "1712.01",
ReplyToMessageID: "1700.01",
Mentioned: true,
},
Content: "hello",
}
if err := mb.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
got := <-mb.InboundChan()
if got.Context.Channel != "slack" {
t.Fatalf("expected context channel slack, got %q", got.Context.Channel)
}
if got.Context.Account != "workspace-a" {
t.Fatalf("expected context account workspace-a, got %q", got.Context.Account)
}
if got.Context.ChatType != "group" {
t.Fatalf("expected context chat type group, got %q", got.Context.ChatType)
}
if got.Context.TopicID != "1712" {
t.Fatalf("expected topic 1712, got %q", got.Context.TopicID)
}
if got.Context.SpaceType != "team" || got.Context.SpaceID != "T001" {
t.Fatalf("expected team space T001, got %q/%q", got.Context.SpaceType, got.Context.SpaceID)
}
if !got.Context.Mentioned {
t.Fatal("expected mentioned=true in context")
}
if got.Context.ReplyToMessageID != "1700.01" {
t.Fatalf("expected reply_to_message_id 1700.01, got %q", got.Context.ReplyToMessageID)
}
}
func TestPublishInbound_MirrorsContextIntoConvenienceFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := InboundMessage{
Context: InboundContext{
Channel: "telegram",
Account: "bot-a",
ChatID: "-1001",
ChatType: "group",
TopicID: "42",
SpaceID: "guild-9",
SpaceType: "guild",
SenderID: "user-1",
MessageID: "777",
Mentioned: true,
ReplyToMessageID: "666",
},
Content: "hi",
}
if err := mb.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
got := <-mb.InboundChan()
if got.Channel != "telegram" {
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
}
if got.ChatID != "-1001" {
t.Fatalf("expected legacy chat ID -1001, got %q", got.ChatID)
}
if got.SenderID != "user-1" {
t.Fatalf("expected legacy sender ID user-1, got %q", got.SenderID)
}
if got.MessageID != "777" {
t.Fatalf("expected legacy message ID 777, got %q", got.MessageID)
}
if got.Context.Account != "bot-a" || got.Context.SpaceID != "guild-9" || got.Context.TopicID != "42" {
t.Fatalf("unexpected normalized context: %+v", got.Context)
}
}
func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := InboundMessage{
Channel: "pico",
ChatID: "session-1",
SenderID: "user-1",
MessageID: "msg-1",
Content: "hello",
}
if err := mb.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
got := <-mb.InboundChan()
if got.Context.Channel != "pico" {
t.Fatalf("expected context channel pico, got %q", got.Context.Channel)
}
if got.Context.ChatID != "session-1" {
t.Fatalf("expected context chat ID session-1, got %q", got.Context.ChatID)
}
if got.Context.SenderID != "user-1" {
t.Fatalf("expected context sender ID user-1, got %q", got.Context.SenderID)
}
if got.Context.MessageID != "msg-1" {
t.Fatalf("expected context message ID msg-1, got %q", got.Context.MessageID)
}
}
func TestPublishOutboundSubscribe(t *testing.T) {
@@ -43,8 +178,10 @@ func TestPublishOutboundSubscribe(t *testing.T) {
ctx := context.Background()
msg := OutboundMessage{
Channel: "telegram",
ChatID: "123",
Context: InboundContext{
Channel: "telegram",
ChatID: "123",
},
Content: "world",
}
@@ -59,6 +196,222 @@ func TestPublishOutboundSubscribe(t *testing.T) {
if got.Content != "world" {
t.Fatalf("expected content 'world', got %q", got.Content)
}
if got.Context.Channel != "telegram" || got.Context.ChatID != "123" {
t.Fatalf("expected normalized outbound context, got %+v", got.Context)
}
}
func TestPublishOutbound_MirrorsContextToLegacyFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMessage{
Context: InboundContext{
Channel: "telegram",
ChatID: "chat-42",
ReplyToMessageID: "msg-9",
},
AgentID: "main",
SessionKey: "sk_v1_123",
Scope: &OutboundScope{
Version: 1,
AgentID: "main",
Channel: "telegram",
Account: "bot-a",
Dimensions: []string{"chat", "sender"},
Values: map[string]string{
"chat": "direct:chat-42",
"sender": "user-1",
},
},
Content: "reply",
}
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got := <-mb.OutboundChan()
if got.Channel != "telegram" {
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
}
if got.ChatID != "chat-42" {
t.Fatalf("expected legacy chat ID chat-42, got %q", got.ChatID)
}
if got.ReplyToMessageID != "msg-9" {
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
}
if got.AgentID != "main" || got.SessionKey != "sk_v1_123" {
t.Fatalf("unexpected outbound turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
}
if got.Scope == nil || got.Scope.AgentID != "main" || got.Scope.Values["chat"] != "direct:chat-42" {
t.Fatalf("unexpected outbound scope: %+v", got.Scope)
}
if got.Context.Channel != "telegram" || got.Context.ChatID != "chat-42" {
t.Fatalf("unexpected outbound context: %+v", got.Context)
}
}
func TestPublishOutbound_PreservesExplicitReplyToMessageID(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMessage{
Context: InboundContext{
Channel: "telegram",
ChatID: "chat-42",
},
ReplyToMessageID: "msg-9",
Content: "reply",
}
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got := <-mb.OutboundChan()
if got.ReplyToMessageID != "msg-9" {
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
}
if got.Context.ReplyToMessageID != "msg-9" {
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
}
}
func TestPublishOutbound_PreservesExplicitReplyToMessageIDWhenContextReplyIsBlank(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMessage{
Context: InboundContext{
Channel: "telegram",
ChatID: "chat-42",
ReplyToMessageID: " ",
},
ReplyToMessageID: "msg-9",
Content: "reply",
}
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got := <-mb.OutboundChan()
if got.ReplyToMessageID != "msg-9" {
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
}
if got.Context.ReplyToMessageID != "msg-9" {
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
}
}
func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMediaMessage{
Context: InboundContext{
Channel: "slack",
ChatID: "C001",
},
AgentID: "support",
SessionKey: "sk_v1_media",
Scope: &OutboundScope{
Version: 1,
AgentID: "support",
Channel: "slack",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "channel:c001",
},
},
Parts: []MediaPart{{Type: "image", Ref: "media://1"}},
}
if err := mb.PublishOutboundMedia(context.Background(), msg); err != nil {
t.Fatalf("PublishOutboundMedia failed: %v", err)
}
got := <-mb.OutboundMediaChan()
if got.Channel != "slack" {
t.Fatalf("expected legacy channel slack, got %q", got.Channel)
}
if got.ChatID != "C001" {
t.Fatalf("expected legacy chat ID C001, got %q", got.ChatID)
}
if got.AgentID != "support" || got.SessionKey != "sk_v1_media" {
t.Fatalf("unexpected outbound media turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
}
if got.Scope == nil || got.Scope.Values["chat"] != "channel:c001" {
t.Fatalf("unexpected outbound media scope: %+v", got.Scope)
}
if got.Context.Channel != "slack" || got.Context.ChatID != "C001" {
t.Fatalf("unexpected outbound media context: %+v", got.Context)
}
}
func TestPublishAudioChunkSubscribe(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
chunk := AudioChunk{
SessionID: "voice-1",
SpeakerID: "speaker-1",
ChatID: "chat-1",
Channel: "discord",
Sequence: 7,
Format: "opus",
Data: []byte{0x01, 0x02},
}
if err := mb.PublishAudioChunk(context.Background(), chunk); err != nil {
t.Fatalf("PublishAudioChunk failed: %v", err)
}
got, ok := <-mb.AudioChunksChan()
if !ok {
t.Fatal("AudioChunksChan returned ok=false")
}
if got.SessionID != "voice-1" || got.Sequence != 7 {
t.Fatalf("unexpected audio chunk: %+v", got)
}
}
func TestPublishVoiceControlSubscribe(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctrl := VoiceControl{
SessionID: "voice-1",
ChatID: "chat-1",
Type: "command",
Action: "start",
}
if err := mb.PublishVoiceControl(context.Background(), ctrl); err != nil {
t.Fatalf("PublishVoiceControl failed: %v", err)
}
got, ok := <-mb.VoiceControlsChan()
if !ok {
t.Fatal("VoiceControlsChan returned ok=false")
}
if got.Type != "command" || got.Action != "start" {
t.Fatalf("unexpected voice control: %+v", got)
}
}
func TestNewOutboundContext_NormalizesReplyAddress(t *testing.T) {
ctx := NewOutboundContext(" telegram ", " chat-42 ", " msg-9 ")
if ctx.Channel != "telegram" {
t.Fatalf("expected channel telegram, got %q", ctx.Channel)
}
if ctx.ChatID != "chat-42" {
t.Fatalf("expected chat_id chat-42, got %q", ctx.ChatID)
}
if ctx.ReplyToMessageID != "msg-9" {
t.Fatalf("expected reply_to_message_id msg-9, got %q", ctx.ReplyToMessageID)
}
}
func TestPublishInbound_ContextCancel(t *testing.T) {
@@ -68,7 +421,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
// Fill the buffer
ctx := context.Background()
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
if err := mb.PublishInbound(ctx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-fill",
ChatType: "direct",
SenderID: "user-fill",
},
Content: "fill",
}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
@@ -77,7 +438,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
err := mb.PublishInbound(cancelCtx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-overflow",
ChatType: "direct",
SenderID: "user-overflow",
},
Content: "overflow",
})
if err == nil {
t.Fatal("expected error from canceled context, got nil")
}
@@ -90,7 +459,15 @@ func TestPublishInbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
err := mb.PublishInbound(context.Background(), InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "test",
})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed, got %v", err)
}
@@ -100,7 +477,13 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
err := mb.PublishOutbound(context.Background(), OutboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
},
Content: "test",
})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed, got %v", err)
}
@@ -112,14 +495,30 @@ func TestConsumeInbound_ContextCancel(t *testing.T) {
defer mb.Close()
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
if err := mb.PublishInbound(context.Background(), InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-fill",
ChatType: "direct",
SenderID: "user-fill",
},
Content: "fill",
}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
mb.PublishInbound(ctx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-cancel",
ChatType: "direct",
SenderID: "user-cancel",
},
Content: "ContextCancel",
})
select {
case <-ctx.Done():
@@ -213,7 +612,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
// Fill the buffer
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
if err := mb.PublishInbound(ctx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-fill",
ChatType: "direct",
SenderID: "user-fill",
},
Content: "fill",
}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
@@ -222,7 +629,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
err := mb.PublishInbound(timeoutCtx, InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat-overflow",
ChatType: "direct",
SenderID: "user-overflow",
},
Content: "overflow",
})
if err == nil {
t.Fatal("expected error when buffer is full and context times out")
}
@@ -240,7 +655,15 @@ func TestCloseIdempotent(t *testing.T) {
mb.Close()
// After close, publish should return ErrBusClosed
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
err := mb.PublishInbound(context.Background(), InboundMessage{
Context: InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "test",
})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
}
+81
View File
@@ -0,0 +1,81 @@
package bus
import "strings"
// NormalizeInboundMessage ensures the inbound context is normalized and keeps
// convenience mirrors in sync for runtime consumers.
func NormalizeInboundMessage(msg InboundMessage) InboundMessage {
if msg.Context.Channel == "" {
msg.Context.Channel = msg.Channel
}
if msg.Context.ChatID == "" {
msg.Context.ChatID = msg.ChatID
}
if msg.Context.SenderID == "" {
msg.Context.SenderID = msg.SenderID
}
if msg.Context.MessageID == "" {
msg.Context.MessageID = msg.MessageID
}
msg.Context = normalizeInboundContext(msg.Context)
msg.Channel = msg.Context.Channel
msg.SenderID = msg.Context.SenderID
msg.ChatID = msg.Context.ChatID
if msg.MessageID == "" {
msg.MessageID = msg.Context.MessageID
}
if msg.Context.MessageID == "" {
msg.Context.MessageID = msg.MessageID
}
return msg
}
func (ctx InboundContext) isZero() bool {
return ctx.Channel == "" &&
ctx.Account == "" &&
ctx.ChatID == "" &&
ctx.ChatType == "" &&
ctx.TopicID == "" &&
ctx.SpaceID == "" &&
ctx.SpaceType == "" &&
ctx.SenderID == "" &&
ctx.MessageID == "" &&
!ctx.Mentioned &&
ctx.ReplyToMessageID == "" &&
ctx.ReplyToSenderID == "" &&
len(ctx.ReplyHandles) == 0 &&
len(ctx.Raw) == 0
}
func normalizeInboundContext(ctx InboundContext) InboundContext {
ctx.Channel = strings.TrimSpace(ctx.Channel)
ctx.Account = strings.TrimSpace(ctx.Account)
ctx.ChatID = strings.TrimSpace(ctx.ChatID)
ctx.ChatType = normalizeKind(ctx.ChatType)
ctx.TopicID = strings.TrimSpace(ctx.TopicID)
ctx.SpaceID = strings.TrimSpace(ctx.SpaceID)
ctx.SpaceType = normalizeKind(ctx.SpaceType)
ctx.SenderID = strings.TrimSpace(ctx.SenderID)
ctx.MessageID = strings.TrimSpace(ctx.MessageID)
ctx.ReplyToMessageID = strings.TrimSpace(ctx.ReplyToMessageID)
ctx.ReplyToSenderID = strings.TrimSpace(ctx.ReplyToSenderID)
ctx.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
ctx.Raw = cloneStringMap(ctx.Raw)
return ctx
}
func cloneStringMap(src map[string]string) map[string]string {
if len(src) == 0 {
return nil
}
dst := make(map[string]string, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func normalizeKind(kind string) string {
return strings.ToLower(strings.TrimSpace(kind))
}
+84
View File
@@ -0,0 +1,84 @@
package bus
import "strings"
// NewOutboundContext builds the minimal normalized addressing context required
// to deliver an outbound text message or reply.
func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext {
return normalizeInboundContext(InboundContext{
Channel: strings.TrimSpace(channel),
ChatID: strings.TrimSpace(chatID),
ReplyToMessageID: strings.TrimSpace(replyToMessageID),
})
}
// NormalizeOutboundMessage ensures Context is normalized and keeps convenience
// mirrors in sync for runtime consumers.
func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage {
msg.Channel = strings.TrimSpace(msg.Channel)
msg.ChatID = strings.TrimSpace(msg.ChatID)
msg.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID)
if msg.Context.Channel == "" {
msg.Context.Channel = msg.Channel
}
if msg.Context.ChatID == "" {
msg.Context.ChatID = msg.ChatID
}
if msg.Context.ReplyToMessageID == "" {
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
}
msg.Context = normalizeInboundContext(msg.Context)
if msg.Channel == "" {
msg.Channel = msg.Context.Channel
}
if msg.ChatID == "" {
msg.ChatID = msg.Context.ChatID
}
if msg.ReplyToMessageID == "" {
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
}
if msg.Context.ReplyToMessageID == "" {
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
}
msg.Scope = cloneOutboundScope(msg.Scope)
return msg
}
// NormalizeOutboundMediaMessage ensures media outbound messages also carry a
// normalized context while keeping convenience mirrors in sync.
func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage {
msg.Channel = strings.TrimSpace(msg.Channel)
msg.ChatID = strings.TrimSpace(msg.ChatID)
if msg.Context.Channel == "" {
msg.Context.Channel = msg.Channel
}
if msg.Context.ChatID == "" {
msg.Context.ChatID = msg.ChatID
}
msg.Context = normalizeInboundContext(msg.Context)
if msg.Channel == "" {
msg.Channel = msg.Context.Channel
}
if msg.ChatID == "" {
msg.ChatID = msg.Context.ChatID
}
msg.Scope = cloneOutboundScope(msg.Scope)
return msg
}
func cloneOutboundScope(scope *OutboundScope) *OutboundScope {
if scope == nil {
return nil
}
cloned := *scope
if len(scope.Dimensions) > 0 {
cloned.Dimensions = append([]string(nil), scope.Dimensions...)
}
if len(scope.Values) > 0 {
cloned.Values = make(map[string]string, len(scope.Values))
for key, value := range scope.Values {
cloned.Values[key] = value
}
}
return &cloned
}
+64 -25
View File
@@ -1,11 +1,5 @@
package bus
// Peer identifies the routing peer for a message (direct, group, channel, etc.)
type Peer struct {
Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
ID string `json:"id"`
}
// SenderInfo provides structured sender identity information.
type SenderInfo struct {
Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ...
@@ -15,26 +9,67 @@ type SenderInfo struct {
DisplayName string `json:"display_name,omitempty"` // display name
}
// InboundContext captures the normalized, platform-agnostic facts about an
// inbound message. This is the source of truth for routing and session
// allocation.
type InboundContext struct {
Channel string `json:"channel"`
Account string `json:"account,omitempty"`
ChatID string `json:"chat_id"`
ChatType string `json:"chat_type,omitempty"` // direct / group / channel
TopicID string `json:"topic_id,omitempty"`
SpaceID string `json:"space_id,omitempty"`
SpaceType string `json:"space_type,omitempty"` // guild / team / workspace / tenant
SenderID string `json:"sender_id"`
MessageID string `json:"message_id,omitempty"`
Mentioned bool `json:"mentioned,omitempty"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
ReplyToSenderID string `json:"reply_to_sender_id,omitempty"`
ReplyHandles map[string]string `json:"reply_handles,omitempty"`
Raw map[string]string `json:"raw,omitempty"`
}
type InboundMessage struct {
Channel string `json:"channel"`
SenderID string `json:"sender_id"`
Sender SenderInfo `json:"sender"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
Media []string `json:"media,omitempty"`
Peer Peer `json:"peer"` // routing peer
MessageID string `json:"message_id,omitempty"` // platform message ID
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
SessionKey string `json:"session_key"`
Metadata map[string]string `json:"metadata,omitempty"`
Context InboundContext `json:"context"`
Sender SenderInfo `json:"sender"`
Content string `json:"content"`
Media []string `json:"media,omitempty"`
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
SessionKey string `json:"session_key"`
// Convenience mirrors derived from Context for runtime consumers.
Channel string `json:"channel"`
SenderID string `json:"sender_id"`
ChatID string `json:"chat_id"`
MessageID string `json:"message_id,omitempty"` // platform message ID
}
// OutboundScope captures the structured session scope associated with an
// outbound turn result without depending on the session package.
type OutboundScope struct {
Version int `json:"version,omitempty"`
AgentID string `json:"agent_id,omitempty"`
Channel string `json:"channel,omitempty"`
Account string `json:"account,omitempty"`
Dimensions []string `json:"dimensions,omitempty"`
Values map[string]string `json:"values,omitempty"`
}
type OutboundMessage struct {
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Context InboundContext `json:"context"`
AgentID string `json:"agent_id,omitempty"`
SessionKey string `json:"session_key,omitempty"`
Scope *OutboundScope `json:"scope,omitempty"`
Content string `json:"content"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
}
// MediaPart describes a single media attachment to send.
@@ -48,9 +83,13 @@ type MediaPart struct {
// OutboundMediaMessage carries media attachments from Agent to channels via the bus.
type OutboundMediaMessage struct {
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Parts []MediaPart `json:"parts"`
Channel string `json:"channel"`
ChatID string `json:"chat_id"`
Context InboundContext `json:"context"`
AgentID string `json:"agent_id,omitempty"`
SessionKey string `json:"session_key,omitempty"`
Scope *OutboundScope `json:"scope,omitempty"`
Parts []MediaPart `json:"parts"`
}
// AudioChunk represents a chunk of streaming voice data.
+39 -19
View File
@@ -260,12 +260,11 @@ func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool {
return false
}
func (c *BaseChannel) HandleMessage(
func (c *BaseChannel) HandleMessageWithContext(
ctx context.Context,
peer bus.Peer,
messageID, senderID, chatID, content string,
deliveryChatID, content string,
media []string,
metadata map[string]string,
inboundCtx bus.InboundContext,
senderOpts ...bus.SenderInfo,
) {
// Use SenderInfo-based allow check when available, else fall back to string
@@ -273,6 +272,7 @@ func (c *BaseChannel) HandleMessage(
if len(senderOpts) > 0 {
sender = senderOpts[0]
}
senderID := strings.TrimSpace(inboundCtx.SenderID)
if sender.CanonicalID != "" || sender.PlatformID != "" {
if !c.IsAllowedSender(sender) {
return
@@ -289,20 +289,28 @@ func (c *BaseChannel) HandleMessage(
resolvedSenderID = sender.CanonicalID
}
scope := BuildMediaScope(c.name, chatID, messageID)
if resolvedSenderID == "" {
resolvedSenderID = senderID
}
inboundCtx.Channel = c.name
if inboundCtx.ChatID == "" {
inboundCtx.ChatID = deliveryChatID
}
if inboundCtx.SenderID == "" {
inboundCtx.SenderID = resolvedSenderID
}
scope := BuildMediaScope(c.name, deliveryChatID, inboundCtx.MessageID)
msg := bus.InboundMessage{
Channel: c.name,
SenderID: resolvedSenderID,
Context: inboundCtx,
Sender: sender,
ChatID: chatID,
Content: content,
Media: media,
Peer: peer,
MessageID: messageID,
MediaScope: scope,
Metadata: metadata,
}
msg = bus.NormalizeInboundMessage(msg)
// Auto-trigger typing indicator, message reaction, and placeholder before publishing.
// Each capability is independent — all three may fire for the same message.
@@ -313,14 +321,14 @@ func (c *BaseChannel) HandleMessage(
if c.owner != nil && c.placeholderRecorder != nil {
// Typing
if tc, ok := c.owner.(TypingCapable); ok {
if stop, err := tc.StartTyping(ctx, chatID); err == nil {
c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
if stop, err := tc.StartTyping(ctx, deliveryChatID); err == nil {
c.placeholderRecorder.RecordTypingStop(c.name, deliveryChatID, stop)
}
}
// Reaction
if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
if rc, ok := c.owner.(ReactionCapable); ok && msg.MessageID != "" {
if undo, err := rc.ReactToMessage(ctx, deliveryChatID, msg.MessageID); err == nil {
c.placeholderRecorder.RecordReactionUndo(c.name, deliveryChatID, undo)
}
}
// Placeholder — independent pipeline.
@@ -329,8 +337,8 @@ func (c *BaseChannel) HandleMessage(
// "Thinking…" only once the voice has been processed.
if !audioAnnotationRe.MatchString(content) {
if pc, ok := c.owner.(PlaceholderCapable); ok {
if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" {
c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID)
if phID, err := pc.SendPlaceholder(ctx, deliveryChatID); err == nil && phID != "" {
c.placeholderRecorder.RecordPlaceholder(c.name, deliveryChatID, phID)
}
}
}
@@ -339,12 +347,24 @@ func (c *BaseChannel) HandleMessage(
if err := c.bus.PublishInbound(ctx, msg); err != nil {
logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{
"channel": c.name,
"chat_id": chatID,
"chat_id": deliveryChatID,
"error": err.Error(),
})
}
}
// HandleInboundContext publishes a normalized inbound message using only the
// structured context.
func (c *BaseChannel) HandleInboundContext(
ctx context.Context,
deliveryChatID, content string,
media []string,
inboundCtx bus.InboundContext,
senderOpts ...bus.SenderInfo,
) {
c.HandleMessageWithContext(ctx, deliveryChatID, content, media, inboundCtx, senderOpts...)
}
func (c *BaseChannel) SetRunning(running bool) {
c.running.Store(running)
}
+56
View File
@@ -1,6 +1,7 @@
package channels
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -263,3 +264,58 @@ func TestIsAllowedSender(t *testing.T) {
})
}
}
func TestHandleInboundContext_PublishesNormalizedContext(t *testing.T) {
tests := []struct {
name string
inbound bus.InboundContext
wantChat string
wantSender string
}{
{
name: "direct uses sender as peer",
inbound: bus.InboundContext{
Channel: "test",
ChatID: "chat-1",
ChatType: "direct",
SenderID: "user-1",
MessageID: "msg-1",
},
wantChat: "chat-1",
wantSender: "user-1",
},
{
name: "group uses chat as peer",
inbound: bus.InboundContext{
Channel: "test",
ChatID: "group-1",
ChatType: "group",
SenderID: "user-2",
MessageID: "msg-2",
},
wantChat: "group-1",
wantSender: "user-2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgBus := bus.NewMessageBus()
defer msgBus.Close()
ch := NewBaseChannel("test", nil, msgBus, nil)
ch.HandleInboundContext(context.Background(), tt.inbound.ChatID, "hello", nil, tt.inbound)
msg := <-msgBus.InboundChan()
if msg.ChatID != tt.wantChat {
t.Fatalf("ChatID = %q, want %q", msg.ChatID, tt.wantChat)
}
if msg.SenderID != tt.wantSender {
t.Fatalf("SenderID = %q, want %q", msg.SenderID, tt.wantSender)
}
if msg.Context.ChatType != tt.inbound.ChatType {
t.Fatalf("ChatType = %q, want %q", msg.Context.ChatType, tt.inbound.ChatType)
}
})
}
}
+22 -10
View File
@@ -185,16 +185,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
"session_webhook": data.SessionWebhook,
}
var peer bus.Peer
var (
chatType string
isMentioned bool
)
if data.ConversationType == "1" {
peerID := senderID
if peerID == "" {
peerID = chatID
}
peer = bus.Peer{Kind: "direct", ID: peerID}
chatType = "direct"
} else {
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
isMentioned := data.IsInAtList
chatType = "group"
isMentioned = data.IsInAtList
if isMentioned {
content = stripLeadingAtMentions(content)
}
@@ -232,8 +231,21 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
return nil, nil
}
// Handle the message through the base channel
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "dingtalk",
ChatID: chatID,
ChatType: chatType,
SenderID: resolvedSenderID,
Mentioned: isMentioned,
Raw: metadata,
}
if data.SessionWebhook != "" {
inboundCtx.ReplyHandles = map[string]string{
"session_webhook": data.SessionWebhook,
}
}
c.HandleInboundContext(ctx, chatID, content, nil, inboundCtx, sender)
// Return nil to indicate we've handled the message asynchronously
// The response will be sent through the message bus
+8 -5
View File
@@ -75,8 +75,8 @@ func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention
if inbound.ChatID != "group-abc" {
t.Fatalf("chat_id=%q", inbound.ChatID)
}
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" {
t.Fatalf("peer=%+v", inbound.Peer)
if inbound.Context.ChatType != "group" {
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
}
if inbound.Content != "/help" {
t.Fatalf("content=%q", inbound.Content)
@@ -103,12 +103,15 @@ func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *te
if inbound.ChatID != "conv-direct-42" {
t.Fatalf("chat_id=%q", inbound.ChatID)
}
if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" {
t.Fatalf("peer=%+v", inbound.Peer)
if inbound.Context.ChatType != "direct" {
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
}
if inbound.SenderID != "dingtalk:openid-user-42" {
if inbound.SenderID != "openid-user-42" {
t.Fatalf("sender_id=%q", inbound.SenderID)
}
if inbound.Sender.CanonicalID != "dingtalk:openid-user-42" {
t.Fatalf("sender canonical_id=%q", inbound.Sender.CanonicalID)
}
if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok {
t.Fatal("expected session webhook keyed by conversation_id")
+18 -6
View File
@@ -408,8 +408,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
// In guild (group) channels, apply unified group trigger filtering
// DMs (GuildID is empty) always get a response
isMentioned := false
if m.GuildID != "" {
isMentioned := false
for _, mention := range m.Mentions {
if mention.ID == c.botUserID {
isMentioned = true
@@ -506,14 +506,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
})
peerKind := "channel"
peerID := m.ChannelID
if m.GuildID == "" {
peerKind = "direct"
peerID = senderID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
metadata := map[string]string{
"user_id": senderID,
"username": m.Author.Username,
@@ -522,8 +518,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
"channel_id": m.ChannelID,
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: m.ChannelID,
ChatType: peerKind,
SenderID: senderID,
MessageID: m.ID,
Mentioned: isMentioned,
Raw: metadata,
}
if m.GuildID != "" {
inboundCtx.SpaceID = m.GuildID
inboundCtx.SpaceType = "guild"
}
if m.MessageReference != nil {
inboundCtx.ReplyToMessageID = m.MessageReference.MessageID
}
c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender)
c.HandleInboundContext(c.ctx, m.ChannelID, content, mediaPaths, inboundCtx, sender)
}
// startTyping starts a continuous typing indicator loop for the given chatID.
+25 -5
View File
@@ -445,17 +445,23 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
// Append media tags to content (like Telegram does)
content = appendMediaTags(content, messageType, mediaRefs)
if content == "" {
content = "[empty message]"
}
chatType := stringValue(message.ChatType)
metadata := buildInboundMetadata(message, sender)
var peer bus.Peer
var (
inboundChatType string
isMentioned bool
)
if chatType == "p2p" {
peer = bus.Peer{Kind: "direct", ID: senderID}
inboundChatType = "direct"
} else {
peer = bus.Peer{Kind: "group", ID: chatID}
inboundChatType = "group"
// Check if bot was mentioned
isMentioned := c.isBotMentioned(message)
isMentioned = c.isBotMentioned(message)
// Strip mention placeholders from content before group trigger check
if len(message.Mentions) > 0 {
@@ -490,7 +496,21 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
"thread_id": stringValue(message.ThreadId),
})
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
inboundCtx := bus.InboundContext{
Channel: "feishu",
ChatID: chatID,
ChatType: inboundChatType,
SenderID: senderID,
MessageID: messageID,
Mentioned: isMentioned,
Raw: metadata,
}
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
inboundCtx.SpaceType = "tenant"
inboundCtx.SpaceID = *sender.TenantKey
}
c.HandleInboundContext(ctx, chatID, content, mediaRefs, inboundCtx, senderInfo)
return nil
}
+18 -5
View File
@@ -51,14 +51,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&")
var chatID string
var peer bus.Peer
if isDM {
chatID = nick
peer = bus.Peer{Kind: "direct", ID: nick}
} else {
chatID = target
peer = bus.Peer{Kind: "group", ID: target}
}
sender := bus.SenderInfo{
@@ -73,9 +70,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
return
}
isMentioned := false
// For channel messages, check group trigger (mention detection)
if !isDM {
isMentioned := isBotMentioned(content, currentNick)
isMentioned = isBotMentioned(content, currentNick)
if isMentioned {
content = stripBotMention(content, currentNick)
}
@@ -100,7 +99,21 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
metadata["channel"] = target
}
c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "irc",
ChatID: chatID,
SenderID: nick,
MessageID: messageID,
Mentioned: isMentioned,
Raw: metadata,
}
if isDM {
inboundCtx.ChatType = "direct"
} else {
inboundCtx.ChatType = "group"
}
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
}
// nickMentionedAt returns the byte index where botNick is mentioned in content
+21 -9
View File
@@ -354,8 +354,9 @@ func (c *LINEChannel) processEvent(event lineEvent) {
}
// In group chats, apply unified group trigger filtering
isMentioned := false
if isGroup {
isMentioned := c.isBotMentioned(msg)
isMentioned = c.isBotMentioned(msg)
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
if !respond {
logger.DebugCF("line", "Ignoring group message by group trigger", map[string]any{
@@ -371,13 +372,6 @@ func (c *LINEChannel) processEvent(event lineEvent) {
"source_type": event.Source.Type,
}
var peer bus.Peer
if isGroup {
peer = bus.Peer{Kind: "group", ID: chatID}
} else {
peer = bus.Peer{Kind: "direct", ID: senderID}
}
logger.DebugCF("line", "Received message", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
@@ -396,7 +390,25 @@ func (c *LINEChannel) processEvent(event lineEvent) {
return
}
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: chatID,
ChatType: map[bool]string{true: "group", false: "direct"}[isGroup],
SenderID: senderID,
MessageID: msg.ID,
Mentioned: isMentioned,
Raw: metadata,
}
if event.ReplyToken != "" {
inboundCtx.ReplyHandles = map[string]string{
"reply_token": event.ReplyToken,
}
if msg.QuoteToken != "" {
inboundCtx.ReplyHandles["quote_token"] = msg.QuoteToken
}
}
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
}
// isBotMentioned checks if the bot is mentioned in the message.
+9 -11
View File
@@ -200,17 +200,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
return
}
c.HandleMessage(
c.ctx,
bus.Peer{Kind: "channel", ID: "default"},
"",
senderID,
chatID,
content,
[]string{},
metadata,
sender,
)
inboundCtx := bus.InboundContext{
Channel: "maixcam",
ChatID: chatID,
ChatType: "channel",
SenderID: senderID,
Raw: metadata,
}
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
}
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
+47 -23
View File
@@ -98,6 +98,22 @@ type asyncTask struct {
cancel context.CancelFunc
}
func outboundMessageChannel(msg bus.OutboundMessage) string {
return msg.Context.Channel
}
func outboundMessageChatID(msg bus.OutboundMessage) string {
return msg.ChatID
}
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
return msg.Context.Channel
}
func outboundMediaChatID(msg bus.OutboundMediaMessage) string {
return msg.ChatID
}
// RecordPlaceholder registers a placeholder message for later editing.
// Implements PlaceholderRecorder.
func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) {
@@ -161,7 +177,8 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
// preSend handles typing stop, reaction undo, and placeholder editing before sending a message.
// Returns the delivered message IDs and true when delivery completed before a normal Send.
func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) {
key := name + ":" + msg.ChatID
chatID := outboundMessageChatID(msg)
key := name + ":" + chatID
// 1. Stop typing
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
@@ -183,9 +200,9 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
// Prefer deleting the placeholder (cleaner UX than editing to same content)
if deleter, ok := ch.(MessageDeleter); ok {
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
} else if editor, ok := ch.(MessageEditor); ok {
editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
}
}
}
@@ -196,7 +213,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
if editor, ok := ch.(MessageEditor); ok {
if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil {
if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil {
return []string{entry.id}, true
}
// edit failed → fall through to normal Send
@@ -212,7 +229,8 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
// delivery never edits the placeholder because there is no text payload to
// replace it with; it only attempts to delete the placeholder when possible.
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
key := name + ":" + msg.ChatID
chatID := outboundMediaChatID(msg)
key := name + ":" + chatID
// 1. Stop typing
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
@@ -235,7 +253,7 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
if deleter, ok := ch.(MessageDeleter); ok {
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
}
}
}
@@ -820,7 +838,7 @@ func (m *Manager) sendWithRetry(
// All retries exhausted or permanent failure
logger.ErrorCF("channels", "Send failed", map[string]any{
"channel": name,
"chat_id": msg.ChatID,
"chat_id": outboundMessageChatID(msg),
"error": lastErr.Error(),
"retries": maxRetries,
})
@@ -882,7 +900,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.OutboundChan(),
func(msg bus.OutboundMessage) string { return msg.Channel },
func(msg bus.OutboundMessage) string { return outboundMessageChannel(msg) },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
case w.queue <- msg:
@@ -902,7 +920,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.OutboundMediaChan(),
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(msg bus.OutboundMediaMessage) string { return outboundMediaChannel(msg) },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
case w.mediaQueue <- msg:
@@ -1001,7 +1019,7 @@ func (m *Manager) sendMediaWithRetry(
// All retries exhausted or permanent failure
logger.ErrorCF("channels", "SendMedia failed", map[string]any{
"channel": name,
"chat_id": msg.ChatID,
"chat_id": outboundMediaChatID(msg),
"error": lastErr.Error(),
"retries": maxRetries,
})
@@ -1200,16 +1218,19 @@ func (m *Manager) UnregisterChannel(name string) {
// delivered (or all retries are exhausted), which preserves ordering when
// a subsequent operation depends on the message having been sent.
func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error {
msg = bus.NormalizeOutboundMessage(msg)
channelName := outboundMessageChannel(msg)
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
_, exists := m.channels[channelName]
w, wExists := m.workers[channelName]
m.mu.RUnlock()
if !exists {
return fmt.Errorf("channel %s not found", msg.Channel)
return fmt.Errorf("channel %s not found", channelName)
}
if !wExists || w == nil {
return fmt.Errorf("channel %s has no active worker", msg.Channel)
return fmt.Errorf("channel %s has no active worker", channelName)
}
maxLen := 0
@@ -1220,10 +1241,10 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
for _, chunk := range SplitMessage(msg.Content, maxLen) {
chunkMsg := msg
chunkMsg.Content = chunk
m.sendWithRetry(ctx, msg.Channel, w, chunkMsg)
m.sendWithRetry(ctx, channelName, w, chunkMsg)
}
} else {
m.sendWithRetry(ctx, msg.Channel, w, msg)
m.sendWithRetry(ctx, channelName, w, msg)
}
return nil
}
@@ -1233,19 +1254,22 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
// retries are exhausted), which preserves ordering when later agent behavior
// depends on actual media delivery.
func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
msg = bus.NormalizeOutboundMediaMessage(msg)
channelName := outboundMediaChannel(msg)
m.mu.RLock()
_, exists := m.channels[msg.Channel]
w, wExists := m.workers[msg.Channel]
_, exists := m.channels[channelName]
w, wExists := m.workers[channelName]
m.mu.RUnlock()
if !exists {
return fmt.Errorf("channel %s not found", msg.Channel)
return fmt.Errorf("channel %s not found", channelName)
}
if !wExists || w == nil {
return fmt.Errorf("channel %s has no active worker", msg.Channel)
return fmt.Errorf("channel %s has no active worker", channelName)
}
_, err := m.sendMediaWithRetry(ctx, msg.Channel, w, msg)
_, err := m.sendMediaWithRetry(ctx, channelName, w, msg)
return err
}
@@ -1260,10 +1284,10 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
}
msg := bus.OutboundMessage{
Channel: channelName,
ChatID: chatID,
Context: bus.NewOutboundContext(channelName, chatID, ""),
Content: content,
}
msg = bus.NormalizeOutboundMessage(msg)
if wExists && w != nil {
select {
+138 -47
View File
@@ -175,11 +175,11 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
if err := m.bus.PublishOutbound(pubCtx, testOutboundMessage(bus.OutboundMessage{
Channel: "good",
ChatID: "chat-1",
Content: "hello",
}); err != nil {
})); err != nil {
t.Fatalf("PublishOutbound() error = %v", err)
}
@@ -197,6 +197,20 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
}
}
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
}
return bus.NormalizeOutboundMessage(msg)
}
func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage {
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, "")
}
return bus.NormalizeOutboundMediaMessage(msg)
}
func TestSendWithRetry_Success(t *testing.T) {
m := newTestManager()
var callCount int
@@ -212,7 +226,7 @@ func TestSendWithRetry_Success(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -239,7 +253,7 @@ func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -263,7 +277,7 @@ func TestSendWithRetry_PermanentFailure(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -287,7 +301,7 @@ func TestSendWithRetry_NotRunning(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -314,7 +328,7 @@ func TestSendWithRetry_RateLimitRetry(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
@@ -344,7 +358,7 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -370,11 +384,11 @@ func TestSendMedia_Success(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err != nil {
t.Fatalf("SendMedia() error = %v", err)
}
@@ -397,11 +411,11 @@ func TestSendMedia_PropagatesFailure(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err == nil {
t.Fatal("expected SendMedia to return error")
}
@@ -424,11 +438,11 @@ func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err == nil {
t.Fatal("expected SendMedia to return error for unsupported channel")
}
@@ -454,11 +468,11 @@ func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) {
m.workers["test"] = w
m.RecordPlaceholder("test", "chat1", "placeholder-1")
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
Channel: "test",
ChatID: "chat1",
Parts: []bus.MediaPart{{Ref: "media://abc"}},
})
}))
if err != nil {
t.Fatalf("SendMedia() error = %v", err)
}
@@ -491,7 +505,7 @@ func TestSendWithRetry_UnknownError(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
m.sendWithRetry(ctx, "test", w, msg)
@@ -515,7 +529,7 @@ func TestSendWithRetry_ContextCancelled(t *testing.T) {
}
ctx, cancel := context.WithCancel(context.Background())
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
// Cancel context after first Send attempt returns
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
@@ -561,7 +575,7 @@ func TestWorkerRateLimiter(t *testing.T) {
// Enqueue 4 messages
for i := range 4 {
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)})
}
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
@@ -637,7 +651,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
go m.runWorker(ctx, "test", w)
// Send a message that should be split
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"})
time.Sleep(100 * time.Millisecond)
@@ -678,7 +692,7 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
@@ -738,7 +752,7 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
// Register placeholder
m.RecordPlaceholder("test", "123", "456")
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if !edited {
@@ -768,7 +782,7 @@ func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) {
m.RecordPlaceholder("test", "123", "456")
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if edited {
@@ -827,7 +841,7 @@ func TestPreSend_TypingStopCalled(t *testing.T) {
stopCalled = true
})
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
m.preSend(context.Background(), "test", msg, ch)
if !stopCalled {
@@ -844,7 +858,7 @@ func TestPreSend_NoRegisteredState(t *testing.T) {
},
}
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if edited {
@@ -874,7 +888,7 @@ func TestPreSend_TypingAndPlaceholder(t *testing.T) {
})
m.RecordPlaceholder("test", "123", "456")
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if !stopCalled {
@@ -938,7 +952,7 @@ func TestRecordTypingStop_ReplacesExistingStop(t *testing.T) {
t.Fatalf("expected replacement typing stop to stay active until preSend, got %d calls", newStopCalls)
}
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
m.preSend(context.Background(), "test", msg, &mockChannel{})
if newStopCalls != 1 {
@@ -972,7 +986,7 @@ func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) {
limiter: rate.NewLimiter(rate.Inf, 1),
}
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
m.sendWithRetry(context.Background(), "test", w, msg)
if sendCalled {
@@ -1135,7 +1149,7 @@ func TestPreSendStillWorksWithWrappedTypes(t *testing.T) {
})
m.RecordPlaceholder("test", "chat1", "ph_id")
msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"})
_, edited := m.preSend(context.Background(), "test", msg, ch)
if !stopCalled {
@@ -1238,11 +1252,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
// Transcription feedback arrives first — it should consume the placeholder
// and be delivered via EditMessage, not Send.
msgTranscript := bus.OutboundMessage{
msgTranscript := testOutboundMessage(bus.OutboundMessage{
Channel: "mock",
ChatID: "chat-1",
Content: "Transcript: hello",
}
})
mgr.sendWithRetry(ctx, "mock", worker, msgTranscript)
if mockCh.editedMessages != 1 {
@@ -1258,11 +1272,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
}
// Final LLM response arrives — no placeholder left, so it goes through Send
msgFinal := bus.OutboundMessage{
msgFinal := testOutboundMessage(bus.OutboundMessage{
Channel: "mock",
ChatID: "chat-1",
Content: "Final Answer",
}
})
mgr.sendWithRetry(ctx, "mock", worker, msgFinal)
if len(mockCh.sentMessages) != 1 {
@@ -1288,12 +1302,12 @@ func TestSendMessage_Synchronous(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "hello world",
ReplyToMessageID: "msg-456",
}
})
err := m.SendMessage(context.Background(), msg)
if err != nil {
@@ -1315,11 +1329,11 @@ func TestSendMessage_Synchronous(t *testing.T) {
func TestSendMessage_UnknownChannel(t *testing.T) {
m := newTestManager()
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "nonexistent",
ChatID: "123",
Content: "hello",
}
})
err := m.SendMessage(context.Background(), msg)
if err == nil {
@@ -1336,11 +1350,11 @@ func TestSendMessage_NoWorker(t *testing.T) {
m.channels["test"] = ch
// No worker registered
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "hello",
}
})
err := m.SendMessage(context.Background(), msg)
if err == nil {
@@ -1369,11 +1383,11 @@ func TestSendMessage_WithRetry(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "retry me",
}
})
err := m.SendMessage(context.Background(), msg)
if err != nil {
@@ -1385,6 +1399,46 @@ func TestSendMessage_WithRetry(t *testing.T) {
}
}
func TestSendMessage_ContextOnlyUsesContextAddressing(t *testing.T) {
m := newTestManager()
var received []bus.OutboundMessage
ch := &mockChannel{
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
received = append(received, msg)
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
msg := testOutboundMessage(bus.OutboundMessage{
Context: bus.NewOutboundContext("test", "123", "msg-9"),
Content: "hello",
})
if err := m.SendMessage(context.Background(), msg); err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(received) != 1 {
t.Fatalf("expected 1 message sent, got %d", len(received))
}
if received[0].Channel != "test" || received[0].ChatID != "123" {
t.Fatalf("expected mirrored legacy address, got %+v", received[0])
}
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "123" {
t.Fatalf("expected context address to be preserved, got %+v", received[0].Context)
}
if received[0].ReplyToMessageID != "msg-9" {
t.Fatalf("expected reply_to_message_id msg-9, got %q", received[0].ReplyToMessageID)
}
}
func TestSendMessage_WithSplitting(t *testing.T) {
m := newTestManager()
@@ -1406,11 +1460,11 @@ func TestSendMessage_WithSplitting(t *testing.T) {
m.channels["test"] = ch
m.workers["test"] = w
msg := bus.OutboundMessage{
msg := testOutboundMessage(bus.OutboundMessage{
Channel: "test",
ChatID: "123",
Content: "hello world",
}
})
err := m.SendMessage(context.Background(), msg)
if err != nil {
@@ -1422,6 +1476,43 @@ func TestSendMessage_WithSplitting(t *testing.T) {
}
}
func TestSendMedia_ContextOnlyUsesContextAddressing(t *testing.T) {
m := newTestManager()
var received []bus.OutboundMediaMessage
ch := &mockMediaChannel{
sendMediaFn: func(_ context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
received = append(received, msg)
return nil, nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.channels["test"] = ch
m.workers["test"] = w
msg := testOutboundMediaMessage(bus.OutboundMediaMessage{
Context: bus.NewOutboundContext("test", "media-chat", ""),
Parts: []bus.MediaPart{{Type: "image", Ref: "media://1"}},
})
if err := m.SendMedia(context.Background(), msg); err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(received) != 1 {
t.Fatalf("expected 1 media message sent, got %d", len(received))
}
if received[0].Channel != "test" || received[0].ChatID != "media-chat" {
t.Fatalf("expected mirrored legacy media address, got %+v", received[0])
}
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "media-chat" {
t.Fatalf("expected media context address to be preserved, got %+v", received[0].Context)
}
}
func TestSendMessage_PreservesOrdering(t *testing.T) {
m := newTestManager()
@@ -1441,12 +1532,12 @@ func TestSendMessage_PreservesOrdering(t *testing.T) {
m.workers["test"] = w
// Send two messages sequentially — they must arrive in order
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
Channel: "test", ChatID: "1", Content: "first",
})
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
}))
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
Channel: "test", ChatID: "1", Content: "second",
})
}))
if len(order) != 2 {
t.Fatalf("expected 2 messages, got %d", len(order))
+13 -13
View File
@@ -739,10 +739,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
}
peerKind := "direct"
peerID := senderID
if isGroup {
peerKind = "group"
peerID = roomID
}
metadata := map[string]string{
@@ -755,17 +753,19 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
metadata["reply_to_msg_id"] = replyTo.String()
}
c.HandleMessage(
c.baseContext(),
bus.Peer{Kind: peerKind, ID: peerID},
evt.ID.String(),
senderID,
roomID,
content,
mediaPaths,
metadata,
sender,
)
inboundCtx := bus.InboundContext{
Channel: "matrix",
ChatID: roomID,
ChatType: peerKind,
SenderID: senderID,
MessageID: evt.ID.String(),
Raw: metadata,
}
if replyTo := msgEvt.GetRelatesTo().GetReplyTo(); replyTo != "" {
inboundCtx.ReplyToMessageID = replyTo.String()
}
c.HandleInboundContext(c.baseContext(), roomID, content, mediaPaths, inboundCtx, sender)
}
// decryptEvent decrypts an encrypted event and returns the decrypted message event content.
+18 -5
View File
@@ -995,8 +995,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
senderID := strconv.FormatInt(userID, 10)
var chatID string
var peer bus.Peer
var contextChatID string
var contextChatType string
metadata := map[string]string{}
@@ -1007,12 +1007,14 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
switch raw.MessageType {
case "private":
chatID = "private:" + senderID
peer = bus.Peer{Kind: "direct", ID: senderID}
contextChatID = senderID
contextChatType = "direct"
case "group":
groupIDStr := strconv.FormatInt(groupID, 10)
chatID = "group:" + groupIDStr
peer = bus.Peer{Kind: "group", ID: groupIDStr}
contextChatID = groupIDStr
contextChatType = "group"
metadata["group_id"] = groupIDStr
senderUserID, _ := parseJSONInt64(sender.UserID)
@@ -1076,7 +1078,18 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
return
}
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: contextChatID,
ChatType: contextChatType,
SenderID: senderID,
MessageID: messageID,
Mentioned: isBotMentioned,
ReplyToMessageID: parsed.ReplyTo,
Raw: metadata,
}
c.HandleInboundContext(c.ctx, chatID, content, parsed.Media, inboundCtx, senderInfo)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
+13 -6
View File
@@ -259,8 +259,6 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
chatID := "pico_client:" + sessionID
senderID := "pico-remote"
peer := bus.Peer{Kind: "direct", ID: chatID}
sender := bus.SenderInfo{
Platform: "pico_client",
PlatformID: senderID,
@@ -271,10 +269,19 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
return
}
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{
"platform": "pico_client",
"session_id": sessionID,
}, sender)
inboundCtx := bus.InboundContext{
Channel: "pico_client",
ChatID: chatID,
ChatType: "direct",
SenderID: senderID,
MessageID: msg.ID,
Raw: map[string]string{
"platform": "pico_client",
"session_id": sessionID,
},
}
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
}
// Send sends a message to the remote server.
+14 -7
View File
@@ -39,11 +39,11 @@ var allowedInlineImageMIMETypes = map[string]struct{}{
"image/bmp": {},
}
func outboundMessageIsThought(metadata map[string]string) bool {
if len(metadata) == 0 {
func outboundMessageIsThought(msg bus.OutboundMessage) bool {
if len(msg.Context.Raw) == 0 {
return false
}
return strings.EqualFold(strings.TrimSpace(metadata["message_kind"]), MessageKindThought)
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindThought)
}
// writeJSON sends a JSON message to the connection with write locking.
@@ -260,7 +260,7 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
isThought := outboundMessageIsThought(msg.Metadata)
isThought := outboundMessageIsThought(msg)
outMsg := newMessage(TypeMessageCreate, map[string]any{
PayloadKeyContent: msg.Content,
@@ -578,8 +578,6 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
chatID := "pico:" + sessionID
senderID := "pico-user"
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
metadata := map[string]string{
"platform": "pico",
"session_id": sessionID,
@@ -602,7 +600,16 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
return
}
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, media, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "pico",
ChatID: chatID,
ChatType: "direct",
SenderID: senderID,
MessageID: msg.ID,
Raw: metadata,
}
c.HandleInboundContext(c.ctx, chatID, content, media, inboundCtx, sender)
}
// truncate truncates a string to maxLen runes.
+31 -20
View File
@@ -590,12 +590,22 @@ func qqFileType(partType string) uint64 {
}
func (c *QQChannel) maxBase64FileSizeBytes() int64 {
if c.config == nil {
return 0
}
if c.config.MaxBase64FileSizeMiB <= 0 {
return 0
}
return c.config.MaxBase64FileSizeMiB * bytesPerMiB
}
func (c *QQChannel) accountID() string {
if c.config == nil {
return ""
}
return c.config.AppID
}
// handleC2CMessage handles QQ private messages.
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
@@ -649,17 +659,17 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
metadata := map[string]string{
"account_id": senderID,
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.accountID(),
ChatID: senderID,
ChatType: "direct",
SenderID: senderID,
MessageID: data.ID,
Raw: metadata,
}
c.HandleMessage(c.ctx,
bus.Peer{Kind: "direct", ID: senderID},
data.ID,
senderID,
senderID,
content,
mediaPaths,
metadata,
sender,
)
c.HandleInboundContext(c.ctx, senderID, content, mediaPaths, inboundCtx, sender)
return nil
}
@@ -727,17 +737,18 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
"account_id": senderID,
"group_id": data.GroupID,
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.accountID(),
ChatID: data.GroupID,
ChatType: "group",
SenderID: senderID,
MessageID: data.ID,
Mentioned: true,
Raw: metadata,
}
c.HandleMessage(c.ctx,
bus.Peer{Kind: "group", ID: data.GroupID},
data.ID,
senderID,
data.GroupID,
content,
mediaPaths,
metadata,
sender,
)
c.HandleInboundContext(c.ctx, data.GroupID, content, mediaPaths, inboundCtx, sender)
return nil
}
+4 -4
View File
@@ -54,8 +54,8 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
if inbound.Context.Raw["account_id"] != "7750283E123456" {
t.Fatalf("account_id raw = %q, want %q", inbound.Context.Raw["account_id"], "7750283E123456")
}
return
}
@@ -165,8 +165,8 @@ func TestHandleGroupATMessage_AttachmentOnlyPublishesMedia(t *testing.T) {
if !strings.HasPrefix(inbound.Media[0], "media://") {
t.Fatalf("inbound.Media[0] = %q, want media:// ref", inbound.Media[0])
}
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-1" {
t.Fatalf("inbound.Peer = %+v, want group/group-1", inbound.Peer)
if inbound.Context.ChatType != "group" {
t.Fatalf("inbound.Context.ChatType = %q, want group", inbound.Context.ChatType)
}
}
+83 -28
View File
@@ -117,7 +117,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
return nil, channels.ErrNotRunning
}
channelID, threadTS := parseSlackChatID(msg.ChatID)
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget(msg.ChatID, &msg.Context)
if channelID == "" {
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
@@ -139,7 +139,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
return nil, fmt.Errorf("slack send: %w", channels.ErrTemporary)
}
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
if ref, ok := c.pendingAcks.LoadAndDelete(deliveryChatID); ok {
msgRef := ref.(slackMessageRef)
c.api.AddReaction("white_check_mark", slack.ItemRef{
Channel: msgRef.ChannelID,
@@ -161,7 +161,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
return nil, channels.ErrNotRunning
}
channelID, _ := parseSlackChatID(msg.ChatID)
_, channelID, threadTS := resolveSlackMediaOutboundTarget(msg.ChatID, &msg.Context)
if channelID == "" {
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
@@ -192,10 +192,11 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
}
_, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{
Channel: channelID,
File: localPath,
Filename: filename,
Title: title,
Channel: channelID,
ThreadTimestamp: threadTS,
File: localPath,
Filename: filename,
Title: title,
})
if err != nil {
logger.ErrorCF("slack", "Failed to upload media", map[string]any{
@@ -360,14 +361,10 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
}
peerKind := "channel"
peerID := channelID
if strings.HasPrefix(channelID, "D") {
peerKind = "direct"
peerID = senderID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
@@ -383,7 +380,22 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
"has_thread": threadTS != "",
})
c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.teamID,
ChatID: channelID,
ChatType: peerKind,
SenderID: senderID,
MessageID: messageTS,
SpaceID: c.teamID,
SpaceType: "workspace",
Raw: metadata,
}
if threadTS != "" {
inboundCtx.TopicID = threadTS
}
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
@@ -431,14 +443,10 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
}
mentionPeerKind := "channel"
mentionPeerID := channelID
if strings.HasPrefix(channelID, "D") {
mentionPeerKind = "direct"
mentionPeerID = senderID
}
mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
@@ -447,8 +455,21 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
"is_mention": "true",
"team_id": c.teamID,
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.teamID,
ChatID: channelID,
ChatType: mentionPeerKind,
TopicID: threadTS,
SenderID: senderID,
MessageID: messageTS,
SpaceID: c.teamID,
SpaceType: "workspace",
Mentioned: true,
Raw: metadata,
}
c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender)
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, mentionSender)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
@@ -495,18 +516,22 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"command": cmd.Command,
"text": utils.Truncate(content, 50),
})
peerKind := "channel"
if strings.HasPrefix(channelID, "D") {
peerKind = "direct"
}
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: c.teamID,
ChatID: channelID,
ChatType: peerKind,
SenderID: senderID,
SpaceID: c.teamID,
SpaceType: "workspace",
Raw: metadata,
}
c.HandleMessage(
c.ctx,
bus.Peer{Kind: "channel", ID: channelID},
"",
senderID,
chatID,
content,
nil,
metadata,
cmdSender,
)
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, cmdSender)
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
@@ -541,3 +566,33 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
}
return channelID, threadTS
}
func resolveSlackOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
deliveryChatID := strings.TrimSpace(chatID)
if deliveryChatID == "" && outboundCtx != nil {
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
}
channelID, threadTS := parseSlackChatID(deliveryChatID)
if threadTS == "" && outboundCtx != nil {
threadTS = strings.TrimSpace(outboundCtx.TopicID)
if threadTS != "" && channelID != "" {
deliveryChatID = channelID + "/" + threadTS
}
}
return deliveryChatID, channelID, threadTS
}
func resolveSlackMediaOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
deliveryChatID := strings.TrimSpace(chatID)
if deliveryChatID == "" && outboundCtx != nil {
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
}
channelID, threadTS := parseSlackChatID(deliveryChatID)
if threadTS == "" && outboundCtx != nil {
threadTS = strings.TrimSpace(outboundCtx.TopicID)
if threadTS != "" && channelID != "" {
deliveryChatID = channelID + "/" + threadTS
}
}
return deliveryChatID, channelID, threadTS
}
+18
View File
@@ -53,6 +53,24 @@ func TestParseSlackChatID(t *testing.T) {
}
}
func TestResolveSlackOutboundTarget_PrefersContextTopicID(t *testing.T) {
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget("C123456", &bus.InboundContext{
Channel: "slack",
ChatID: "C123456",
TopicID: "1234567890.123456",
})
if deliveryChatID != "C123456/1234567890.123456" {
t.Fatalf("deliveryChatID = %q, want %q", deliveryChatID, "C123456/1234567890.123456")
}
if channelID != "C123456" {
t.Fatalf("channelID = %q, want %q", channelID, "C123456")
}
if threadTS != "1234567890.123456" {
t.Fatalf("threadTS = %q, want %q", threadTS, "1234567890.123456")
}
}
func TestStripBotMention(t *testing.T) {
ch := &SlackChannel{botUserID: "U12345BOT"}
+42 -18
View File
@@ -182,7 +182,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
useMarkdownV2 := c.tgCfg.UseMarkdownV2
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
if err != nil {
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
@@ -469,7 +469,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
return nil, channels.ErrNotRunning
}
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
if err != nil {
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
@@ -697,8 +697,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
}
// In group chats, apply unified group trigger filtering
isMentioned := false
if message.Chat.Type != "private" {
isMentioned := c.isBotMentioned(message)
isMentioned = c.isBotMentioned(message)
if isMentioned {
content = c.stripBotMention(content)
}
@@ -744,13 +745,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
})
peerKind := "direct"
peerID := fmt.Sprintf("%d", user.ID)
if message.Chat.Type != "private" {
peerKind = "group"
peerID = compositeChatID
}
peer := bus.Peer{Kind: peerKind, ID: peerID}
messageID := fmt.Sprintf("%d", message.MessageID)
metadata := map[string]string{
@@ -759,24 +756,29 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
if message.ReplyToMessage != nil {
metadata["reply_to_message_id"] = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
}
// Set parent_peer metadata for per-topic agent binding.
inboundCtx := bus.InboundContext{
Channel: c.Name(),
ChatID: fmt.Sprintf("%d", chatID),
ChatType: peerKind,
SenderID: platformID,
MessageID: messageID,
Mentioned: isMentioned,
Raw: metadata,
}
if message.Chat.IsForum && threadID != 0 {
metadata["parent_peer_kind"] = "topic"
metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID)
inboundCtx.TopicID = fmt.Sprintf("%d", threadID)
}
if message.ReplyToMessage != nil {
inboundCtx.ReplyToMessageID = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
}
c.HandleMessage(c.ctx,
peer,
messageID,
platformID,
c.HandleMessageWithContext(
c.ctx,
compositeChatID,
content,
mediaPaths,
metadata,
inboundCtx,
sender,
)
return nil
@@ -964,6 +966,28 @@ func parseTelegramChatID(chatID string) (int64, int, error) {
return cid, tid, nil
}
func resolveTelegramOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (int64, int, error) {
targetChatID := strings.TrimSpace(chatID)
if targetChatID == "" && outboundCtx != nil {
targetChatID = strings.TrimSpace(outboundCtx.ChatID)
}
resolvedChatID, resolvedThreadID, err := parseTelegramChatID(targetChatID)
if err != nil {
return 0, 0, err
}
if resolvedThreadID != 0 || outboundCtx == nil {
return resolvedChatID, resolvedThreadID, nil
}
topicID := strings.TrimSpace(outboundCtx.TopicID)
if topicID == "" {
return resolvedChatID, resolvedThreadID, nil
}
if threadID, convErr := strconv.Atoi(topicID); convErr == nil {
return resolvedChatID, threadID, nil
}
return resolvedChatID, resolvedThreadID, nil
}
func logParseFailed(err error, useMarkdownV2 bool) {
parsingName := "HTML"
if useMarkdownV2 {
+42 -26
View File
@@ -528,6 +528,38 @@ func TestSend_WithForumThreadID(t *testing.T) {
assert.Len(t, caller.calls, 1)
}
func TestSend_UsesContextTopicIDWhenChatIDDoesNotIncludeThread(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
_, err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "-1001234567890",
Content: "Hello from topic context",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "-1001234567890",
TopicID: "42",
},
})
require.NoError(t, err)
require.Len(t, caller.calls, 1)
var params struct {
ChatID int64 `json:"chat_id"`
MessageThreadID int `json:"message_thread_id"`
Text string `json:"text"`
}
require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, &params))
assert.Equal(t, int64(-1001234567890), params.ChatID)
assert.Equal(t, 42, params.MessageThreadID)
assert.Equal(t, "Hello from topic context", params.Text)
}
func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
@@ -557,16 +589,10 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID
assert.Equal(t, "-1001234567890/42", inbound.ChatID)
// Peer ID should include thread ID for session key isolation
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-1001234567890/42", inbound.Peer.ID)
// Parent peer metadata should be set for agent binding
assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"])
assert.Equal(t, "42", inbound.Metadata["parent_peer_id"])
// ChatID remains the parent chat; TopicID isolates the sub-conversation.
assert.Equal(t, "-1001234567890", inbound.ChatID)
assert.Equal(t, "group", inbound.Context.ChatType)
assert.Equal(t, "42", inbound.Context.TopicID)
}
func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
@@ -599,13 +625,8 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
// Plain chatID without thread suffix
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (no thread suffix)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
assert.Equal(t, "group", inbound.Context.ChatType)
assert.Empty(t, inbound.Context.TopicID)
}
func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
@@ -642,13 +663,8 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
// chatID should NOT include thread suffix for non-forum groups
assert.Equal(t, "-100999", inbound.ChatID)
// Peer ID should be raw chat ID (shared session for whole group)
assert.Equal(t, "group", inbound.Peer.Kind)
assert.Equal(t, "-100999", inbound.Peer.ID)
// No parent peer metadata
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
assert.Empty(t, inbound.Metadata["parent_peer_id"])
assert.Equal(t, "group", inbound.Context.ChatType)
assert.Empty(t, inbound.Context.TopicID)
}
func assertHandleMessageQuotedUserReply(
@@ -701,7 +717,7 @@ func assertHandleMessageQuotedUserReply(
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Metadata["reply_to_message_id"])
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Context.ReplyToMessageID)
assert.Equal(t, expectedContent, inbound.Content)
}
@@ -787,7 +803,7 @@ func TestHandleMessage_ReplyToOwnBotMessage_UsesAssistantRole(t *testing.T) {
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
assert.Equal(t, "101", inbound.Metadata["reply_to_message_id"])
assert.Equal(t, "101", inbound.Context.ReplyToMessageID)
assert.Equal(
t,
"[quoted assistant message from afjcjsbx_picoclaw_bot]: Fatto! Ho creato il file notizie_2026_03_28.md\n\nti ricordi questo file?",
+11 -15
View File
@@ -172,14 +172,11 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
_ = groupTrigger
}
peerKind := "direct"
peerIDStr := userID
chatType := "direct"
if isGroupChat {
peerKind = "group"
peerIDStr = chatID
chatType = "group"
}
peer := bus.Peer{Kind: peerKind, ID: peerIDStr}
messageID := strconv.Itoa(msg.ConversationMessageID)
metadata := map[string]string{
@@ -187,16 +184,15 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
"is_group": fmt.Sprintf("%t", isGroupChat),
}
c.HandleMessage(c.ctx,
peer,
messageID,
userID,
chatID,
text,
nil,
metadata,
sender,
)
c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{
Channel: "vk",
ChatID: chatID,
ChatType: chatType,
SenderID: userID,
MessageID: messageID,
Mentioned: isGroupChat && c.isMentioned(msg),
Raw: metadata,
}, sender)
}
func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
+14 -2
View File
@@ -570,7 +570,6 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
return err
}
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
metadata := map[string]string{
"channel": "wecom",
"req_id": reqID,
@@ -583,7 +582,20 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
metadata["quote_text"] = quoteText
}
c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: c.Name(),
Account: strings.TrimSpace(msg.AIBotID),
ChatID: actualChatID,
ChatType: peerKind,
SenderID: senderID,
MessageID: msg.MsgID,
ReplyHandles: map[string]string{
"req_id": reqID,
},
Raw: metadata,
}
c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender)
return nil
}
+4 -4
View File
@@ -50,11 +50,11 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
if inbound.MessageID != "msg-1" {
t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID)
}
if inbound.Peer.ID != "chat-1" {
t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID)
if inbound.Context.ChatType != "direct" {
t.Fatalf("inbound Context.ChatType = %q, want direct", inbound.Context.ChatType)
}
if inbound.Metadata["req_id"] != "req-1" {
t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"])
if inbound.Context.ReplyHandles["req_id"] != "req-1" {
t.Fatalf("inbound req_id = %q, want req-1", inbound.Context.ReplyHandles["req_id"])
}
default:
t.Fatal("expected inbound message to be published")
+15 -3
View File
@@ -357,8 +357,6 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
return
}
peer := bus.Peer{Kind: "direct", ID: fromUserID}
metadata := map[string]string{
"from_user_id": fromUserID,
"context_token": msg.ContextToken,
@@ -377,7 +375,21 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
c.persistContextTokens()
}
c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "weixin",
ChatID: fromUserID,
ChatType: "direct",
SenderID: fromUserID,
MessageID: messageID,
Raw: metadata,
}
if msg.ContextToken != "" {
inboundCtx.ReplyHandles = map[string]string{
"context_token": msg.ContextToken,
}
}
c.HandleInboundContext(ctx, fromUserID, content, mediaRefs, inboundCtx, sender)
}
// Send implements channels.Channel by sending a text message to the WeChat user.
+14 -8
View File
@@ -227,13 +227,6 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
metadata["user_name"] = userName
}
var peer bus.Peer
if chatID == senderID {
peer = bus.Peer{Kind: "direct", ID: senderID}
} else {
peer = bus.Peer{Kind: "group", ID: chatID}
}
logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
"sender": senderID,
"preview": utils.Truncate(content, 50),
@@ -252,5 +245,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
return
}
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "whatsapp",
ChatID: chatID,
SenderID: senderID,
MessageID: messageID,
Raw: metadata,
}
if chatID == senderID {
inboundCtx.ChatType = "direct"
} else {
inboundCtx.ChatType = "group"
}
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
}
@@ -377,7 +377,6 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
if evt.Info.Chat.Server == types.GroupServer {
peerKind = "group"
}
peer := bus.Peer{Kind: peerKind, ID: chatID}
messageID := evt.Info.ID
sender := bus.SenderInfo{
Platform: "whatsapp",
@@ -395,7 +394,17 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
"WhatsApp message received",
map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)},
)
c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
inboundCtx := bus.InboundContext{
Channel: "whatsapp",
ChatID: chatID,
SenderID: senderID,
MessageID: messageID,
ChatType: peerKind,
Raw: metadata,
}
c.HandleInboundContext(c.runCtx, chatID, content, mediaPaths, inboundCtx, sender)
}
func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
+33 -24
View File
@@ -30,10 +30,10 @@ func init() {
// Config is the current config structure with version support.
type Config struct {
Version int `json:"version" yaml:"-"` // Config schema version for migration
// Config schema version for migration.
Version int `json:"version" yaml:"-"`
Isolation IsolationConfig `json:"isolation,omitempty" yaml:"-"`
Agents AgentsConfig `json:"agents" yaml:"-"`
Bindings []AgentBinding `json:"bindings,omitempty" yaml:"-"`
Session SessionConfig `json:"session,omitempty" yaml:"-"`
Channels ChannelsConfig `json:"channel_list" yaml:"channel_list"`
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
@@ -120,7 +120,7 @@ type BuildInfo struct {
}
// MarshalJSON implements custom JSON marshaling for Config
// to omit providers section when empty and session when empty
// to omit providers section when empty and session when empty.
func (c *Config) MarshalJSON() ([]byte, error) {
type Alias Config
aux := &struct {
@@ -130,17 +130,18 @@ func (c *Config) MarshalJSON() ([]byte, error) {
Alias: (*Alias)(c),
}
// Only include session if not empty
if c.Session.DMScope != "" || len(c.Session.IdentityLinks) > 0 {
aux.Session = &c.Session
if len(c.Session.Dimensions) > 0 || len(c.Session.IdentityLinks) > 0 {
sessionCfg := c.Session
aux.Session = &sessionCfg
}
return json.Marshal(aux)
}
type AgentsConfig struct {
Defaults AgentDefaults `json:"defaults"`
List []AgentConfig `json:"list,omitempty"`
Defaults AgentDefaults `json:"defaults"`
List []AgentConfig `json:"list,omitempty"`
Dispatch *DispatchConfig `json:"dispatch,omitempty"`
}
// AgentModelConfig supports both string and structured model config.
@@ -197,26 +198,29 @@ type SubagentsConfig struct {
Model *AgentModelConfig `json:"model,omitempty"`
}
type PeerMatch struct {
Kind string `json:"kind"`
ID string `json:"id"`
type DispatchConfig struct {
Rules []DispatchRule `json:"rules,omitempty"`
}
type BindingMatch struct {
Channel string `json:"channel"`
AccountID string `json:"account_id,omitempty"`
Peer *PeerMatch `json:"peer,omitempty"`
GuildID string `json:"guild_id,omitempty"`
TeamID string `json:"team_id,omitempty"`
type DispatchRule struct {
Name string `json:"name,omitempty"`
Agent string `json:"agent"`
When DispatchSelector `json:"when"`
SessionDimensions []string `json:"session_dimensions,omitempty"`
}
type AgentBinding struct {
AgentID string `json:"agent_id"`
Match BindingMatch `json:"match"`
type DispatchSelector struct {
Channel string `json:"channel,omitempty"`
Account string `json:"account,omitempty"`
Space string `json:"space,omitempty"`
Chat string `json:"chat,omitempty"`
Topic string `json:"topic,omitempty"`
Sender string `json:"sender,omitempty"`
Mentioned *bool `json:"mentioned,omitempty"`
}
type SessionConfig struct {
DMScope string `json:"dm_scope,omitempty"`
Dimensions []string `json:"dimensions,omitempty"`
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
}
@@ -509,9 +513,10 @@ type DevicesConfig struct {
}
type VoiceConfig struct {
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"`
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"`
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
ElevenLabsAPIKey string `json:"elevenlabs_api_key,omitempty" env:"PICOCLAW_VOICE_ELEVENLABS_API_KEY"`
}
// ModelConfig represents a model-centric provider configuration.
@@ -1066,6 +1071,8 @@ func LoadConfig(path string) (*Config, error) {
return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version)
}
applyLegacyBindingsMigration(data, cfg)
if err = env.Parse(cfg); err != nil {
return nil, err
}
@@ -1256,6 +1263,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
isVirtual: true,
}
expanded = append(expanded, additionalEntry)
@@ -1277,6 +1285,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
APIKeys: SimpleSecureStrings(keys[0]),
}
+243 -38
View File
@@ -109,18 +109,8 @@ func TestAgentConfig_FullParse(t *testing.T) {
}
]
},
"bindings": [
{
"agent_id": "support",
"match": {
"channel": "telegram",
"account_id": "*",
"peer": {"kind": "direct", "id": "user123"}
}
}
],
"session": {
"dm_scope": "per-peer",
"dimensions": ["sender"],
"identity_links": {
"john": ["telegram:123", "discord:john#1234"]
}
@@ -158,19 +148,8 @@ func TestAgentConfig_FullParse(t *testing.T) {
t.Errorf("support.Subagents = %+v", support.Subagents)
}
if len(cfg.Bindings) != 1 {
t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings))
}
binding := cfg.Bindings[0]
if binding.AgentID != "support" || binding.Match.Channel != "telegram" {
t.Errorf("binding = %+v", binding)
}
if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" {
t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer)
}
if cfg.Session.DMScope != "per-peer" {
t.Errorf("Session.DMScope = %q", cfg.Session.DMScope)
if len(cfg.Session.Dimensions) != 1 || cfg.Session.Dimensions[0] != "sender" {
t.Errorf("Session.Dimensions = %v", cfg.Session.Dimensions)
}
if len(cfg.Session.IdentityLinks) != 1 {
t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks)
@@ -236,8 +215,242 @@ func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
if len(cfg.Agents.List) != 0 {
t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List))
}
if len(cfg.Bindings) != 0 {
t.Errorf("bindings should be empty, got %d", len(cfg.Bindings))
}
func TestAgentConfig_ParsesDispatchRules(t *testing.T) {
jsonData := `{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7"
},
"list": [
{ "id": "main", "default": true },
{ "id": "support" }
],
"dispatch": {
"rules": [
{
"name": "support-vip",
"agent": "support",
"when": {
"channel": "telegram",
"chat": "group:-100123",
"sender": "12345",
"mentioned": true
},
"session_dimensions": ["chat", "sender"]
}
]
}
}
}`
cfg := DefaultConfig()
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if cfg.Agents.Dispatch == nil {
t.Fatal("Agents.Dispatch should not be nil")
}
if len(cfg.Agents.Dispatch.Rules) != 1 {
t.Fatalf("Dispatch.Rules len = %d, want 1", len(cfg.Agents.Dispatch.Rules))
}
rule := cfg.Agents.Dispatch.Rules[0]
if rule.Name != "support-vip" || rule.Agent != "support" {
t.Fatalf("rule = %+v", rule)
}
if rule.When.Channel != "telegram" || rule.When.Chat != "group:-100123" || rule.When.Sender != "12345" {
t.Fatalf("rule.When = %+v", rule.When)
}
if rule.When.Mentioned == nil || !*rule.When.Mentioned {
t.Fatalf("rule.When.Mentioned = %+v, want true", rule.When.Mentioned)
}
if got := rule.SessionDimensions; len(got) != 2 || got[0] != "chat" || got[1] != "sender" {
t.Fatalf("rule.SessionDimensions = %v, want [chat sender]", got)
}
}
func TestLoadConfig_MigratesLegacyBindingsToDispatchRules(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
raw := `{
"version": 2,
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7"
},
"list": [
{ "id": "main", "default": true },
{ "id": "support" },
{ "id": "ops" },
{ "id": "slack" }
]
},
"bindings": [
{
"agent_id": "support",
"match": {
"channel": "telegram",
"peer": { "kind": "group", "id": "-100123" }
}
},
{
"agent_id": "ops",
"match": {
"channel": "discord",
"guild_id": "guild-1"
}
},
{
"agent_id": "slack",
"match": {
"channel": "slack",
"account_id": "*"
}
}
]
}`
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
t.Fatalf("WriteFile(configPath): %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Agents.Dispatch == nil {
t.Fatal("Agents.Dispatch should not be nil")
}
if len(cfg.Agents.Dispatch.Rules) != 3 {
t.Fatalf("Dispatch.Rules len = %d, want 3", len(cfg.Agents.Dispatch.Rules))
}
first := cfg.Agents.Dispatch.Rules[0]
if first.Agent != "support" {
t.Fatalf("first.Agent = %q, want %q", first.Agent, "support")
}
if first.When.Channel != "telegram" || first.When.Chat != "group:-100123" {
t.Fatalf("first.When = %+v", first.When)
}
if first.When.Account != legacyDefaultAccountID {
t.Fatalf("first.When.Account = %q, want %q", first.When.Account, legacyDefaultAccountID)
}
second := cfg.Agents.Dispatch.Rules[1]
if second.Agent != "ops" || second.When.Space != "guild:guild-1" {
t.Fatalf("second = %+v", second)
}
third := cfg.Agents.Dispatch.Rules[2]
if third.Agent != "slack" {
t.Fatalf("third.Agent = %q, want %q", third.Agent, "slack")
}
if third.When.Channel != "slack" || third.When.Account != "" {
t.Fatalf("third.When = %+v", third.When)
}
}
func TestLoadConfig_PrefersDispatchRulesOverLegacyBindings(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
raw := `{
"version": 2,
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7"
},
"list": [
{ "id": "main", "default": true },
{ "id": "support" }
],
"dispatch": {
"rules": [
{
"name": "explicit",
"agent": "support",
"when": {
"channel": "telegram",
"chat": "group:-100123"
}
}
]
}
},
"bindings": [
{
"agent_id": "main",
"match": {
"channel": "telegram",
"account_id": "*"
}
}
]
}`
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
t.Fatalf("WriteFile(configPath): %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Agents.Dispatch == nil {
t.Fatal("Agents.Dispatch should not be nil")
}
if len(cfg.Agents.Dispatch.Rules) != 1 {
t.Fatalf("Dispatch.Rules len = %d, want 1", len(cfg.Agents.Dispatch.Rules))
}
if cfg.Agents.Dispatch.Rules[0].Name != "explicit" {
t.Fatalf("Dispatch.Rules[0].Name = %q, want %q", cfg.Agents.Dispatch.Rules[0].Name, "explicit")
}
}
func TestLoadConfig_MigratesLegacyDirectBindingsWithIdentityLinks(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
raw := `{
"version": 2,
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7"
},
"list": [
{ "id": "main", "default": true },
{ "id": "support" }
]
},
"session": {
"identity_links": {
"john": ["telegram:123", "123"]
}
},
"bindings": [
{
"agent_id": "support",
"match": {
"channel": "telegram",
"peer": { "kind": "direct", "id": "123" }
}
}
]
}`
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
t.Fatalf("WriteFile(configPath): %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Agents.Dispatch == nil || len(cfg.Agents.Dispatch.Rules) != 1 {
t.Fatalf("Dispatch.Rules = %+v, want 1 migrated rule", cfg.Agents.Dispatch)
}
if got := cfg.Agents.Dispatch.Rules[0].When.Sender; got != "john" {
t.Fatalf("migrated sender selector = %q, want %q", got, "john")
}
}
@@ -374,13 +587,6 @@ func TestDefaultConfig_WebTools(t *testing.T) {
}
}
func TestDefaultConfig_ReadFileMode(t *testing.T) {
cfg := DefaultConfig()
if cfg.Tools.ReadFile.EffectiveMode() != ReadFileModeBytes {
t.Fatalf("expected default read_file mode %q, got %q", ReadFileModeBytes, cfg.Tools.ReadFile.EffectiveMode())
}
}
func TestSaveConfig_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
@@ -825,7 +1031,7 @@ func TestLoadConfig_HooksProcessConfig(t *testing.T) {
}
}
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SessionDimensions verifies the default session dimensions
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
cfg := DefaultConfig()
@@ -838,11 +1044,11 @@ func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
}
}
func TestDefaultConfig_DMScope(t *testing.T) {
func TestDefaultConfig_SessionDimensions(t *testing.T) {
cfg := DefaultConfig()
if cfg.Session.DMScope != "per-channel-peer" {
t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope)
if len(cfg.Session.Dimensions) != 1 || cfg.Session.Dimensions[0] != "chat" {
t.Errorf("Session.Dimensions = %v, want [chat]", cfg.Session.Dimensions)
}
}
@@ -1076,7 +1282,6 @@ func TestLoadConfig_TelegramPlaceholderTextAcceptsSingleString(t *testing.T) {
data := `{
"version": 1,
"agents": { "defaults": { "workspace": "", "model": "", "max_tokens": 0, "max_tool_iterations": 0 } },
"bindings": [],
"session": {},
"channels": {
"telegram": {
+3 -2
View File
@@ -41,9 +41,8 @@ func DefaultConfig() *Config {
SplitOnMarker: false,
},
},
Bindings: []AgentBinding{},
Session: SessionConfig{
DMScope: "per-channel-peer",
Dimensions: []string{"chat"},
},
Channels: defaultChannels(),
Hooks: HooksConfig{
@@ -422,7 +421,9 @@ func DefaultConfig() *Config {
},
Voice: VoiceConfig{
ModelName: "",
TTSModelName: "",
EchoTranscription: false,
ElevenLabsAPIKey: "",
},
BuildInfo: BuildInfo{
Version: Version,
+267
View File
@@ -0,0 +1,267 @@
package config
import (
"encoding/json"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/logger"
)
const legacyDefaultAccountID = "default"
type legacyBindingsEnvelope struct {
Bindings json.RawMessage `json:"bindings"`
}
type legacyAgentBinding struct {
AgentID string `json:"agent_id"`
Match legacyBindingMatch `json:"match"`
}
type legacyBindingMatch struct {
Channel string `json:"channel"`
AccountID string `json:"account_id,omitempty"`
Peer *legacyPeerMatch `json:"peer,omitempty"`
GuildID string `json:"guild_id,omitempty"`
TeamID string `json:"team_id,omitempty"`
}
type legacyPeerMatch struct {
Kind string `json:"kind"`
ID string `json:"id"`
}
func applyLegacyBindingsMigration(data []byte, cfg *Config) {
if cfg == nil {
return
}
bindings, found, err := decodeLegacyBindings(data)
if err != nil {
logger.WarnF(
"legacy bindings config detected but could not be decoded",
map[string]any{"error": err},
)
return
}
if !found {
return
}
if cfg.Agents.Dispatch != nil && len(cfg.Agents.Dispatch.Rules) > 0 {
logger.WarnF(
"legacy bindings config is deprecated and ignored because agents.dispatch.rules is configured",
map[string]any{"bindings": len(bindings), "dispatch_rules": len(cfg.Agents.Dispatch.Rules)},
)
return
}
rules, dropped := migrateLegacyBindings(bindings, cfg.Session.IdentityLinks)
if len(rules) == 0 {
logger.WarnF(
"legacy bindings config is deprecated and could not be migrated",
map[string]any{"bindings": len(bindings), "dropped_bindings": dropped},
)
return
}
if cfg.Agents.Dispatch == nil {
cfg.Agents.Dispatch = &DispatchConfig{}
}
cfg.Agents.Dispatch.Rules = rules
fields := map[string]any{
"bindings": len(bindings),
"dispatch_rules": len(rules),
}
if dropped > 0 {
fields["dropped_bindings"] = dropped
}
logger.WarnF("legacy bindings config is deprecated; migrated to agents.dispatch.rules in memory", fields)
}
func decodeLegacyBindings(data []byte) ([]legacyAgentBinding, bool, error) {
var envelope legacyBindingsEnvelope
if err := json.Unmarshal(data, &envelope); err != nil {
return nil, false, err
}
if len(envelope.Bindings) == 0 {
return nil, false, nil
}
var bindings []legacyAgentBinding
if err := json.Unmarshal(envelope.Bindings, &bindings); err != nil {
return nil, true, err
}
return bindings, true, nil
}
func migrateLegacyBindings(bindings []legacyAgentBinding, identityLinks map[string][]string) ([]DispatchRule, int) {
if len(bindings) == 0 {
return nil, 0
}
type prioritizedRule struct {
rule DispatchRule
index int
kind int
}
prioritized := make([]prioritizedRule, 0, len(bindings))
dropped := 0
for i, binding := range bindings {
rule, kind, ok := migrateLegacyBinding(binding, i, identityLinks)
if !ok {
dropped++
continue
}
prioritized = append(prioritized, prioritizedRule{rule: rule, index: i, kind: kind})
}
if len(prioritized) == 0 {
return nil, dropped
}
rules := make([]DispatchRule, 0, len(prioritized))
for kind := 0; kind <= 4; kind++ {
for _, item := range prioritized {
if item.kind == kind {
rules = append(rules, item.rule)
}
}
}
return rules, dropped
}
func migrateLegacyBinding(
binding legacyAgentBinding,
index int,
identityLinks map[string][]string,
) (DispatchRule, int, bool) {
channel := strings.ToLower(strings.TrimSpace(binding.Match.Channel))
agentID := strings.TrimSpace(binding.AgentID)
if channel == "" || agentID == "" {
return DispatchRule{}, 0, false
}
rule := DispatchRule{
Name: fmt.Sprintf("legacy-binding-%d", index+1),
Agent: agentID,
When: DispatchSelector{
Channel: channel,
},
}
switch normalizeLegacyAccountSelector(binding.Match.AccountID) {
case "":
case "*":
default:
rule.When.Account = normalizeLegacyAccountSelector(binding.Match.AccountID)
}
if peer := binding.Match.Peer; peer != nil {
peerKind := strings.ToLower(strings.TrimSpace(peer.Kind))
peerID := strings.TrimSpace(peer.ID)
if peerID == "" {
return DispatchRule{}, 0, false
}
switch peerKind {
case "direct":
rule.When.Sender = canonicalLegacyBindingSenderID(channel, peerID, identityLinks)
return rule, 0, true
case "group", "channel":
rule.When.Chat = peerKind + ":" + peerID
return rule, 0, true
case "topic":
rule.When.Topic = "topic:" + peerID
return rule, 0, true
default:
return DispatchRule{}, 0, false
}
}
if guildID := strings.TrimSpace(binding.Match.GuildID); guildID != "" {
rule.When.Space = "guild:" + guildID
return rule, 1, true
}
if teamID := strings.TrimSpace(binding.Match.TeamID); teamID != "" {
rule.When.Space = "team:" + teamID
return rule, 2, true
}
accountSelector := normalizeLegacyAccountSelector(binding.Match.AccountID)
if accountSelector == "*" {
rule.When.Account = ""
return rule, 4, true
}
rule.When.Account = accountSelector
return rule, 3, true
}
func normalizeLegacyAccountSelector(accountID string) string {
accountID = strings.TrimSpace(accountID)
switch accountID {
case "":
return legacyDefaultAccountID
case "*":
return "*"
default:
return strings.ToLower(accountID)
}
}
func canonicalLegacyBindingSenderID(channel, peerID string, identityLinks map[string][]string) string {
peerID = strings.TrimSpace(peerID)
if peerID == "" {
return ""
}
if linked := resolveLegacyBindingLinkedID(identityLinks, channel, peerID); linked != "" {
return strings.ToLower(linked)
}
return strings.ToLower(peerID)
}
func resolveLegacyBindingLinkedID(identityLinks map[string][]string, channel, peerID string) string {
if len(identityLinks) == 0 {
return ""
}
peerID = strings.TrimSpace(peerID)
if peerID == "" {
return ""
}
candidates := make(map[string]struct{})
rawCandidate := strings.ToLower(peerID)
if rawCandidate != "" {
candidates[rawCandidate] = struct{}{}
}
channel = strings.ToLower(strings.TrimSpace(channel))
if channel != "" {
candidates[channel+":"+rawCandidate] = struct{}{}
}
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
candidates[rawCandidate[idx+1:]] = struct{}{}
}
for canonical, ids := range identityLinks {
canonical = strings.TrimSpace(canonical)
if canonical == "" {
continue
}
for _, id := range ids {
normalized := strings.ToLower(strings.TrimSpace(id))
if normalized == "" {
continue
}
if _, ok := candidates[normalized]; ok {
return canonical
}
}
}
return ""
}
+1 -2
View File
@@ -131,8 +131,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Context: bus.NewOutboundContext(platform, userID, ""),
Content: msg,
})
+1 -2
View File
@@ -339,8 +339,7 @@ func (hs *HeartbeatService) sendResponse(response string) {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Context: bus.NewOutboundContext(platform, userID, ""),
Content: response,
})
+334 -15
View File
@@ -32,14 +32,19 @@ const (
maxLineSize = 10 * 1024 * 1024 // 10 MB
)
// sessionMeta holds per-session metadata stored in a .meta.json file.
type sessionMeta struct {
Key string `json:"key"`
Summary string `json:"summary"`
Skip int `json:"skip"`
Count int `json:"count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// SessionMeta holds per-session metadata stored in a .meta.json file.
//
// Scope is stored as raw JSON so pkg/memory can stay decoupled from the
// higher-level session package while still preserving structured scope data.
type SessionMeta struct {
Key string `json:"key"`
Summary string `json:"summary"`
Skip int `json:"skip"`
Count int `json:"count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Scope json.RawMessage `json:"scope,omitempty"`
Aliases []string `json:"aliases,omitempty"`
}
// JSONLStore implements Store using append-only JSONL files.
@@ -98,25 +103,31 @@ func sanitizeKey(key string) string {
// readMeta loads the metadata file for a session.
// Returns a zero-value sessionMeta if the file does not exist.
func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
func (s *JSONLStore) readMeta(key string) (SessionMeta, error) {
data, err := os.ReadFile(s.metaPath(key))
if os.IsNotExist(err) {
return sessionMeta{Key: key}, nil
return SessionMeta{Key: key}, nil
}
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
return SessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
}
var meta sessionMeta
var meta SessionMeta
err = json.Unmarshal(data, &meta)
if err != nil {
return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
return SessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
}
if meta.Key == "" {
meta.Key = key
}
return meta, nil
}
// writeMeta atomically writes the metadata file using the project's
// standard WriteFileAtomic (temp + fsync + rename).
func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
func (s *JSONLStore) writeMeta(key string, meta SessionMeta) error {
if strings.TrimSpace(meta.Key) == "" {
meta.Key = key
}
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return fmt.Errorf("memory: encode meta: %w", err)
@@ -124,6 +135,314 @@ func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
}
func cloneRawJSON(data json.RawMessage) json.RawMessage {
if len(data) == 0 {
return nil
}
return append(json.RawMessage(nil), data...)
}
func normalizeAliases(canonicalKey string, aliases []string) []string {
if len(aliases) == 0 {
return nil
}
normalized := make([]string, 0, len(aliases))
seen := make(map[string]struct{}, len(aliases))
canonicalKey = strings.TrimSpace(canonicalKey)
for _, alias := range aliases {
alias = strings.TrimSpace(alias)
if alias == "" || alias == canonicalKey {
continue
}
if _, ok := seen[alias]; ok {
continue
}
seen[alias] = struct{}{}
normalized = append(normalized, alias)
}
if len(normalized) == 0 {
return nil
}
return normalized
}
func (s *JSONLStore) sessionExists(key string) bool {
if key == "" {
return false
}
if _, err := os.Stat(s.jsonlPath(key)); err == nil {
return true
}
if _, err := os.Stat(s.metaPath(key)); err == nil {
return true
}
return false
}
// GetSessionMeta returns the current metadata snapshot for sessionKey.
func (s *JSONLStore) GetSessionMeta(_ context.Context, sessionKey string) (SessionMeta, error) {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return SessionMeta{}, err
}
meta.Scope = cloneRawJSON(meta.Scope)
if len(meta.Aliases) > 0 {
meta.Aliases = append([]string(nil), meta.Aliases...)
}
return meta, nil
}
// UpsertSessionMeta stores structured session metadata while preserving
// summary/count/skip timestamps maintained by the core JSONL store.
func (s *JSONLStore) UpsertSessionMeta(
_ context.Context,
sessionKey string,
scope json.RawMessage,
aliases []string,
) error {
l := s.sessionLock(sessionKey)
l.Lock()
defer l.Unlock()
meta, err := s.readMeta(sessionKey)
if err != nil {
return err
}
meta.Scope = cloneRawJSON(scope)
meta.Aliases = normalizeAliases(sessionKey, aliases)
now := time.Now()
if meta.CreatedAt.IsZero() {
meta.CreatedAt = now
}
meta.UpdatedAt = now
return s.writeMeta(sessionKey, meta)
}
// PromoteAliasHistory atomically promotes the first non-empty alias session
// into the canonical session when the canonical session is still empty.
func (s *JSONLStore) PromoteAliasHistory(
_ context.Context,
sessionKey string,
scope json.RawMessage,
aliases []string,
) (bool, error) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return false, nil
}
aliases = normalizeAliases(sessionKey, aliases)
for _, alias := range aliases {
unlock := s.lockSessionPair(sessionKey, alias)
promoted, err := s.promoteAliasHistoryLocked(sessionKey, alias, scope, aliases)
unlock()
if err != nil || promoted {
return promoted, err
}
}
return false, nil
}
// ResolveSessionKey returns the canonical session key for a candidate key.
// It short-circuits direct canonical keys when possible, then scans metadata
// once to resolve aliases or canonical metadata keys.
func (s *JSONLStore) ResolveSessionKey(_ context.Context, sessionKey string) (string, bool, error) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return "", false, nil
}
hasDirectSession := s.sessionExists(sessionKey)
if hasDirectSession && shouldShortCircuitSessionResolve(sessionKey) {
return sessionKey, true, nil
}
entries, err := os.ReadDir(s.dir)
if err != nil {
return "", false, fmt.Errorf("memory: read sessions dir: %w", err)
}
var directMetaMatch string
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
continue
}
data, readErr := os.ReadFile(filepath.Join(s.dir, entry.Name()))
if readErr != nil {
log.Printf("memory: skipping unreadable meta %s: %v", entry.Name(), readErr)
continue
}
var meta SessionMeta
if err := json.Unmarshal(data, &meta); err != nil {
log.Printf("memory: skipping corrupt meta %s: %v", entry.Name(), err)
continue
}
if meta.Key == "" {
continue
}
if meta.Key == sessionKey {
directMetaMatch = meta.Key
}
for _, alias := range meta.Aliases {
if alias == sessionKey && meta.Key != sessionKey {
return meta.Key, true, nil
}
}
}
if directMetaMatch != "" {
return directMetaMatch, true, nil
}
if hasDirectSession {
return sessionKey, true, nil
}
return "", false, nil
}
func shouldShortCircuitSessionResolve(sessionKey string) bool {
sessionKey = strings.TrimSpace(strings.ToLower(sessionKey))
if sessionKey == "" {
return false
}
return !strings.ContainsAny(sessionKey, ":/\\")
}
func (s *JSONLStore) lockSessionPair(keyA, keyB string) func() {
lockA := s.sessionLock(keyA)
lockB := s.sessionLock(keyB)
if lockA == lockB {
lockA.Lock()
return func() { lockA.Unlock() }
}
if keyA <= keyB {
lockA.Lock()
lockB.Lock()
return func() {
lockB.Unlock()
lockA.Unlock()
}
}
lockB.Lock()
lockA.Lock()
return func() {
lockA.Unlock()
lockB.Unlock()
}
}
func (s *JSONLStore) promoteAliasHistoryLocked(
sessionKey string,
alias string,
scope json.RawMessage,
aliases []string,
) (bool, error) {
canonicalMeta, err := s.readMeta(sessionKey)
if err != nil {
return false, err
}
canonicalHasContent, err := s.sessionHasVisibleContentLocked(sessionKey, canonicalMeta)
if err != nil {
return false, err
}
if canonicalHasContent {
return false, nil
}
aliasMeta, err := s.readMeta(alias)
if err != nil {
return false, err
}
aliasHistory, err := readMessages(s.jsonlPath(alias), aliasMeta.Skip)
if err != nil {
return false, err
}
aliasSummary := strings.TrimSpace(aliasMeta.Summary)
if len(aliasHistory) == 0 && aliasSummary == "" {
return false, nil
}
previousJSONL, hadPreviousJSONL, err := s.readRawJSONL(sessionKey)
if err != nil {
return false, err
}
now := time.Now()
if canonicalMeta.CreatedAt.IsZero() {
canonicalMeta.CreatedAt = now
}
canonicalMeta.Scope = cloneRawJSON(scope)
canonicalMeta.Aliases = normalizeAliases(sessionKey, aliases)
canonicalMeta.Skip = 0
canonicalMeta.Count = len(aliasHistory)
canonicalMeta.UpdatedAt = now
if aliasSummary != "" {
canonicalMeta.Summary = aliasSummary
}
if err := s.rewriteJSONL(sessionKey, aliasHistory); err != nil {
return false, err
}
if err := s.writeMeta(sessionKey, canonicalMeta); err != nil {
if rollbackErr := s.restoreRawJSONL(sessionKey, previousJSONL, hadPreviousJSONL); rollbackErr != nil {
return false, fmt.Errorf("memory: write promoted meta: %w (rollback jsonl: %v)", err, rollbackErr)
}
return false, err
}
return true, nil
}
func (s *JSONLStore) sessionHasVisibleContentLocked(sessionKey string, meta SessionMeta) (bool, error) {
if meta.Count-meta.Skip > 0 || strings.TrimSpace(meta.Summary) != "" {
return true, nil
}
if meta.Count != 0 || meta.Skip != 0 {
return false, nil
}
history, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return false, err
}
return len(history) > 0, nil
}
func (s *JSONLStore) readRawJSONL(sessionKey string) ([]byte, bool, error) {
data, err := os.ReadFile(s.jsonlPath(sessionKey))
if os.IsNotExist(err) {
return nil, false, nil
}
if err != nil {
return nil, false, fmt.Errorf("memory: read jsonl: %w", err)
}
return data, true, nil
}
func (s *JSONLStore) restoreRawJSONL(sessionKey string, data []byte, existed bool) error {
path := s.jsonlPath(sessionKey)
if !existed {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("memory: remove jsonl rollback: %w", err)
}
return nil
}
if err := fileutil.WriteFileAtomic(path, data, 0o644); err != nil {
return fmt.Errorf("memory: restore jsonl rollback: %w", err)
}
return nil
}
// readMessages reads valid JSON lines from a .jsonl file, skipping
// the first `skip` lines without unmarshaling them. This avoids the
// cost of json.Unmarshal on logically truncated messages.
@@ -471,7 +790,7 @@ func (s *JSONLStore) ListSessions() []string {
if err != nil {
continue
}
var meta sessionMeta
var meta SessionMeta
if err := json.Unmarshal(data, &meta); err != nil {
continue
}
+138
View File
@@ -2,8 +2,10 @@ package memory
import (
"context"
"encoding/json"
"os"
"path/filepath"
"reflect"
"sync"
"testing"
@@ -241,6 +243,142 @@ func TestSetSummary_GetSummary(t *testing.T) {
}
}
func TestSessionMetaScopeAndAliasesPersist(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
scope := json.RawMessage(`{"version":1,"channel":"telegram","values":{"chat":"group:c1"}}`)
aliases := []string{"legacy:one", "legacy:one", "canonical"}
if err := store.UpsertSessionMeta(ctx, "canonical", scope, aliases); err != nil {
t.Fatalf("UpsertSessionMeta() error = %v", err)
}
meta, err := store.GetSessionMeta(ctx, "canonical")
if err != nil {
t.Fatalf("GetSessionMeta() error = %v", err)
}
var gotScope map[string]any
if err := json.Unmarshal(meta.Scope, &gotScope); err != nil {
t.Fatalf("Unmarshal(meta.Scope) error = %v", err)
}
var wantScope map[string]any
if err := json.Unmarshal(scope, &wantScope); err != nil {
t.Fatalf("Unmarshal(scope) error = %v", err)
}
if !reflect.DeepEqual(gotScope, wantScope) {
t.Fatalf("meta.Scope = %#v, want %#v", gotScope, wantScope)
}
if len(meta.Aliases) != 1 || meta.Aliases[0] != "legacy:one" {
t.Fatalf("meta.Aliases = %#v, want [legacy:one]", meta.Aliases)
}
}
func TestResolveSessionKeyByAlias(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil {
t.Fatalf("AddMessage() error = %v", err)
}
if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil {
t.Fatalf("UpsertSessionMeta() error = %v", err)
}
resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key")
if err != nil {
t.Fatalf("ResolveSessionKey() error = %v", err)
}
if !found {
t.Fatal("ResolveSessionKey() did not find alias")
}
if resolved != "canonical" {
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
}
}
func TestResolveSessionKeyByAlias_PrefersMetadataOverLegacyFile(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
if err := store.AddMessage(ctx, "legacy:key", "user", "legacy"); err != nil {
t.Fatalf("AddMessage(legacy) error = %v", err)
}
if err := store.AddMessage(ctx, "canonical", "user", "canonical"); err != nil {
t.Fatalf("AddMessage(canonical) error = %v", err)
}
if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil {
t.Fatalf("UpsertSessionMeta() error = %v", err)
}
resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key")
if err != nil {
t.Fatalf("ResolveSessionKey() error = %v", err)
}
if !found {
t.Fatal("ResolveSessionKey() did not find alias")
}
if resolved != "canonical" {
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
}
}
func TestResolveSessionKey_DirectHitSkipsCorruptMetadata(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil {
t.Fatalf("AddMessage() error = %v", err)
}
if err := os.WriteFile(
filepath.Join(store.dir, "broken.meta.json"),
[]byte("{not-json"),
0o644,
); err != nil {
t.Fatalf("WriteFile(broken.meta.json) error = %v", err)
}
resolved, found, err := store.ResolveSessionKey(ctx, "canonical")
if err != nil {
t.Fatalf("ResolveSessionKey() error = %v", err)
}
if !found {
t.Fatal("ResolveSessionKey() did not find direct session")
}
if resolved != "canonical" {
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
}
}
func TestResolveSessionKey_SkipsCorruptMetadataDuringAliasScan(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil {
t.Fatalf("AddMessage() error = %v", err)
}
if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil {
t.Fatalf("UpsertSessionMeta() error = %v", err)
}
if err := os.WriteFile(
filepath.Join(store.dir, "broken.meta.json"),
[]byte("{not-json"),
0o644,
); err != nil {
t.Fatalf("WriteFile(broken.meta.json) error = %v", err)
}
resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key")
if err != nil {
t.Fatalf("ResolveSessionKey() error = %v", err)
}
if !found {
t.Fatal("ResolveSessionKey() did not find alias")
}
if resolved != "canonical" {
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
}
}
func TestTruncateHistory_KeepLast(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
+245 -184
View File
@@ -1,32 +1,29 @@
package routing
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
// RouteInput contains the routing context from an inbound message.
type RouteInput struct {
Channel string
AccountID string
Peer *RoutePeer
ParentPeer *RoutePeer
GuildID string
TeamID string
// SessionPolicy describes how a routed message should be mapped to a session.
type SessionPolicy struct {
Dimensions []string
IdentityLinks map[string][]string
}
// ResolvedRoute is the result of agent routing.
type ResolvedRoute struct {
AgentID string
Channel string
AccountID string
SessionKey string
MainSessionKey string
MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default"
AgentID string
Channel string
AccountID string
SessionPolicy SessionPolicy
MatchedBy string
}
// RouteResolver determines which agent handles a message based on config bindings.
// RouteResolver determines which agent handles a message.
type RouteResolver struct {
cfg *config.Config
}
@@ -36,182 +33,32 @@ func NewRouteResolver(cfg *config.Config) *RouteResolver {
return &RouteResolver{cfg: cfg}
}
// ResolveRoute determines which agent handles the message and constructs session keys.
// Implements the 7-level priority cascade:
// peer > parent_peer > guild > team > account > channel_wildcard > default
func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
channel := strings.ToLower(strings.TrimSpace(input.Channel))
accountID := NormalizeAccountID(input.AccountID)
peer := input.Peer
// ResolveRoute determines which agent handles the message from a normalized
// inbound context and returns the session policy that should be used to
// allocate session state.
func (r *RouteResolver) ResolveRoute(inbound bus.InboundContext) ResolvedRoute {
channel := strings.ToLower(strings.TrimSpace(inbound.Channel))
accountID := NormalizeAccountID(inbound.Account)
identityLinks := cloneIdentityLinks(r.cfg.Session.IdentityLinks)
view := buildDispatchView(inbound, identityLinks)
dmScope := DMScope(r.cfg.Session.DMScope)
if dmScope == "" {
dmScope = DMScopeMain
}
identityLinks := r.cfg.Session.IdentityLinks
bindings := r.filterBindings(channel, accountID)
choose := func(agentID string, matchedBy string) ResolvedRoute {
resolvedAgentID := r.pickAgentID(agentID)
sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: resolvedAgentID,
if rule := r.matchDispatchRule(view); rule != nil {
return ResolvedRoute{
AgentID: r.pickAgentID(rule.Agent),
Channel: channel,
AccountID: accountID,
Peer: peer,
DMScope: dmScope,
IdentityLinks: identityLinks,
}))
mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID))
return ResolvedRoute{
AgentID: resolvedAgentID,
Channel: channel,
AccountID: accountID,
SessionKey: sessionKey,
MainSessionKey: mainSessionKey,
MatchedBy: matchedBy,
SessionPolicy: r.sessionPolicy(rule),
MatchedBy: matchedByForRule(rule),
}
}
// Priority 1: Peer binding
if peer != nil && strings.TrimSpace(peer.ID) != "" {
if match := r.findPeerMatch(bindings, peer); match != nil {
return choose(match.AgentID, "binding.peer")
}
return ResolvedRoute{
AgentID: r.pickAgentID(r.resolveDefaultAgentID()),
Channel: channel,
AccountID: accountID,
SessionPolicy: r.sessionPolicy(nil),
MatchedBy: "default",
}
// Priority 2: Parent peer binding
parentPeer := input.ParentPeer
if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" {
if match := r.findPeerMatch(bindings, parentPeer); match != nil {
return choose(match.AgentID, "binding.peer.parent")
}
}
// Priority 3: Guild binding
guildID := strings.TrimSpace(input.GuildID)
if guildID != "" {
if match := r.findGuildMatch(bindings, guildID); match != nil {
return choose(match.AgentID, "binding.guild")
}
}
// Priority 4: Team binding
teamID := strings.TrimSpace(input.TeamID)
if teamID != "" {
if match := r.findTeamMatch(bindings, teamID); match != nil {
return choose(match.AgentID, "binding.team")
}
}
// Priority 5: Account binding
if match := r.findAccountMatch(bindings); match != nil {
return choose(match.AgentID, "binding.account")
}
// Priority 6: Channel wildcard binding
if match := r.findChannelWildcardMatch(bindings); match != nil {
return choose(match.AgentID, "binding.channel")
}
// Priority 7: Default agent
return choose(r.resolveDefaultAgentID(), "default")
}
func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding {
var filtered []config.AgentBinding
for _, b := range r.cfg.Bindings {
matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel))
if matchChannel == "" || matchChannel != channel {
continue
}
if !matchesAccountID(b.Match.AccountID, accountID) {
continue
}
filtered = append(filtered, b)
}
return filtered
}
func matchesAccountID(matchAccountID, actual string) bool {
trimmed := strings.TrimSpace(matchAccountID)
if trimmed == "" {
return actual == DefaultAccountID
}
if trimmed == "*" {
return true
}
return strings.ToLower(trimmed) == strings.ToLower(actual)
}
func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
if b.Match.Peer == nil {
continue
}
peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind))
peerID := strings.TrimSpace(b.Match.Peer.ID)
if peerKind == "" || peerID == "" {
continue
}
if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID {
return b
}
}
return nil
}
func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
matchGuild := strings.TrimSpace(b.Match.GuildID)
if matchGuild != "" && matchGuild == guildID {
return &bindings[i]
}
}
return nil
}
func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
matchTeam := strings.TrimSpace(b.Match.TeamID)
if matchTeam != "" && matchTeam == teamID {
return &bindings[i]
}
}
return nil
}
func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
accountID := strings.TrimSpace(b.Match.AccountID)
if accountID == "*" {
continue
}
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
continue
}
return &bindings[i]
}
return nil
}
func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
accountID := strings.TrimSpace(b.Match.AccountID)
if accountID != "*" {
continue
}
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
continue
}
return &bindings[i]
}
return nil
}
func (r *RouteResolver) pickAgentID(agentID string) string {
@@ -250,3 +97,217 @@ func (r *RouteResolver) resolveDefaultAgentID() string {
}
return DefaultAgentID
}
func (r *RouteResolver) sessionPolicy(rule *config.DispatchRule) SessionPolicy {
dimensions := r.cfg.Session.Dimensions
if rule != nil && len(rule.SessionDimensions) > 0 {
dimensions = rule.SessionDimensions
}
return SessionPolicy{
Dimensions: normalizeSessionDimensions(dimensions),
IdentityLinks: cloneIdentityLinks(r.cfg.Session.IdentityLinks),
}
}
func normalizeSessionDimensions(dimensions []string) []string {
if len(dimensions) == 0 {
return nil
}
normalized := make([]string, 0, len(dimensions))
seen := make(map[string]struct{}, len(dimensions))
for _, dimension := range dimensions {
dimension = strings.ToLower(strings.TrimSpace(dimension))
switch dimension {
case "space", "chat", "topic", "sender":
default:
continue
}
if _, ok := seen[dimension]; ok {
continue
}
seen[dimension] = struct{}{}
normalized = append(normalized, dimension)
}
if len(normalized) == 0 {
return nil
}
return normalized
}
func cloneIdentityLinks(src map[string][]string) map[string][]string {
if len(src) == 0 {
return nil
}
cloned := make(map[string][]string, len(src))
for canonical, ids := range src {
dup := make([]string, len(ids))
copy(dup, ids)
cloned[canonical] = dup
}
return cloned
}
type dispatchView struct {
Channel string
Account string
Space string
Chat string
Topic string
Sender string
Mentioned bool
}
func (r *RouteResolver) matchDispatchRule(view dispatchView) *config.DispatchRule {
if r.cfg == nil || r.cfg.Agents.Dispatch == nil || len(r.cfg.Agents.Dispatch.Rules) == 0 {
return nil
}
for i := range r.cfg.Agents.Dispatch.Rules {
rule := &r.cfg.Agents.Dispatch.Rules[i]
if !selectorHasAnyConstraint(rule.When) {
continue
}
if ruleMatchesView(*rule, view) {
return rule
}
}
return nil
}
func ruleMatchesView(rule config.DispatchRule, view dispatchView) bool {
when := normalizeDispatchSelector(rule.When)
if when.Channel != "" && when.Channel != view.Channel {
return false
}
if when.Account != "" && when.Account != view.Account {
return false
}
if when.Space != "" && when.Space != view.Space {
return false
}
if when.Chat != "" && when.Chat != view.Chat {
return false
}
if when.Topic != "" && when.Topic != view.Topic {
return false
}
if when.Sender != "" && when.Sender != view.Sender {
return false
}
if when.Mentioned != nil && *when.Mentioned != view.Mentioned {
return false
}
return true
}
func matchedByForRule(rule *config.DispatchRule) string {
if rule == nil {
return "default"
}
name := strings.TrimSpace(rule.Name)
if name == "" {
return "dispatch.rule"
}
return "dispatch.rule:" + strings.ToLower(name)
}
func buildDispatchView(inbound bus.InboundContext, identityLinks map[string][]string) dispatchView {
view := dispatchView{
Channel: strings.ToLower(strings.TrimSpace(inbound.Channel)),
Account: NormalizeAccountID(inbound.Account),
Mentioned: inbound.Mentioned,
}
if spaceID := strings.TrimSpace(inbound.SpaceID); spaceID != "" {
spaceType := strings.ToLower(strings.TrimSpace(inbound.SpaceType))
if spaceType == "" {
spaceType = "space"
}
view.Space = fmt.Sprintf("%s:%s", spaceType, strings.ToLower(spaceID))
}
if chatID := strings.TrimSpace(inbound.ChatID); chatID != "" {
chatType := strings.ToLower(strings.TrimSpace(inbound.ChatType))
if chatType == "" {
chatType = "direct"
}
view.Chat = fmt.Sprintf("%s:%s", chatType, strings.ToLower(chatID))
}
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
view.Topic = "topic:" + strings.ToLower(topicID)
}
view.Sender = canonicalDispatchSenderID(inbound.Channel, inbound.SenderID, identityLinks)
return view
}
func normalizeDispatchSelector(selector config.DispatchSelector) config.DispatchSelector {
selector.Channel = strings.ToLower(strings.TrimSpace(selector.Channel))
selector.Account = NormalizeAccountID(selector.Account)
selector.Space = strings.ToLower(strings.TrimSpace(selector.Space))
selector.Chat = strings.ToLower(strings.TrimSpace(selector.Chat))
selector.Topic = strings.ToLower(strings.TrimSpace(selector.Topic))
selector.Sender = strings.ToLower(strings.TrimSpace(selector.Sender))
return selector
}
func selectorHasAnyConstraint(selector config.DispatchSelector) bool {
return strings.TrimSpace(selector.Channel) != "" ||
strings.TrimSpace(selector.Account) != "" ||
strings.TrimSpace(selector.Space) != "" ||
strings.TrimSpace(selector.Chat) != "" ||
strings.TrimSpace(selector.Topic) != "" ||
strings.TrimSpace(selector.Sender) != "" ||
selector.Mentioned != nil
}
func canonicalDispatchSenderID(channel, rawID string, identityLinks map[string][]string) string {
normalizedID := strings.TrimSpace(rawID)
if normalizedID == "" {
return ""
}
if linked := resolveLinkedDispatchID(identityLinks, channel, normalizedID); linked != "" {
normalizedID = linked
}
return strings.ToLower(normalizedID)
}
func resolveLinkedDispatchID(identityLinks map[string][]string, channel, peerID string) string {
if len(identityLinks) == 0 {
return ""
}
peerID = strings.TrimSpace(peerID)
if peerID == "" {
return ""
}
candidates := make(map[string]bool)
rawCandidate := strings.ToLower(peerID)
if rawCandidate != "" {
candidates[rawCandidate] = true
}
channel = strings.ToLower(strings.TrimSpace(channel))
if channel != "" {
candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true
}
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
candidates[rawCandidate[idx+1:]] = true
}
for canonical, ids := range identityLinks {
canonicalName := strings.TrimSpace(canonical)
if canonicalName == "" {
continue
}
for _, id := range ids {
normalized := strings.ToLower(strings.TrimSpace(id))
if normalized != "" && candidates[normalized] {
return canonicalName
}
}
}
return ""
}
+126 -190
View File
@@ -3,10 +3,11 @@ package routing
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config {
func testConfig(agents []config.AgentConfig) *config.Config {
return &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
@@ -15,20 +16,20 @@ func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *co
},
List: agents,
},
Bindings: bindings,
Session: config.SessionConfig{
DMScope: "per-peer",
Dimensions: []string{"sender"},
},
}
}
func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
cfg := testConfig(nil, nil)
cfg := testConfig(nil)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
route := r.ResolveRoute(bus.InboundContext{
Channel: "telegram",
ChatType: "direct",
SenderID: "user1",
})
if route.AgentID != DefaultAgentID {
@@ -37,202 +38,152 @@ func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
if route.MatchedBy != "default" {
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
}
if len(route.SessionPolicy.Dimensions) != 1 || route.SessionPolicy.Dimensions[0] != "sender" {
t.Errorf("SessionPolicy.Dimensions = %v, want [sender]", route.SessionPolicy.Dimensions)
}
if route.SessionPolicy.IdentityLinks != nil {
t.Errorf("SessionPolicy.IdentityLinks = %v, want nil", route.SessionPolicy.IdentityLinks)
}
}
func TestResolveRoute_PeerBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "sales", Default: true},
{ID: "support"},
func TestResolveRoute_UsesNormalizedInboundContextFields(t *testing.T) {
cfg := testConfig([]config.AgentConfig{{ID: "sales", Default: true}})
r := NewRouteResolver(cfg)
route := r.ResolveRoute(bus.InboundContext{
Channel: "Telegram",
Account: "Bot2",
ChatType: "direct",
SenderID: "user123",
})
if route.AgentID != "sales" {
t.Errorf("AgentID = %q, want 'sales'", route.AgentID)
}
bindings := []config.AgentBinding{
{
AgentID: "support",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
Peer: &config.PeerMatch{Kind: "direct", ID: "user123"},
if route.Channel != "telegram" {
t.Errorf("Channel = %q, want 'telegram'", route.Channel)
}
if route.AccountID != "bot2" {
t.Errorf("AccountID = %q, want 'bot2'", route.AccountID)
}
if route.MatchedBy != "default" {
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
}
}
func TestResolveRoute_DispatchFirstMatchWins(t *testing.T) {
cfg := testConfig([]config.AgentConfig{
{ID: "main", Default: true},
{ID: "support"},
{ID: "sales"},
})
cfg.Agents.Dispatch = &config.DispatchConfig{
Rules: []config.DispatchRule{
{
Name: "support-group",
Agent: "support",
When: config.DispatchSelector{
Channel: "telegram",
Chat: "group:-100123",
},
},
{
Name: "vip-in-group",
Agent: "sales",
When: config.DispatchSelector{
Channel: "telegram",
Chat: "group:-100123",
Sender: "12345",
},
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
route := r.ResolveRoute(bus.InboundContext{
Channel: "telegram",
ChatID: "-100123",
ChatType: "group",
SenderID: "12345",
})
if route.AgentID != "support" {
t.Errorf("AgentID = %q, want 'support'", route.AgentID)
t.Fatalf("AgentID = %q, want support", route.AgentID)
}
if route.MatchedBy != "binding.peer" {
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
if route.MatchedBy != "dispatch.rule:support-group" {
t.Fatalf("MatchedBy = %q, want dispatch.rule:support-group", route.MatchedBy)
}
}
func TestResolveRoute_GuildBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "gaming"},
}
bindings := []config.AgentBinding{
{
AgentID: "gaming",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
GuildID: "guild-abc",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "discord",
GuildID: "guild-abc",
Peer: &RoutePeer{Kind: "channel", ID: "ch1"},
})
if route.AgentID != "gaming" {
t.Errorf("AgentID = %q, want 'gaming'", route.AgentID)
}
if route.MatchedBy != "binding.guild" {
t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy)
}
}
func TestResolveRoute_TeamBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "work"},
}
bindings := []config.AgentBinding{
{
AgentID: "work",
Match: config.BindingMatch{
Channel: "slack",
AccountID: "*",
TeamID: "T12345",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "slack",
TeamID: "T12345",
Peer: &RoutePeer{Kind: "channel", ID: "C001"},
})
if route.AgentID != "work" {
t.Errorf("AgentID = %q, want 'work'", route.AgentID)
}
if route.MatchedBy != "binding.team" {
t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy)
}
}
func TestResolveRoute_AccountBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "default-agent", Default: true},
{ID: "premium"},
}
bindings := []config.AgentBinding{
{
AgentID: "premium",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "bot2",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
AccountID: "bot2",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
})
if route.AgentID != "premium" {
t.Errorf("AgentID = %q, want 'premium'", route.AgentID)
}
if route.MatchedBy != "binding.account" {
t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy)
}
}
func TestResolveRoute_ChannelWildcard(t *testing.T) {
agents := []config.AgentConfig{
func TestResolveRoute_DispatchOverridesSessionDimensions(t *testing.T) {
cfg := testConfig([]config.AgentConfig{
{ID: "main", Default: true},
{ID: "telegram-bot"},
}
bindings := []config.AgentBinding{
{
AgentID: "telegram-bot",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
{ID: "support"},
})
cfg.Session.Dimensions = []string{"chat"}
cfg.Agents.Dispatch = &config.DispatchConfig{
Rules: []config.DispatchRule{
{
Name: "support-dm",
Agent: "support",
When: config.DispatchSelector{
Channel: "telegram",
Chat: "direct:user-1",
},
SessionDimensions: []string{"chat", "sender"},
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
route := r.ResolveRoute(bus.InboundContext{
Channel: "telegram",
ChatID: "user-1",
ChatType: "direct",
SenderID: "user-1",
})
if route.AgentID != "telegram-bot" {
t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID)
if route.AgentID != "support" {
t.Fatalf("AgentID = %q, want support", route.AgentID)
}
if route.MatchedBy != "binding.channel" {
t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy)
if got := route.SessionPolicy.Dimensions; len(got) != 2 || got[0] != "chat" || got[1] != "sender" {
t.Fatalf("SessionPolicy.Dimensions = %v, want [chat sender]", got)
}
}
func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "vip"},
{ID: "gaming"},
}
bindings := []config.AgentBinding{
{
AgentID: "vip",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"},
},
},
{
AgentID: "gaming",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
GuildID: "guild-1",
func TestResolveRoute_DispatchMentionedRule(t *testing.T) {
cfg := testConfig([]config.AgentConfig{
{ID: "main", Default: true},
{ID: "support"},
})
mentioned := true
cfg.Agents.Dispatch = &config.DispatchConfig{
Rules: []config.DispatchRule{
{
Name: "slack-mentions",
Agent: "support",
When: config.DispatchSelector{
Channel: "slack",
Space: "workspace:t001",
Mentioned: &mentioned,
},
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "discord",
GuildID: "guild-1",
Peer: &RoutePeer{Kind: "direct", ID: "user-vip"},
route := r.ResolveRoute(bus.InboundContext{
Channel: "slack",
ChatID: "C123",
ChatType: "channel",
SpaceID: "T001",
SpaceType: "workspace",
SenderID: "U123",
Mentioned: true,
})
if route.AgentID != "vip" {
t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID)
}
if route.MatchedBy != "binding.peer" {
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
if route.AgentID != "support" {
t.Fatalf("AgentID = %q, want support", route.AgentID)
}
}
@@ -240,21 +191,10 @@ func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) {
agents := []config.AgentConfig{
{ID: "main", Default: true},
}
bindings := []config.AgentBinding{
{
AgentID: "nonexistent",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
},
},
}
cfg := testConfig(agents, bindings)
cfg := testConfig(agents)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
})
route := r.ResolveRoute(bus.InboundContext{Channel: "telegram"})
if route.AgentID != "main" {
t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID)
@@ -267,12 +207,10 @@ func TestResolveRoute_DefaultAgentSelection(t *testing.T) {
{ID: "beta", Default: true},
{ID: "gamma"},
}
cfg := testConfig(agents, nil)
cfg := testConfig(agents)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "cli",
})
route := r.ResolveRoute(bus.InboundContext{Channel: "cli"})
if route.AgentID != "beta" {
t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID)
@@ -284,12 +222,10 @@ func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) {
{ID: "alpha"},
{ID: "beta"},
}
cfg := testConfig(agents, nil)
cfg := testConfig(agents)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "cli",
})
route := r.ResolveRoute(bus.InboundContext{Channel: "cli"})
if route.AgentID != "alpha" {
t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID)
-192
View File
@@ -1,192 +0,0 @@
package routing
import (
"fmt"
"strings"
)
// DMScope controls DM session isolation granularity.
type DMScope string
const (
DMScopeMain DMScope = "main"
DMScopePerPeer DMScope = "per-peer"
DMScopePerChannelPeer DMScope = "per-channel-peer"
DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer"
)
// RoutePeer represents a chat peer with kind and ID.
type RoutePeer struct {
Kind string // "direct", "group", "channel"
ID string
}
// SessionKeyParams holds all inputs for session key construction.
type SessionKeyParams struct {
AgentID string
Channel string
AccountID string
Peer *RoutePeer
DMScope DMScope
IdentityLinks map[string][]string
}
// ParsedSessionKey is the result of parsing an agent-scoped session key.
type ParsedSessionKey struct {
AgentID string
Rest string
}
// BuildAgentMainSessionKey returns "agent:<agentId>:main".
func BuildAgentMainSessionKey(agentID string) string {
return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey)
}
// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope.
func BuildAgentPeerSessionKey(params SessionKeyParams) string {
agentID := NormalizeAgentID(params.AgentID)
peer := params.Peer
if peer == nil {
peer = &RoutePeer{Kind: "direct"}
}
peerKind := strings.TrimSpace(peer.Kind)
if peerKind == "" {
peerKind = "direct"
}
if peerKind == "direct" {
dmScope := params.DMScope
if dmScope == "" {
dmScope = DMScopeMain
}
peerID := strings.TrimSpace(peer.ID)
// Resolve identity links (cross-platform collapse)
if dmScope != DMScopeMain && peerID != "" {
if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" {
peerID = linked
}
}
peerID = strings.ToLower(peerID)
switch dmScope {
case DMScopePerAccountChannelPeer:
if peerID != "" {
channel := normalizeChannel(params.Channel)
accountID := NormalizeAccountID(params.AccountID)
return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID)
}
case DMScopePerChannelPeer:
if peerID != "" {
channel := normalizeChannel(params.Channel)
return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID)
}
case DMScopePerPeer:
if peerID != "" {
return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID)
}
}
return BuildAgentMainSessionKey(agentID)
}
// Group/channel peers always get per-peer sessions
channel := normalizeChannel(params.Channel)
peerID := strings.ToLower(strings.TrimSpace(peer.ID))
if peerID == "" {
peerID = "unknown"
}
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
}
// ParseAgentSessionKey extracts agentId and rest from "agent:<agentId>:<rest>".
func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey {
raw := strings.TrimSpace(sessionKey)
if raw == "" {
return nil
}
parts := strings.SplitN(raw, ":", 3)
if len(parts) < 3 {
return nil
}
if parts[0] != "agent" {
return nil
}
agentID := strings.TrimSpace(parts[1])
rest := parts[2]
if agentID == "" || rest == "" {
return nil
}
return &ParsedSessionKey{AgentID: agentID, Rest: rest}
}
// IsSubagentSessionKey returns true if the session key represents a subagent.
func IsSubagentSessionKey(sessionKey string) bool {
raw := strings.TrimSpace(sessionKey)
if raw == "" {
return false
}
if strings.HasPrefix(strings.ToLower(raw), "subagent:") {
return true
}
parsed := ParseAgentSessionKey(raw)
if parsed == nil {
return false
}
return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:")
}
func normalizeChannel(channel string) string {
c := strings.TrimSpace(strings.ToLower(channel))
if c == "" {
return "unknown"
}
return c
}
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
if len(identityLinks) == 0 {
return ""
}
peerID = strings.TrimSpace(peerID)
if peerID == "" {
return ""
}
candidates := make(map[string]bool)
rawCandidate := strings.ToLower(peerID)
if rawCandidate != "" {
candidates[rawCandidate] = true
}
channel = strings.ToLower(strings.TrimSpace(channel))
if channel != "" {
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
candidates[scopedCandidate] = true
}
// If peerID is already in canonical "platform:id" format, also add the
// bare ID part as a candidate for backward compatibility with identity_links
// that use raw IDs (e.g. "123" instead of "telegram:123").
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
bareID := rawCandidate[idx+1:]
candidates[bareID] = true
}
if len(candidates) == 0 {
return ""
}
for canonical, ids := range identityLinks {
canonicalName := strings.TrimSpace(canonical)
if canonicalName == "" {
continue
}
for _, id := range ids {
normalized := strings.ToLower(strings.TrimSpace(id))
if normalized != "" && candidates[normalized] {
return canonicalName
}
}
}
return ""
}
-207
View File
@@ -1,207 +0,0 @@
package routing
import "testing"
func TestBuildAgentMainSessionKey(t *testing.T) {
got := BuildAgentMainSessionKey("sales")
want := "agent:sales:main"
if got != want {
t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want)
}
}
func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) {
got := BuildAgentMainSessionKey("Sales Bot")
want := "agent:sales-bot:main"
if got != want {
t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopeMain,
})
want := "agent:main:main"
if got != want {
t.Errorf("DMScopeMain = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerPeer,
})
want := "agent:main:direct:user123"
if got != want {
t.Errorf("DMScopePerPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerChannelPeer,
})
want := "agent:main:telegram:direct:user123"
if got != want {
t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
AccountID: "bot1",
Peer: &RoutePeer{Kind: "direct", ID: "User123"},
DMScope: DMScopePerAccountChannelPeer,
})
want := "agent:main:telegram:bot1:direct:user123"
if got != want {
t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "group", ID: "chat456"},
DMScope: DMScopePerPeer,
})
want := "agent:main:telegram:group:chat456"
if got != want {
t.Errorf("GroupPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: nil,
DMScope: DMScopePerPeer,
})
// nil peer defaults to direct with empty ID, falls to main
want := "agent:main:main"
if got != want {
t.Errorf("NilPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
links := map[string][]string{
"john": {"telegram:user123", "discord:john#1234"},
}
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerPeer,
IdentityLinks: links,
})
want := "agent:main:direct:john"
if got != want {
t.Errorf("IdentityLink = %q, want %q", got, want)
}
}
func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) {
// When peerID is already in canonical "platform:id" format,
// it should match identity_links that use the bare ID.
links := map[string][]string{
"john": {"123"},
}
got := resolveLinkedPeerID(links, "telegram", "telegram:123")
if got != "john" {
t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john")
}
}
func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) {
// When identity_links contain canonical IDs and peerID is canonical too
links := map[string][]string{
"john": {"telegram:123", "discord:456"},
}
got := resolveLinkedPeerID(links, "telegram", "telegram:123")
if got != "john" {
t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john")
}
}
func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) {
// When peerID is bare "123" and links have "telegram:123",
// the scoped candidate "telegram:123" should match.
links := map[string][]string{
"john": {"telegram:123"},
}
got := resolveLinkedPeerID(links, "telegram", "123")
if got != "john" {
t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john")
}
}
func TestResolveLinkedPeerID_NoMatch(t *testing.T) {
links := map[string][]string{
"john": {"telegram:123"},
}
got := resolveLinkedPeerID(links, "discord", "999")
if got != "" {
t.Errorf("resolveLinkedPeerID no match = %q, want empty", got)
}
}
func TestParseAgentSessionKey_Valid(t *testing.T) {
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
if parsed == nil {
t.Fatal("expected non-nil result")
}
if parsed.AgentID != "sales" {
t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID)
}
if parsed.Rest != "telegram:direct:user123" {
t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest)
}
}
func TestParseAgentSessionKey_Invalid(t *testing.T) {
tests := []string{
"",
"foo:bar",
"notprefix:sales:main",
"agent::main",
"agent:sales:",
}
for _, input := range tests {
if got := ParseAgentSessionKey(input); got != nil {
t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got)
}
}
}
func TestIsSubagentSessionKey(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"subagent:task-1", true},
{"agent:main:subagent:task-1", true},
{"agent:main:main", false},
{"agent:main:telegram:direct:user123", false},
{"", false},
}
for _, tt := range tests {
if got := IsSubagentSessionKey(tt.input); got != tt.want {
t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
+213
View File
@@ -0,0 +1,213 @@
package session
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
)
// Allocation contains the concrete session keys selected for a routed turn.
// The current implementation intentionally preserves the legacy session-key
// layout while moving key construction out of the router.
type Allocation struct {
Scope SessionScope
SessionKey string
SessionAliases []string
MainSessionKey string
MainAliases []string
}
// AllocationInput contains the routing result and peer context needed to
// derive the session keys for a turn.
type AllocationInput struct {
AgentID string
Context bus.InboundContext
SessionPolicy routing.SessionPolicy
}
// AllocateRouteSession maps a route decision onto a structured scope and the
// current opaque session-key format.
func AllocateRouteSession(input AllocationInput) Allocation {
scope := buildSessionScope(input)
legacySessionAliases := buildLegacySessionAliases(input)
legacyMainSessionKey := strings.ToLower(BuildLegacyMainAlias(input.AgentID))
return Allocation{
Scope: scope,
SessionKey: BuildSessionKey(scope),
SessionAliases: legacySessionAliases,
MainSessionKey: BuildOpaqueSessionKey(legacyMainSessionKey),
MainAliases: []string{legacyMainSessionKey},
}
}
func buildSessionScope(input AllocationInput) SessionScope {
inbound := input.Context
includeTopicInChatDimension := shouldPreserveTelegramForumIsolation(input)
scope := SessionScope{
Version: ScopeVersionV1,
AgentID: routing.NormalizeAgentID(input.AgentID),
Channel: strings.ToLower(strings.TrimSpace(inbound.Channel)),
Account: routing.NormalizeAccountID(inbound.Account),
}
if scope.Channel == "" {
scope.Channel = "unknown"
}
dimensions := make([]string, 0, len(input.SessionPolicy.Dimensions))
values := make(map[string]string, len(input.SessionPolicy.Dimensions))
for _, dimension := range input.SessionPolicy.Dimensions {
switch dimension {
case "space":
if spaceID := strings.TrimSpace(inbound.SpaceID); spaceID != "" {
spaceType := strings.ToLower(strings.TrimSpace(inbound.SpaceType))
if spaceType == "" {
spaceType = "space"
}
dimensions = append(dimensions, "space")
values["space"] = fmt.Sprintf("%s:%s", spaceType, strings.ToLower(spaceID))
}
case "chat":
chatID := strings.TrimSpace(inbound.ChatID)
if chatID == "" {
continue
}
if includeTopicInChatDimension {
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
chatID = chatID + "/" + topicID
}
}
chatType := strings.ToLower(strings.TrimSpace(inbound.ChatType))
if chatType == "" {
chatType = "direct"
}
dimensions = append(dimensions, "chat")
values["chat"] = fmt.Sprintf("%s:%s", chatType, strings.ToLower(chatID))
case "topic":
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
dimensions = append(dimensions, "topic")
values["topic"] = "topic:" + strings.ToLower(topicID)
}
case "sender":
senderID := CanonicalSessionIdentityID(
inbound.Channel,
inbound.SenderID,
input.SessionPolicy.IdentityLinks,
)
if senderID == "" {
continue
}
dimensions = append(dimensions, "sender")
values["sender"] = senderID
}
}
if len(dimensions) > 0 {
scope.Dimensions = dimensions
scope.Values = values
}
return scope
}
func buildLegacySessionAliases(input AllocationInput) []string {
aliases := []string{strings.ToLower(BuildLegacyMainAlias(input.AgentID))}
inbound := input.Context
if strings.EqualFold(strings.TrimSpace(inbound.ChatType), "direct") {
peerIDs := buildLegacyDirectPeerIDs(input)
if len(peerIDs) == 0 {
return uniqueAliases(aliases)
}
for _, peerID := range peerIDs {
aliases = append(
aliases,
BuildLegacyDirectAliases(input.AgentID, inbound.Channel, inbound.Account, peerID)...,
)
}
return uniqueAliases(aliases)
}
peerID := strings.TrimSpace(inbound.ChatID)
if peerID == "" {
return uniqueAliases(aliases)
}
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
peerID = peerID + "/" + topicID
}
aliases = append(aliases, BuildLegacyPeerAlias(
input.AgentID,
inbound.Channel,
strings.ToLower(strings.TrimSpace(inbound.ChatType)),
peerID,
))
return uniqueAliases(aliases)
}
func shouldPreserveTelegramForumIsolation(input AllocationInput) bool {
inbound := input.Context
if !strings.EqualFold(strings.TrimSpace(inbound.Channel), "telegram") {
return false
}
if strings.TrimSpace(inbound.TopicID) == "" {
return false
}
for _, dimension := range input.SessionPolicy.Dimensions {
if strings.EqualFold(strings.TrimSpace(dimension), "topic") {
return false
}
}
return true
}
func buildLegacyDirectPeerIDs(input AllocationInput) []string {
inbound := input.Context
peerIDs := make([]string, 0, 3)
rawSenderID := strings.TrimSpace(inbound.SenderID)
if rawSenderID != "" {
peerIDs = append(peerIDs, strings.ToLower(rawSenderID))
}
canonicalSenderID := CanonicalSessionIdentityID(
inbound.Channel,
inbound.SenderID,
input.SessionPolicy.IdentityLinks,
)
if canonicalSenderID != "" {
peerIDs = append(peerIDs, canonicalSenderID)
}
chatID := strings.TrimSpace(inbound.ChatID)
if chatID != "" {
peerIDs = append(peerIDs, strings.ToLower(chatID))
}
return uniqueAliases(peerIDs)
}
func uniqueAliases(aliases []string) []string {
if len(aliases) == 0 {
return nil
}
normalized := make([]string, 0, len(aliases))
seen := make(map[string]struct{}, len(aliases))
for _, alias := range aliases {
alias = strings.TrimSpace(strings.ToLower(alias))
if alias == "" {
continue
}
if _, ok := seen[alias]; ok {
continue
}
seen[alias] = struct{}{}
normalized = append(normalized, alias)
}
if len(normalized) == 0 {
return nil
}
return normalized
}
+160
View File
@@ -0,0 +1,160 @@
package session
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/routing"
)
func TestAllocateRouteSession_PerPeerDM(t *testing.T) {
allocation := AllocateRouteSession(AllocationInput{
AgentID: "main",
Context: bus.InboundContext{
Channel: "telegram",
Account: "default",
ChatID: "dm-123",
ChatType: "direct",
SenderID: "User123",
},
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
})
if allocation.SessionKey == "" || !IsOpaqueSessionKey(allocation.SessionKey) {
t.Fatalf("SessionKey = %q, want opaque session key", allocation.SessionKey)
}
if !containsAlias(allocation.SessionAliases, "agent:main:direct:user123") {
t.Fatalf("SessionAliases = %v, want to contain agent:main:direct:user123", allocation.SessionAliases)
}
if allocation.MainSessionKey == "" || !IsOpaqueSessionKey(allocation.MainSessionKey) {
t.Fatalf("MainSessionKey = %q, want opaque session key", allocation.MainSessionKey)
}
if len(allocation.MainAliases) != 1 || allocation.MainAliases[0] != "agent:main:main" {
t.Fatalf("MainAliases = %v, want [agent:main:main]", allocation.MainAliases)
}
if allocation.Scope.Version != ScopeVersionV1 {
t.Fatalf("Scope.Version = %d, want %d", allocation.Scope.Version, ScopeVersionV1)
}
if len(allocation.Scope.Dimensions) != 1 || allocation.Scope.Dimensions[0] != "sender" {
t.Fatalf("Scope.Dimensions = %v, want [sender]", allocation.Scope.Dimensions)
}
if allocation.Scope.Values["sender"] != "user123" {
t.Fatalf("Scope.Values[sender] = %q, want user123", allocation.Scope.Values["sender"])
}
}
func TestAllocateRouteSession_GroupPeer(t *testing.T) {
allocation := AllocateRouteSession(AllocationInput{
AgentID: "main",
Context: bus.InboundContext{
Channel: "slack",
Account: "workspace-a",
ChatID: "C001",
ChatType: "channel",
SenderID: "U001",
},
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"chat"},
},
})
if allocation.SessionKey == "" || !IsOpaqueSessionKey(allocation.SessionKey) {
t.Fatalf("SessionKey = %q, want opaque session key", allocation.SessionKey)
}
if !containsAlias(allocation.SessionAliases, "agent:main:slack:channel:c001") {
t.Fatalf("SessionAliases = %v, want to contain agent:main:slack:channel:c001", allocation.SessionAliases)
}
if allocation.MainSessionKey == "" || !IsOpaqueSessionKey(allocation.MainSessionKey) {
t.Fatalf("MainSessionKey = %q, want opaque session key", allocation.MainSessionKey)
}
if len(allocation.MainAliases) != 1 || allocation.MainAliases[0] != "agent:main:main" {
t.Fatalf("MainAliases = %v, want [agent:main:main]", allocation.MainAliases)
}
if len(allocation.Scope.Dimensions) != 1 || allocation.Scope.Dimensions[0] != "chat" {
t.Fatalf("Scope.Dimensions = %v, want [chat]", allocation.Scope.Dimensions)
}
if allocation.Scope.Values["chat"] != "channel:c001" {
t.Fatalf("Scope.Values[chat] = %q, want channel:c001", allocation.Scope.Values["chat"])
}
}
func TestAllocateRouteSession_TelegramForumTopicsRemainIsolatedByDefault(t *testing.T) {
first := AllocateRouteSession(AllocationInput{
AgentID: "main",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "-1001234567890",
ChatType: "group",
TopicID: "42",
SenderID: "7",
},
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"chat"},
},
})
second := AllocateRouteSession(AllocationInput{
AgentID: "main",
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "-1001234567890",
ChatType: "group",
TopicID: "99",
SenderID: "7",
},
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"chat"},
},
})
if first.SessionKey == second.SessionKey {
t.Fatalf("forum topics should not share default session key: %q", first.SessionKey)
}
if got := first.Scope.Values["chat"]; got != "group:-1001234567890/42" {
t.Fatalf("first.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/42")
}
if got := second.Scope.Values["chat"]; got != "group:-1001234567890/99" {
t.Fatalf("second.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/99")
}
}
func TestAllocateRouteSession_PicoDirectAliasesIncludeLegacyChatKey(t *testing.T) {
allocation := AllocateRouteSession(AllocationInput{
AgentID: "main",
Context: bus.InboundContext{
Channel: "pico",
Account: "default",
ChatID: "pico:session-123",
ChatType: "direct",
SenderID: "pico-user",
},
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
})
if !containsAlias(allocation.SessionAliases, "agent:main:pico:direct:pico:session-123") {
t.Fatalf("SessionAliases = %v, want pico legacy alias", allocation.SessionAliases)
}
}
func TestBuildOpaqueSessionKey_IsStable(t *testing.T) {
first := BuildOpaqueSessionKey("agent:main:direct:user123")
second := BuildOpaqueSessionKey("agent:main:direct:user123")
if first != second {
t.Fatalf("BuildOpaqueSessionKey() mismatch: %q != %q", first, second)
}
if !IsOpaqueSessionKey(first) {
t.Fatalf("expected opaque session key, got %q", first)
}
}
func containsAlias(aliases []string, want string) bool {
for _, alias := range aliases {
if alias == want {
return true
}
}
return false
}
+106
View File
@@ -2,7 +2,9 @@ package session
import (
"context"
"encoding/json"
"log"
"strings"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -15,24 +17,123 @@ type JSONLBackend struct {
store memory.Store
}
type metaAwareStore interface {
GetSessionMeta(ctx context.Context, sessionKey string) (memory.SessionMeta, error)
UpsertSessionMeta(ctx context.Context, sessionKey string, scope json.RawMessage, aliases []string) error
ResolveSessionKey(ctx context.Context, sessionKey string) (string, bool, error)
}
type aliasPromotingStore interface {
PromoteAliasHistory(ctx context.Context, sessionKey string, scope json.RawMessage, aliases []string) (bool, error)
}
// MetadataAwareSessionStore exposes structured session metadata operations.
type MetadataAwareSessionStore interface {
EnsureSessionMetadata(sessionKey string, scope *SessionScope, aliases []string)
ResolveSessionKey(sessionKey string) string
GetSessionScope(sessionKey string) *SessionScope
}
// NewJSONLBackend wraps a memory.Store for use as a SessionStore.
func NewJSONLBackend(store memory.Store) *JSONLBackend {
return &JSONLBackend{store: store}
}
func (b *JSONLBackend) resolveSessionKey(sessionKey string) string {
metaStore, ok := b.store.(metaAwareStore)
if !ok {
return sessionKey
}
resolved, found, err := metaStore.ResolveSessionKey(context.Background(), sessionKey)
if err != nil {
log.Printf("session: resolve session key: %v", err)
return sessionKey
}
if found && resolved != "" {
return resolved
}
return sessionKey
}
// ResolveSessionKey maps aliases onto their canonical session key when the
// underlying store supports structured metadata. Unknown aliases fall back to
// the original input so existing callers remain compatible.
func (b *JSONLBackend) ResolveSessionKey(sessionKey string) string {
return b.resolveSessionKey(sessionKey)
}
// EnsureSessionMetadata persists scope and alias metadata for a session.
func (b *JSONLBackend) EnsureSessionMetadata(sessionKey string, scope *SessionScope, aliases []string) {
metaStore, ok := b.store.(metaAwareStore)
if !ok {
return
}
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return
}
var rawScope json.RawMessage
if scope != nil {
data, err := json.Marshal(scope)
if err != nil {
log.Printf("session: encode session scope: %v", err)
return
}
rawScope = data
}
ctx := context.Background()
if err := metaStore.UpsertSessionMeta(ctx, sessionKey, rawScope, aliases); err != nil {
log.Printf("session: upsert session metadata: %v", err)
return
}
if promotingStore, ok := b.store.(aliasPromotingStore); ok {
if _, err := promotingStore.PromoteAliasHistory(ctx, sessionKey, rawScope, aliases); err != nil {
log.Printf("session: promote alias history: %v", err)
}
}
}
// GetSessionScope reads structured scope metadata for a session key or alias.
func (b *JSONLBackend) GetSessionScope(sessionKey string) *SessionScope {
metaStore, ok := b.store.(metaAwareStore)
if !ok {
return nil
}
sessionKey = b.resolveSessionKey(sessionKey)
meta, err := metaStore.GetSessionMeta(context.Background(), sessionKey)
if err != nil {
log.Printf("session: get session metadata: %v", err)
return nil
}
if len(meta.Scope) == 0 {
return nil
}
var scope SessionScope
if err := json.Unmarshal(meta.Scope, &scope); err != nil {
log.Printf("session: decode session scope: %v", err)
return nil
}
return CloneScope(&scope)
}
func (b *JSONLBackend) AddMessage(sessionKey, role, content string) {
sessionKey = b.resolveSessionKey(sessionKey)
if err := b.store.AddMessage(context.Background(), sessionKey, role, content); err != nil {
log.Printf("session: add message: %v", err)
}
}
func (b *JSONLBackend) AddFullMessage(sessionKey string, msg providers.Message) {
sessionKey = b.resolveSessionKey(sessionKey)
if err := b.store.AddFullMessage(context.Background(), sessionKey, msg); err != nil {
log.Printf("session: add full message: %v", err)
}
}
func (b *JSONLBackend) GetHistory(key string) []providers.Message {
key = b.resolveSessionKey(key)
msgs, err := b.store.GetHistory(context.Background(), key)
if err != nil {
log.Printf("session: get history: %v", err)
@@ -42,6 +143,7 @@ func (b *JSONLBackend) GetHistory(key string) []providers.Message {
}
func (b *JSONLBackend) GetSummary(key string) string {
key = b.resolveSessionKey(key)
summary, err := b.store.GetSummary(context.Background(), key)
if err != nil {
log.Printf("session: get summary: %v", err)
@@ -51,18 +153,21 @@ func (b *JSONLBackend) GetSummary(key string) string {
}
func (b *JSONLBackend) SetSummary(key, summary string) {
key = b.resolveSessionKey(key)
if err := b.store.SetSummary(context.Background(), key, summary); err != nil {
log.Printf("session: set summary: %v", err)
}
}
func (b *JSONLBackend) SetHistory(key string, history []providers.Message) {
key = b.resolveSessionKey(key)
if err := b.store.SetHistory(context.Background(), key, history); err != nil {
log.Printf("session: set history: %v", err)
}
}
func (b *JSONLBackend) TruncateHistory(key string, keepLast int) {
key = b.resolveSessionKey(key)
if err := b.store.TruncateHistory(context.Background(), key, keepLast); err != nil {
log.Printf("session: truncate history: %v", err)
}
@@ -72,6 +177,7 @@ func (b *JSONLBackend) TruncateHistory(key string, keepLast int) {
// immediately, the data is already durable. Save runs compaction to reclaim
// space from logically truncated messages (no-op when there are none).
func (b *JSONLBackend) Save(key string) error {
key = b.resolveSessionKey(key)
return b.store.Compact(context.Background(), key)
}
+125
View File
@@ -4,8 +4,10 @@ import (
"fmt"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
@@ -177,3 +179,126 @@ func TestJSONLBackend_SummarizeFlow(t *testing.T) {
t.Errorf("first message = %q, want %q", history[0].Content, "msg 16")
}
}
func TestJSONLBackend_ResolveAliasAndPersistMetadata(t *testing.T) {
b := newBackend(t)
scope := &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Account: "default",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "group:c1",
},
}
b.EnsureSessionMetadata("canonical", scope, []string{"legacy"})
if got := b.ResolveSessionKey("legacy"); got != "canonical" {
t.Fatalf("ResolveSessionKey() = %q, want %q", got, "canonical")
}
b.AddMessage("legacy", "user", "hello through alias")
history := b.GetHistory("canonical")
if len(history) != 1 {
t.Fatalf("len(history) = %d, want 1", len(history))
}
if history[0].Content != "hello through alias" {
t.Fatalf("history[0].Content = %q, want %q", history[0].Content, "hello through alias")
}
resolvedScope := b.GetSessionScope("legacy")
if resolvedScope == nil {
t.Fatal("GetSessionScope() returned nil")
}
if resolvedScope.AgentID != scope.AgentID || resolvedScope.Values["chat"] != scope.Values["chat"] {
t.Fatalf("GetSessionScope() = %+v, want %+v", resolvedScope, scope)
}
}
func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyAliasHistory(t *testing.T) {
b := newBackend(t)
legacyKey := "agent:main:direct:legacy-user"
b.AddMessage(legacyKey, "user", "legacy history")
b.SetSummary(legacyKey, "legacy summary")
canonicalKey := session.BuildOpaqueSessionKey(legacyKey)
b.EnsureSessionMetadata(canonicalKey, &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
}, []string{legacyKey})
if got := b.ResolveSessionKey(legacyKey); got != canonicalKey {
t.Fatalf("ResolveSessionKey() = %q, want %q", got, canonicalKey)
}
history := b.GetHistory(canonicalKey)
if len(history) != 1 || history[0].Content != "legacy history" {
t.Fatalf("promoted history = %+v", history)
}
if summary := b.GetSummary(canonicalKey); summary != "legacy summary" {
t.Fatalf("promoted summary = %q, want %q", summary, "legacy summary")
}
}
func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyPicoDirectAliasHistory(t *testing.T) {
b := newBackend(t)
legacyKey := "agent:main:pico:direct:pico:session-123"
b.AddMessage(legacyKey, "user", "legacy pico history")
scope := &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "pico",
Account: "default",
Dimensions: []string{"sender"},
Values: map[string]string{
"sender": "pico-user",
},
}
allocation := session.AllocateRouteSession(session.AllocationInput{
AgentID: "main",
Context: bus.InboundContext{
Channel: "pico",
Account: "default",
ChatID: "pico:session-123",
ChatType: "direct",
SenderID: "pico-user",
},
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
})
b.EnsureSessionMetadata(allocation.SessionKey, scope, allocation.SessionAliases)
if got := b.ResolveSessionKey(legacyKey); got != allocation.SessionKey {
t.Fatalf("ResolveSessionKey() = %q, want %q", got, allocation.SessionKey)
}
history := b.GetHistory(allocation.SessionKey)
if len(history) != 1 || history[0].Content != "legacy pico history" {
t.Fatalf("promoted history = %+v", history)
}
}
func TestJSONLBackend_EnsureSessionMetadata_DoesNotOverwriteNonEmptyCanonicalHistory(t *testing.T) {
b := newBackend(t)
canonicalKey := session.BuildOpaqueSessionKey("agent:main:direct:current-user")
legacyKey := "agent:main:direct:legacy-user"
b.AddMessage(canonicalKey, "user", "current canonical history")
b.AddMessage(legacyKey, "user", "legacy history")
b.EnsureSessionMetadata(canonicalKey, &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
}, []string{legacyKey})
history := b.GetHistory(canonicalKey)
if len(history) != 1 || history[0].Content != "current canonical history" {
t.Fatalf("canonical history overwritten: %+v", history)
}
}
+205
View File
@@ -0,0 +1,205 @@
package session
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/routing"
)
const (
sessionKeyV1Prefix = "sk_v1_"
legacyAgentSessionKeyPrefix = "agent:"
)
type ParsedLegacySessionKey struct {
AgentID string
Rest string
}
// BuildOpaqueSessionKey returns a stable opaque session key derived from a
// canonical alias string. The alias remains available through metadata for
// compatibility and migration purposes.
func BuildOpaqueSessionKey(alias string) string {
normalized := strings.TrimSpace(strings.ToLower(alias))
if normalized == "" {
return ""
}
sum := sha256.Sum256([]byte(normalized))
return sessionKeyV1Prefix + hex.EncodeToString(sum[:])
}
// IsOpaqueSessionKey returns true when the key matches the current opaque
// session-key format.
func IsOpaqueSessionKey(key string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), sessionKeyV1Prefix)
}
func IsLegacyAgentSessionKey(key string) bool {
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), legacyAgentSessionKeyPrefix)
}
func IsExplicitSessionKey(key string) bool {
return IsOpaqueSessionKey(key) || IsLegacyAgentSessionKey(key)
}
func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey {
raw := strings.TrimSpace(sessionKey)
if raw == "" {
return nil
}
parts := strings.SplitN(raw, ":", 3)
if len(parts) < 3 || parts[0] != "agent" {
return nil
}
agentID := strings.TrimSpace(parts[1])
rest := parts[2]
if agentID == "" || rest == "" {
return nil
}
return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest}
}
// ResolveAgentID returns the routed agent ID associated with a session. It
// prefers structured session scope metadata when available and falls back to
// legacy agent-scoped session keys for compatibility.
func ResolveAgentID(store any, sessionKey string) string {
if scopeReader, ok := store.(interface {
GetSessionScope(sessionKey string) *SessionScope
}); ok {
scope := scopeReader.GetSessionScope(sessionKey)
if scope != nil && strings.TrimSpace(scope.AgentID) != "" {
return routing.NormalizeAgentID(scope.AgentID)
}
}
if parsed := ParseLegacyAgentSessionKey(sessionKey); parsed != nil {
return routing.NormalizeAgentID(parsed.AgentID)
}
return ""
}
func BuildLegacyMainAlias(agentID string) string {
return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID))
}
// BuildMainSessionKey returns the canonical opaque main-session key for an
// agent. The corresponding legacy alias remains available via
// BuildLegacyMainAlias for compatibility and migration logic.
func BuildMainSessionKey(agentID string) string {
return BuildOpaqueSessionKey(BuildLegacyMainAlias(agentID))
}
func BuildLegacyDirectAliases(agentID, channel, account, peerID string) []string {
agentID = routing.NormalizeAgentID(agentID)
channel = normalizeLegacyChannel(channel)
account = routing.NormalizeAccountID(account)
peerID = strings.ToLower(strings.TrimSpace(peerID))
if peerID == "" {
return nil
}
return []string{
fmt.Sprintf("agent:%s:direct:%s", agentID, peerID),
fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID),
fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, account, peerID),
}
}
func BuildLegacyPeerAlias(agentID, channel, peerKind, peerID string) string {
agentID = routing.NormalizeAgentID(agentID)
channel = normalizeLegacyChannel(channel)
peerKind = strings.ToLower(strings.TrimSpace(peerKind))
if peerKind == "" {
peerKind = "unknown"
}
peerID = strings.ToLower(strings.TrimSpace(peerID))
if peerID == "" {
peerID = "unknown"
}
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
}
// CanonicalSessionIdentityID collapses an identity using identity_links when
// possible, then returns a normalized lowercase identifier.
func CanonicalSessionIdentityID(channel, rawID string, identityLinks map[string][]string) string {
normalizedID := strings.TrimSpace(rawID)
if normalizedID == "" {
return ""
}
if linked := resolveLinkedPeerID(identityLinks, channel, normalizedID); linked != "" {
normalizedID = linked
}
return strings.ToLower(normalizedID)
}
func normalizeLegacyChannel(channel string) string {
channel = strings.ToLower(strings.TrimSpace(channel))
if channel == "" {
return "unknown"
}
return channel
}
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
if len(identityLinks) == 0 {
return ""
}
peerID = strings.TrimSpace(peerID)
if peerID == "" {
return ""
}
candidates := make(map[string]bool)
rawCandidate := strings.ToLower(peerID)
if rawCandidate != "" {
candidates[rawCandidate] = true
}
channel = strings.ToLower(strings.TrimSpace(channel))
if channel != "" {
candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true
}
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
candidates[rawCandidate[idx+1:]] = true
}
for canonical, ids := range identityLinks {
canonicalName := strings.TrimSpace(canonical)
if canonicalName == "" {
continue
}
for _, id := range ids {
normalized := strings.ToLower(strings.TrimSpace(id))
if normalized != "" && candidates[normalized] {
return canonicalName
}
}
}
return ""
}
// CanonicalScopeSignature returns a stable serialized representation of scope.
func CanonicalScopeSignature(scope SessionScope) string {
parts := []string{
fmt.Sprintf("v=%d", scope.Version),
fmt.Sprintf("agent=%s", strings.TrimSpace(strings.ToLower(scope.AgentID))),
fmt.Sprintf("channel=%s", strings.TrimSpace(strings.ToLower(scope.Channel))),
fmt.Sprintf("account=%s", strings.TrimSpace(strings.ToLower(scope.Account))),
}
for _, dimension := range scope.Dimensions {
dimension = strings.TrimSpace(strings.ToLower(dimension))
if dimension == "" {
continue
}
value := strings.TrimSpace(strings.ToLower(scope.Values[dimension]))
parts = append(parts, fmt.Sprintf("%s=%s", dimension, value))
}
return strings.Join(parts, "|")
}
// BuildSessionKey returns the current opaque key for a structured session scope.
func BuildSessionKey(scope SessionScope) string {
return BuildOpaqueSessionKey(CanonicalScopeSignature(scope))
}
+100
View File
@@ -0,0 +1,100 @@
package session
import "testing"
type testScopeReader struct {
scope *SessionScope
}
func (r testScopeReader) GetSessionScope(sessionKey string) *SessionScope {
return CloneScope(r.scope)
}
func TestIsExplicitSessionKey(t *testing.T) {
tests := []struct {
key string
want bool
}{
{"sk_v1_abc", true},
{"agent:main:direct:user123", true},
{"custom-key", false},
{"", false},
}
for _, tt := range tests {
if got := IsExplicitSessionKey(tt.key); got != tt.want {
t.Fatalf("IsExplicitSessionKey(%q) = %v, want %v", tt.key, got, tt.want)
}
}
}
func TestParseLegacyAgentSessionKey(t *testing.T) {
parsed := ParseLegacyAgentSessionKey("agent:sales:telegram:direct:user123")
if parsed == nil {
t.Fatal("expected parsed legacy key, got nil")
}
if parsed.AgentID != "sales" {
t.Fatalf("AgentID = %q, want sales", parsed.AgentID)
}
if parsed.Rest != "telegram:direct:user123" {
t.Fatalf("Rest = %q, want telegram:direct:user123", parsed.Rest)
}
if got := ParseLegacyAgentSessionKey("sk_v1_abc"); got != nil {
t.Fatalf("expected nil for opaque key, got %+v", got)
}
}
func TestBuildLegacyDirectAliases(t *testing.T) {
aliases := BuildLegacyDirectAliases("Main", "Telegram", "BotA", "User123")
want := []string{
"agent:main:direct:user123",
"agent:main:telegram:direct:user123",
"agent:main:telegram:bota:direct:user123",
}
if len(aliases) != len(want) {
t.Fatalf("len(aliases) = %d, want %d", len(aliases), len(want))
}
for i := range want {
if aliases[i] != want[i] {
t.Fatalf("aliases[%d] = %q, want %q", i, aliases[i], want[i])
}
}
}
func TestBuildLegacyPeerAlias(t *testing.T) {
got := BuildLegacyPeerAlias("Main", "Slack", "channel", "C001")
if got != "agent:main:slack:channel:c001" {
t.Fatalf("BuildLegacyPeerAlias() = %q", got)
}
}
func TestBuildMainSessionKey(t *testing.T) {
got := BuildMainSessionKey("Main")
if !IsOpaqueSessionKey(got) {
t.Fatalf("BuildMainSessionKey() = %q, want opaque key", got)
}
if got != BuildOpaqueSessionKey("agent:main:main") {
t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got)
}
}
func TestResolveAgentID_PrefersSessionScope(t *testing.T) {
store := testScopeReader{
scope: &SessionScope{
Version: ScopeVersionV1,
AgentID: "Support",
Channel: "slack",
},
}
if got := ResolveAgentID(store, "sk_v1_anything"); got != "support" {
t.Fatalf("ResolveAgentID() = %q, want support", got)
}
}
func TestResolveAgentID_FallsBackToLegacyKey(t *testing.T) {
if got := ResolveAgentID(nil, "agent:Sales:telegram:direct:user123"); got != "sales" {
t.Fatalf("ResolveAgentID() = %q, want sales", got)
}
}
+32
View File
@@ -0,0 +1,32 @@
package session
// ScopeVersionV1 is the first structured session-scope schema version.
const ScopeVersionV1 = 1
// SessionScope describes the semantic session partition selected for a turn.
type SessionScope struct {
Version int `json:"version"`
AgentID string `json:"agent_id"`
Channel string `json:"channel"`
Account string `json:"account"`
Dimensions []string `json:"dimensions"`
Values map[string]string `json:"values"`
}
// CloneScope returns a deep copy of scope.
func CloneScope(scope *SessionScope) *SessionScope {
if scope == nil {
return nil
}
cloned := *scope
if len(scope.Dimensions) > 0 {
cloned.Dimensions = append([]string(nil), scope.Dimensions...)
}
if len(scope.Values) > 0 {
cloned.Values = make(map[string]string, len(scope.Values))
for key, value := range scope.Values {
cloned.Values[key] = value
}
}
return &cloned
}
+38 -1
View File
@@ -1,6 +1,10 @@
package tools
import "context"
import (
"context"
"github.com/sipeed/picoclaw/pkg/session"
)
// Tool is the interface that all tools must implement.
type Tool interface {
@@ -25,6 +29,9 @@ var (
ctxKeyChatID = &toolCtxKey{"chatID"}
ctxKeyMessageID = &toolCtxKey{"messageID"}
ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"}
ctxKeyAgentID = &toolCtxKey{"agentID"}
ctxKeySessionKey = &toolCtxKey{"sessionKey"}
ctxKeySessionScope = &toolCtxKey{"sessionScope"}
)
// WithToolContext returns a child context carrying channel and chatID.
@@ -51,6 +58,18 @@ func WithToolInboundContext(
return ctx
}
// WithToolSessionContext returns a child context carrying turn-scoped session metadata.
func WithToolSessionContext(
ctx context.Context,
agentID, sessionKey string,
scope *session.SessionScope,
) context.Context {
ctx = context.WithValue(ctx, ctxKeyAgentID, agentID)
ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey)
ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope))
return ctx
}
// ToolChannel extracts the channel from ctx, or "" if unset.
func ToolChannel(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChannel).(string)
@@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string {
return v
}
// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset.
func ToolAgentID(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyAgentID).(string)
return v
}
// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset.
func ToolSessionKey(ctx context.Context) string {
v, _ := ctx.Value(ctxKeySessionKey).(string)
return v
}
// ToolSessionScope extracts the active turn's structured session scope from ctx.
func ToolSessionScope(ctx context.Context) *session.SessionScope {
scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope)
return session.CloneScope(scope)
}
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
+2 -4
View File
@@ -311,8 +311,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Context: bus.NewOutboundContext(channel, chatID, ""),
Content: output,
})
return "ok"
@@ -335,8 +334,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Context: bus.NewOutboundContext(channel, chatID, ""),
Content: output,
})
return "ok"
+4 -4
View File
@@ -6,7 +6,7 @@ import (
"sync"
)
type SendCallback func(channel, chatID, content, replyToMessageID string) error
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
// sentTarget records the channel+chatID that the message tool sent to.
type sentTarget struct {
@@ -15,7 +15,7 @@ type sentTarget struct {
}
type MessageTool struct {
sendCallback SendCallback
sendCallback SendCallbackWithContext
mu sync.Mutex
sentTargets []sentTarget // Tracks all targets sent to in the current round
}
@@ -86,7 +86,7 @@ func (t *MessageTool) HasSentTo(channel, chatID string) bool {
return false
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
t.sendCallback = callback
}
@@ -115,7 +115,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil {
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
+49 -5
View File
@@ -4,16 +4,22 @@ import (
"context"
"errors"
"testing"
"github.com/sipeed/picoclaw/pkg/session"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
}
return nil
})
@@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
return nil
@@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return sendErr
})
@@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return nil
})
@@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
tool := NewMessageTool()
var sentReplyTo string
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentReplyTo = replyToMessageID
return nil
})
@@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
}
}
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
tool := NewMessageTool()
var gotAgentID, gotSessionKey string
var gotScope *session.SessionScope
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
gotAgentID = ToolAgentID(ctx)
gotSessionKey = ToolSessionKey(ctx)
gotScope = ToolSessionScope(ctx)
return nil
})
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "direct:test-chat-id",
},
})
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if gotAgentID != "main" {
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
}
if gotSessionKey != "sk_v1_tool" {
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
}
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
}
}
+276 -117
View File
@@ -13,7 +13,9 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -49,26 +51,12 @@ type sessionChatMessage struct {
Media []string `json:"media,omitempty"`
}
type sessionMetaFile struct {
Key string `json:"key"`
Summary string `json:"summary"`
Skip int `json:"skip"`
Count int `json:"count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// picoSessionPrefix is the key prefix used by the gateway's routing for Pico
// channel sessions. The full key format is:
//
// agent:main:pico:direct:pico:<session-uuid>
//
// The sanitized filename replaces ':' with '_', so on disk it becomes:
//
// agent_main_pico_direct_pico_<session-uuid>.json
// legacyPicoSessionPrefix is the legacy key prefix used by older Pico JSON/JSONL
// sessions before structured scope metadata existed.
const (
picoSessionPrefix = "agent:main:pico:direct:pico:"
sanitizedPicoSessionPrefix = "agent_main_pico_direct_pico_"
legacyPicoSessionPrefix = "agent:main:pico:direct:pico:"
picoSessionPrefix = legacyPicoSessionPrefix
// Keep the session API aligned with the shared JSONL store reader limit in
// pkg/memory/jsonl.go so oversized lines fail consistently everywhere.
maxSessionJSONLLineSize = 10 * 1024 * 1024
@@ -82,28 +70,23 @@ func defaultToolFeedbackMaxArgsLength() int {
return defaults.GetToolFeedbackMaxArgsLength()
}
// extractPicoSessionID extracts the session UUID from a full session key.
// extractLegacyPicoSessionID extracts the session UUID from an old Pico key.
// Returns the UUID and true if the key matches the Pico session pattern.
func extractPicoSessionID(key string) (string, bool) {
if strings.HasPrefix(key, picoSessionPrefix) {
return strings.TrimPrefix(key, picoSessionPrefix), true
}
return "", false
}
func extractPicoSessionIDFromSanitizedKey(key string) (string, bool) {
if strings.HasPrefix(key, sanitizedPicoSessionPrefix) {
return strings.TrimPrefix(key, sanitizedPicoSessionPrefix), true
func extractLegacyPicoSessionID(key string) (string, bool) {
if strings.HasPrefix(key, legacyPicoSessionPrefix) {
return strings.TrimPrefix(key, legacyPicoSessionPrefix), true
}
return "", false
}
func sanitizeSessionKey(key string) string {
return strings.ReplaceAll(key, ":", "_")
key = strings.ReplaceAll(key, ":", "_")
key = strings.ReplaceAll(key, "/", "_")
key = strings.ReplaceAll(key, "\\", "_")
return key
}
func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error) {
path := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID)+".json")
func (h *Handler) readLegacySession(path string) (sessionFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return sessionFile{}, err
@@ -116,18 +99,18 @@ func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error)
return sess, nil
}
func (h *Handler) readSessionMeta(path, sessionKey string) (sessionMetaFile, error) {
func (h *Handler) readSessionMeta(path, sessionKey string) (memory.SessionMeta, error) {
data, err := os.ReadFile(path)
if os.IsNotExist(err) {
return sessionMetaFile{Key: sessionKey}, nil
return memory.SessionMeta{Key: sessionKey}, nil
}
if err != nil {
return sessionMetaFile{}, err
return memory.SessionMeta{}, err
}
var meta sessionMetaFile
var meta memory.SessionMeta
if err := json.Unmarshal(data, &meta); err != nil {
return sessionMetaFile{}, err
return memory.SessionMeta{}, err
}
if meta.Key == "" {
meta.Key = sessionKey
@@ -170,8 +153,7 @@ func (h *Handler) readSessionMessages(path string, skip int) ([]providers.Messag
return msgs, nil
}
func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) {
sessionKey := picoSessionPrefix + sessionID
func (h *Handler) readJSONLSession(dir, sessionKey string) (sessionFile, error) {
base := filepath.Join(dir, sanitizeSessionKey(sessionKey))
jsonlPath := base + ".jsonl"
metaPath := base + ".meta.json"
@@ -208,6 +190,213 @@ func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) {
}, nil
}
type picoJSONLSessionRef struct {
ID string
Key string
}
type picoLegacySessionRef struct {
ID string
Path string
}
func extractPicoSessionIDFromScope(scope session.SessionScope) (string, bool) {
if !strings.EqualFold(strings.TrimSpace(scope.Channel), "pico") {
return "", false
}
candidates := []string{
strings.TrimSpace(scope.Values["sender"]),
strings.TrimSpace(scope.Values["chat"]),
}
for _, candidate := range candidates {
if candidate == "" {
continue
}
if idx := strings.Index(candidate, "pico:"); idx >= 0 {
sessionID := strings.TrimSpace(candidate[idx+len("pico:"):])
if sessionID != "" {
return sessionID, true
}
}
}
return "", false
}
func sessionRefFromMeta(meta memory.SessionMeta) (picoJSONLSessionRef, bool) {
if len(meta.Scope) == 0 {
if sessionID, ok := extractLegacyPicoSessionID(meta.Key); ok {
return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true
}
for _, alias := range meta.Aliases {
if sessionID, ok := extractLegacyPicoSessionID(alias); ok {
return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true
}
}
return picoJSONLSessionRef{}, false
}
var scope session.SessionScope
if err := json.Unmarshal(meta.Scope, &scope); err != nil {
return picoJSONLSessionRef{}, false
}
sessionID, ok := extractPicoSessionIDFromScope(scope)
if !ok {
if legacySessionID, ok := extractLegacyPicoSessionID(meta.Key); ok {
return picoJSONLSessionRef{ID: legacySessionID, Key: meta.Key}, true
}
for _, alias := range meta.Aliases {
if legacySessionID, ok := extractLegacyPicoSessionID(alias); ok {
return picoJSONLSessionRef{ID: legacySessionID, Key: meta.Key}, true
}
}
return picoJSONLSessionRef{}, false
}
return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true
}
func (h *Handler) findPicoJSONLSessions(dir string) ([]picoJSONLSessionRef, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
refs := make([]picoJSONLSessionRef, 0)
seen := make(map[string]struct{})
metaBackedBases := make(map[string]struct{})
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
continue
}
name := entry.Name()
metaPath := filepath.Join(dir, name)
meta, err := h.readSessionMeta(metaPath, "")
if err != nil {
continue
}
ref, ok := sessionRefFromMeta(meta)
if !ok || ref.Key == "" || ref.ID == "" {
continue
}
metaBackedBases[strings.TrimSuffix(name, ".meta.json")] = struct{}{}
if _, exists := seen[ref.ID]; exists {
continue
}
seen[ref.ID] = struct{}{}
refs = append(refs, ref)
}
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") {
continue
}
name := entry.Name()
base := strings.TrimSuffix(name, ".jsonl")
if _, ok := metaBackedBases[base]; ok {
continue
}
ref, ok := jsonlSessionRefFromFilename(name)
if !ok || ref.Key == "" || ref.ID == "" {
continue
}
if _, exists := seen[ref.ID]; exists {
continue
}
seen[ref.ID] = struct{}{}
refs = append(refs, ref)
}
return refs, nil
}
func (h *Handler) findPicoJSONLSession(dir, sessionID string) (picoJSONLSessionRef, error) {
refs, err := h.findPicoJSONLSessions(dir)
if err != nil {
return picoJSONLSessionRef{}, err
}
for _, ref := range refs {
if ref.ID == sessionID {
return ref, nil
}
}
return picoJSONLSessionRef{}, os.ErrNotExist
}
func (h *Handler) findLegacyPicoSessions(dir string) ([]picoLegacySessionRef, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
refs := make([]picoLegacySessionRef, 0)
seen := make(map[string]struct{})
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() || filepath.Ext(name) != ".json" || strings.HasSuffix(name, ".meta.json") {
continue
}
path := filepath.Join(dir, entry.Name())
sess, err := h.readLegacySession(path)
if err != nil || isEmptySession(sess) {
continue
}
sessionID, ok := extractLegacyPicoSessionID(sess.Key)
if !ok || sessionID == "" {
continue
}
if _, exists := seen[sessionID]; exists {
continue
}
seen[sessionID] = struct{}{}
refs = append(refs, picoLegacySessionRef{ID: sessionID, Path: path})
}
return refs, nil
}
func jsonlSessionRefFromFilename(name string) (picoJSONLSessionRef, bool) {
if !strings.HasSuffix(name, ".jsonl") {
return picoJSONLSessionRef{}, false
}
base := strings.TrimSuffix(name, ".jsonl")
if base == "" {
return picoJSONLSessionRef{}, false
}
legacyPrefix := sanitizeSessionKey(legacyPicoSessionPrefix)
if strings.HasPrefix(base, legacyPrefix) {
sessionID := strings.TrimPrefix(base, legacyPrefix)
if sessionID == "" {
return picoJSONLSessionRef{}, false
}
return picoJSONLSessionRef{
ID: sessionID,
Key: legacyPicoSessionPrefix + sessionID,
}, true
}
if session.IsOpaqueSessionKey(base) {
return picoJSONLSessionRef{
ID: base,
Key: base,
}, true
}
return picoJSONLSessionRef{}, false
}
func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessionRef, error) {
refs, err := h.findLegacyPicoSessions(dir)
if err != nil {
return picoLegacySessionRef{}, err
}
for _, ref := range refs {
if ref.ID == sessionID {
return ref, nil
}
}
return picoLegacySessionRef{}, os.ErrNotExist
}
func buildSessionListItem(sessionID string, sess sessionFile, toolFeedbackMaxArgsLength int) sessionListItem {
preview := ""
for _, msg := range sess.Messages {
@@ -458,8 +647,7 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if _, err := os.ReadDir(dir); err != nil {
// Directory doesn't exist yet = no sessions
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]sessionListItem{})
@@ -469,74 +657,29 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
items := []sessionListItem{}
seen := make(map[string]struct{})
for _, entry := range entries {
if entry.IsDir() {
continue
if refs, findErr := h.findPicoJSONLSessions(dir); findErr == nil {
for _, ref := range refs {
sess, loadErr := h.readJSONLSession(dir, ref.Key)
if loadErr != nil || isEmptySession(sess) {
continue
}
seen[ref.ID] = struct{}{}
items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength))
}
}
name := entry.Name()
var (
sessionID string
sess sessionFile
loadErr error
ok bool
)
switch {
case strings.HasSuffix(name, ".jsonl"):
sessionID, ok = extractPicoSessionIDFromSanitizedKey(strings.TrimSuffix(name, ".jsonl"))
if !ok {
if legacyRefs, findErr := h.findLegacyPicoSessions(dir); findErr == nil {
for _, ref := range legacyRefs {
if _, exists := seen[ref.ID]; exists {
continue
}
sess, loadErr = h.readJSONLSession(dir, sessionID)
if loadErr == nil && isEmptySession(sess) {
sess, loadErr := h.readLegacySession(ref.Path)
if loadErr != nil || isEmptySession(sess) {
continue
}
case strings.HasSuffix(name, ".meta.json"):
continue
case filepath.Ext(name) == ".json":
base := strings.TrimSuffix(name, ".json")
if _, statErr := os.Stat(filepath.Join(dir, base+".jsonl")); statErr == nil {
if jsonlSessionID, found := extractPicoSessionIDFromSanitizedKey(base); found {
if jsonlSess, jsonlErr := h.readJSONLSession(
dir,
jsonlSessionID,
); jsonlErr == nil &&
!isEmptySession(jsonlSess) {
continue
}
}
}
data, err := os.ReadFile(filepath.Join(dir, name))
if err != nil {
continue
}
if err := json.Unmarshal(data, &sess); err != nil {
continue
}
if isEmptySession(sess) {
continue
}
sessionID, ok = extractPicoSessionID(sess.Key)
if !ok {
continue
}
if _, exists := seen[sessionID]; exists {
continue
}
default:
continue
seen[ref.ID] = struct{}{}
items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength))
}
if loadErr != nil {
continue
}
if _, exists := seen[sessionID]; exists {
continue
}
seen[sessionID] = struct{}{}
items = append(items, buildSessionListItem(sessionID, sess, toolFeedbackMaxArgsLength))
}
// Sort by updated descending (most recent first)
@@ -590,13 +733,20 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
return
}
sess, err := h.readJSONLSession(dir, sessionID)
ref, refErr := h.findPicoJSONLSession(dir, sessionID)
var sess sessionFile
err = refErr
if refErr == nil {
sess, err = h.readJSONLSession(dir, ref.Key)
}
if err == nil && isEmptySession(sess) {
err = os.ErrNotExist
}
if err != nil {
if errors.Is(err, os.ErrNotExist) {
sess, err = h.readLegacySession(dir, sessionID)
if legacyRef, legacyErr := h.findLegacyPicoSession(dir, sessionID); legacyErr == nil {
sess, err = h.readLegacySession(legacyRef.Path)
}
if err == nil && isEmptySession(sess) {
err = os.ErrNotExist
}
@@ -639,21 +789,30 @@ func (h *Handler) handleDeleteSession(w http.ResponseWriter, r *http.Request) {
return
}
base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID))
jsonlPath := base + ".jsonl"
metaPath := base + ".meta.json"
legacyPath := base + ".json"
removed := false
for _, path := range []string{jsonlPath, metaPath, legacyPath} {
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
continue
if ref, err := h.findPicoJSONLSession(dir, sessionID); err == nil {
base := filepath.Join(dir, sanitizeSessionKey(ref.Key))
for _, path := range []string{base + ".jsonl", base + ".meta.json"} {
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
continue
}
http.Error(w, "failed to delete session", http.StatusInternalServerError)
return
}
http.Error(w, "failed to delete session", http.StatusInternalServerError)
return
removed = true
}
}
if legacyRef, err := h.findLegacyPicoSession(dir, sessionID); err == nil {
if err := os.Remove(legacyRef.Path); err != nil {
if !os.IsNotExist(err) {
http.Error(w, "failed to delete session", http.StatusInternalServerError)
return
}
} else {
removed = true
}
removed = true
}
if !removed {
+166 -12
View File
@@ -36,12 +36,12 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
store, storeErr := memory.NewJSONLStore(dir)
if storeErr != nil {
t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := picoSessionPrefix + "history-jsonl"
sessionKey := legacyPicoSessionPrefix + "history-jsonl"
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Content: "Explain why the history API is empty after migration.",
@@ -106,12 +106,12 @@ func TestHandleListSessions_TitleUsesFirstUserMessage(t *testing.T) {
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
store, storeErr := memory.NewJSONLStore(dir)
if storeErr != nil {
t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := picoSessionPrefix + "summary-title"
sessionKey := legacyPicoSessionPrefix + "summary-title"
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Content: "fallback preview",
@@ -164,7 +164,7 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "detail-jsonl"
sessionKey := legacyPicoSessionPrefix + "detail-jsonl"
for _, msg := range []providers.Message{
{Role: "user", Content: "first"},
{Role: "assistant", Content: "second"},
@@ -218,6 +218,81 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) {
}
}
func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, storeErr := memory.NewJSONLStore(dir)
if storeErr != nil {
t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := "sk_v1_scope_discovery"
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Content: "scope discovered session",
}); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
if err := store.SetSummary(nil, sessionKey, "scope summary"); err != nil {
t.Fatalf("SetSummary() error = %v", err)
}
scopeData, err := json.Marshal(session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "pico",
Account: "default",
Dimensions: []string{"sender"},
Values: map[string]string{
"sender": "pico:scope-jsonl",
},
})
if err != nil {
t.Fatalf("Marshal(scope) error = %v", err)
}
if err := store.UpsertSessionMeta(nil, sessionKey, scopeData, nil); err != nil {
t.Fatalf("UpsertSessionMeta() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
listRec := httptest.NewRecorder()
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
mux.ServeHTTP(listRec, listReq)
if listRec.Code != http.StatusOK {
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
}
var items []sessionListItem
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
t.Fatalf("Unmarshal(list) error = %v", err)
}
if len(items) != 1 {
t.Fatalf("len(items) = %d, want 1", len(items))
}
if items[0].ID != "scope-jsonl" {
t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "scope-jsonl")
}
detailRec := httptest.NewRecorder()
detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/scope-jsonl", nil)
mux.ServeHTTP(detailRec, detailReq)
if detailRec.Code != http.StatusOK {
t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String())
}
deleteRec := httptest.NewRecorder()
deleteReq := httptest.NewRequest(http.MethodDelete, "/api/sessions/scope-jsonl", nil)
mux.ServeHTTP(deleteRec, deleteReq)
if deleteRec.Code != http.StatusNoContent {
t.Fatalf("delete status = %d, want %d, body=%s", deleteRec.Code, http.StatusNoContent, deleteRec.Body.String())
}
}
func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
@@ -784,7 +859,7 @@ func TestHandleDeleteSession_JSONLStorage(t *testing.T) {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "delete-jsonl"
sessionKey := legacyPicoSessionPrefix + "delete-jsonl"
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Content: "delete me",
@@ -821,7 +896,7 @@ func TestHandleGetSession_LegacyJSONFallback(t *testing.T) {
dir := sessionsTestDir(t, configPath)
manager := session.NewSessionManager(dir)
sessionKey := picoSessionPrefix + "legacy-json"
sessionKey := legacyPicoSessionPrefix + "legacy-json"
manager.AddMessage(sessionKey, "user", "legacy user")
manager.AddMessage(sessionKey, "assistant", "legacy assistant")
if err := manager.Save(sessionKey); err != nil {
@@ -846,7 +921,7 @@ func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) {
defer cleanup()
dir := sessionsTestDir(t, configPath)
base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+"empty-jsonl"))
base := filepath.Join(dir, sanitizeSessionKey(legacyPicoSessionPrefix+"empty-jsonl"))
if err := os.WriteFile(base+".jsonl", []byte{}, 0o644); err != nil {
t.Fatalf("WriteFile(jsonl) error = %v", err)
}
@@ -879,3 +954,82 @@ func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) {
t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusNotFound, detailRec.Body.String())
}
}
func TestHandleSessions_ListsLegacyJSONLWithoutMeta(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
sessionKey := legacyPicoSessionPrefix + "missing-meta"
base := filepath.Join(dir, sanitizeSessionKey(sessionKey))
line, err := json.Marshal(providers.Message{Role: "user", Content: "recover me"})
if err != nil {
t.Fatalf("Marshal(message) error = %v", err)
}
if err := os.WriteFile(base+".jsonl", append(line, '\n'), 0o644); err != nil {
t.Fatalf("WriteFile(jsonl) error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
listRec := httptest.NewRecorder()
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
mux.ServeHTTP(listRec, listReq)
if listRec.Code != http.StatusOK {
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
}
var items []sessionListItem
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
t.Fatalf("Unmarshal(list) error = %v", err)
}
if len(items) != 1 {
t.Fatalf("len(items) = %d, want 1", len(items))
}
if items[0].ID != "missing-meta" {
t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "missing-meta")
}
detailRec := httptest.NewRecorder()
detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/missing-meta", nil)
mux.ServeHTTP(detailRec, detailReq)
if detailRec.Code != http.StatusOK {
t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String())
}
}
func TestHandleSessions_IgnoresMetaJSONInLegacyFallback(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
metaOnly := filepath.Join(dir, "agent_main_pico_direct_pico_meta-only.meta.json")
metaOnlyContent := []byte(`{"key":"agent:main:pico:direct:pico:meta-only","summary":"meta only"}`)
if err := os.WriteFile(metaOnly, metaOnlyContent, 0o644); err != nil {
t.Fatalf("WriteFile(meta) error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
listRec := httptest.NewRecorder()
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
mux.ServeHTTP(listRec, listReq)
if listRec.Code != http.StatusOK {
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
}
var items []sessionListItem
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
t.Fatalf("Unmarshal(list) error = %v", err)
}
if len(items) != 0 {
t.Fatalf("len(items) = %d, want 0", len(items))
}
}