mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2249 from alexhoshina/refactor-inbound-context-routing-session
Refactor inbound context routing session
This commit is contained in:
+61
-116
@@ -120,133 +120,78 @@ dammi le ultime news
|
||||
- Unknown slash command (for example `/foo`) passes through to normal LLM processing.
|
||||
- Registered but unsupported command on the current channel (for example `/show` on WhatsApp) returns an explicit user-facing error and stops further processing.
|
||||
|
||||
### Agent Bindings (Route messages to specific agents)
|
||||
### Routing
|
||||
|
||||
Use `bindings` in `config.json` to route incoming messages to different agents by channel/account/context.
|
||||
Routing is configured through `agents.dispatch.rules`.
|
||||
|
||||
Each rule matches against the normalized inbound context produced by channels.
|
||||
Rules are evaluated from top to bottom. The first matching rule wins. If no
|
||||
rule matches, PicoClaw falls back to the configured default agent.
|
||||
|
||||
Supported match fields:
|
||||
|
||||
* `channel`
|
||||
* `account`
|
||||
* `space`
|
||||
* `chat`
|
||||
* `topic`
|
||||
* `sender`
|
||||
* `mentioned`
|
||||
|
||||
Match values use the same scope vocabulary as the session system:
|
||||
|
||||
* `space`: `workspace:t001`, `guild:123456`
|
||||
* `chat`: `direct:user123`, `group:-100123`, `channel:c123`
|
||||
* `topic`: `topic:42`
|
||||
* `sender`: a normalized sender identifier for the platform
|
||||
|
||||
Rules may optionally override the global `session.dimensions` value through
|
||||
`session_dimensions`. This allows routing and session allocation to stay aligned
|
||||
without reintroducing the old `bindings` or `dm_scope` formats.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model_name": "gpt-4o-mini"
|
||||
},
|
||||
"list": [
|
||||
{ "id": "main", "default": true, "name": "Main Assistant" },
|
||||
{ "id": "support", "name": "Support Assistant" },
|
||||
{ "id": "sales", "name": "Sales Assistant" }
|
||||
]
|
||||
},
|
||||
"bindings": [
|
||||
{
|
||||
"agent_id": "support",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"account_id": "*",
|
||||
"peer": { "kind": "direct", "id": "user123" }
|
||||
}
|
||||
},
|
||||
{
|
||||
"agent_id": "sales",
|
||||
"match": {
|
||||
"channel": "discord",
|
||||
"account_id": "my-discord-bot",
|
||||
"guild_id": "987654321"
|
||||
}
|
||||
{ "id": "main", "default": true },
|
||||
{ "id": "support" },
|
||||
{ "id": "sales" }
|
||||
],
|
||||
"dispatch": {
|
||||
"rules": [
|
||||
{
|
||||
"name": "vip in support group",
|
||||
"agent": "sales",
|
||||
"when": {
|
||||
"channel": "telegram",
|
||||
"chat": "group:-1001234567890",
|
||||
"sender": "12345"
|
||||
},
|
||||
"session_dimensions": ["chat", "sender"]
|
||||
},
|
||||
{
|
||||
"name": "telegram support group",
|
||||
"agent": "support",
|
||||
"when": {
|
||||
"channel": "telegram",
|
||||
"chat": "group:-1001234567890"
|
||||
},
|
||||
"session_dimensions": ["chat"]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### `bindings` fields
|
||||
|
||||
| Field | Required | Description |
|
||||
|-------|----------|-------------|
|
||||
| `agent_id` | Yes | Target agent id in `agents.list` |
|
||||
| `match.channel` | Yes | Channel name (e.g. `telegram`, `discord`) |
|
||||
| `match.account_id` | No | Channel account filter. Use `"*"` for all accounts of that channel. If omitted, only default account is matched |
|
||||
| `match.peer.kind` + `match.peer.id` | No | Exact peer match (e.g. direct chat / topic / group id) |
|
||||
| `match.guild_id` | No | Guild/server-level match |
|
||||
| `match.team_id` | No | Team/workspace-level match |
|
||||
|
||||
#### Matching priority
|
||||
|
||||
When multiple bindings exist, PicoClaw resolves in this order:
|
||||
|
||||
1. `peer`
|
||||
2. `parent_peer` (for thread/topic parent contexts)
|
||||
3. `guild_id`
|
||||
4. `team_id`
|
||||
5. `account_id` (non-wildcard)
|
||||
6. channel wildcard (`account_id: "*"`)
|
||||
7. default agent
|
||||
|
||||
If a binding points to a missing `agent_id`, PicoClaw falls back to the default agent.
|
||||
|
||||
#### How matching works (step-by-step)
|
||||
|
||||
1. PicoClaw first filters bindings by `match.channel` (must equal current channel).
|
||||
2. It then filters by `match.account_id`:
|
||||
- omitted: match only the channel's default account
|
||||
- `"*"`: match all accounts on this channel
|
||||
- explicit value: exact account id match (case-insensitive)
|
||||
3. From the remaining candidates, it applies the priority chain above and stops at the first hit.
|
||||
|
||||
In other words: **channel + account form the candidate set; peer/guild/team then decide final winner**.
|
||||
|
||||
#### Common recipes
|
||||
|
||||
**1) Route one specific DM user to a specialist agent**
|
||||
|
||||
```json
|
||||
{
|
||||
"agent_id": "support",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"account_id": "*",
|
||||
"peer": { "kind": "direct", "id": "user123" }
|
||||
},
|
||||
"session": {
|
||||
"dimensions": ["chat"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**2) Route one Discord server (guild) to a dedicated agent**
|
||||
|
||||
```json
|
||||
{
|
||||
"agent_id": "sales",
|
||||
"match": {
|
||||
"channel": "discord",
|
||||
"account_id": "my-discord-bot",
|
||||
"guild_id": "987654321"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3) Route all remaining traffic of a channel to a fallback agent**
|
||||
|
||||
```json
|
||||
{
|
||||
"agent_id": "main",
|
||||
"match": {
|
||||
"channel": "discord",
|
||||
"account_id": "*"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Authoring guidelines (important)
|
||||
|
||||
- Keep exactly one clear default agent in `agents.list` (`"default": true`).
|
||||
- Put specific rules (`peer`, `guild_id`, `team_id`) and broad rules (`account_id: "*"` only) together safely; priority already guarantees specific rules win.
|
||||
- Avoid duplicate rules with the same specificity and match values. If duplicates exist, the first matching entry in the config array wins.
|
||||
- Ensure every `agent_id` exists in `agents.list`; unknown IDs silently fall back to default.
|
||||
|
||||
#### Troubleshooting checklist
|
||||
|
||||
- **Rule not taking effect?** Check `match.channel` spelling first (must be exact).
|
||||
- **Expected account-specific routing but still using default?** Verify `match.account_id` equals actual runtime account id.
|
||||
- **Wildcard catches too much traffic?** Add more specific `peer/guild/team` rules for critical paths.
|
||||
- **Unexpected default fallback?** Confirm `agent_id` exists and is not misspelled.
|
||||
In the example above, the VIP rule must appear before the broader group rule.
|
||||
Because routing is strictly ordered, more specific rules should be placed
|
||||
earlier and broader fallback rules later.
|
||||
|
||||
### 🔒 Security Sandbox
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) e
|
||||
if result, ok := m.forceCompression(req.SessionKey); ok {
|
||||
m.al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
m.al.newTurnEventScope("", req.SessionKey).meta(0, "forceCompression", "turn.context.compress"),
|
||||
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: req.Reason,
|
||||
DroppedMessages: result.DroppedMessages,
|
||||
@@ -247,7 +247,7 @@ func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey
|
||||
agent.Sessions.Save(sessionKey)
|
||||
m.al.emitEvent(
|
||||
EventKindSessionSummarize,
|
||||
m.al.newTurnEventScope(agent.ID, sessionKey).meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
SessionSummarizePayload{
|
||||
SummarizedMessages: len(validMessages),
|
||||
KeptMessages: keepCount,
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// DispatchRequest is the normalized runtime input passed into the agent loop
|
||||
// after routing and session allocation have completed.
|
||||
type DispatchRequest struct {
|
||||
SessionKey string
|
||||
SessionAliases []string
|
||||
InboundContext *bus.InboundContext
|
||||
RouteResult *routing.ResolvedRoute
|
||||
SessionScope *session.SessionScope
|
||||
UserMessage string
|
||||
Media []string
|
||||
}
|
||||
|
||||
func (r DispatchRequest) Channel() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.Channel
|
||||
}
|
||||
|
||||
func (r DispatchRequest) ChatID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.ChatID
|
||||
}
|
||||
|
||||
func (r DispatchRequest) MessageID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.MessageID
|
||||
}
|
||||
|
||||
func (r DispatchRequest) ReplyToMessageID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.ReplyToMessageID
|
||||
}
|
||||
|
||||
func (r DispatchRequest) SenderID() string {
|
||||
if r.InboundContext == nil {
|
||||
return ""
|
||||
}
|
||||
return r.InboundContext.SenderID
|
||||
}
|
||||
|
||||
func normalizeProcessOptionsInPlace(opts *processOptions) {
|
||||
if opts == nil {
|
||||
return
|
||||
}
|
||||
*opts = normalizeProcessOptions(*opts)
|
||||
}
|
||||
|
||||
func normalizeProcessOptions(opts processOptions) processOptions {
|
||||
if opts.Dispatch.SessionKey == "" {
|
||||
opts.Dispatch.SessionKey = strings.TrimSpace(opts.SessionKey)
|
||||
}
|
||||
if len(opts.Dispatch.SessionAliases) == 0 && len(opts.SessionAliases) > 0 {
|
||||
opts.Dispatch.SessionAliases = append([]string(nil), opts.SessionAliases...)
|
||||
}
|
||||
if opts.Dispatch.UserMessage == "" {
|
||||
opts.Dispatch.UserMessage = opts.UserMessage
|
||||
}
|
||||
if len(opts.Dispatch.Media) == 0 && len(opts.Media) > 0 {
|
||||
opts.Dispatch.Media = append([]string(nil), opts.Media...)
|
||||
}
|
||||
if opts.Dispatch.RouteResult == nil {
|
||||
opts.Dispatch.RouteResult = cloneResolvedRoute(opts.RouteResult)
|
||||
}
|
||||
if opts.Dispatch.SessionScope == nil {
|
||||
opts.Dispatch.SessionScope = session.CloneScope(opts.SessionScope)
|
||||
}
|
||||
if opts.Dispatch.InboundContext == nil {
|
||||
if opts.InboundContext != nil {
|
||||
opts.Dispatch.InboundContext = cloneInboundContext(opts.InboundContext)
|
||||
} else if opts.Channel != "" || opts.ChatID != "" || opts.SenderID != "" ||
|
||||
opts.MessageID != "" || opts.ReplyToMessageID != "" {
|
||||
inbound := bus.InboundContext{
|
||||
Channel: strings.TrimSpace(opts.Channel),
|
||||
ChatID: strings.TrimSpace(opts.ChatID),
|
||||
SenderID: strings.TrimSpace(opts.SenderID),
|
||||
MessageID: strings.TrimSpace(opts.MessageID),
|
||||
ReplyToMessageID: strings.TrimSpace(opts.ReplyToMessageID),
|
||||
}
|
||||
inbound.ChatType = inferChatTypeFromSessionScope(opts.Dispatch.SessionScope)
|
||||
if inbound.Channel != "" || inbound.ChatID != "" || inbound.SenderID != "" ||
|
||||
inbound.MessageID != "" || inbound.ReplyToMessageID != "" {
|
||||
inbound = bus.NormalizeInboundMessage(bus.InboundMessage{Context: inbound}).Context
|
||||
opts.Dispatch.InboundContext = &inbound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep legacy mirrors populated while the rest of the runtime migrates.
|
||||
opts.SessionKey = opts.Dispatch.SessionKey
|
||||
opts.SessionAliases = append([]string(nil), opts.Dispatch.SessionAliases...)
|
||||
opts.UserMessage = opts.Dispatch.UserMessage
|
||||
opts.Media = append([]string(nil), opts.Dispatch.Media...)
|
||||
opts.InboundContext = cloneInboundContext(opts.Dispatch.InboundContext)
|
||||
opts.RouteResult = cloneResolvedRoute(opts.Dispatch.RouteResult)
|
||||
opts.SessionScope = session.CloneScope(opts.Dispatch.SessionScope)
|
||||
if opts.InboundContext != nil {
|
||||
if opts.Channel == "" {
|
||||
opts.Channel = opts.InboundContext.Channel
|
||||
}
|
||||
if opts.ChatID == "" {
|
||||
opts.ChatID = opts.InboundContext.ChatID
|
||||
}
|
||||
if opts.MessageID == "" {
|
||||
opts.MessageID = opts.InboundContext.MessageID
|
||||
}
|
||||
if opts.ReplyToMessageID == "" {
|
||||
opts.ReplyToMessageID = opts.InboundContext.ReplyToMessageID
|
||||
}
|
||||
if opts.SenderID == "" {
|
||||
opts.SenderID = opts.InboundContext.SenderID
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func inferChatTypeFromSessionScope(scope *session.SessionScope) string {
|
||||
if scope == nil || len(scope.Values) == 0 {
|
||||
return ""
|
||||
}
|
||||
chatValue := strings.TrimSpace(scope.Values["chat"])
|
||||
if chatValue == "" {
|
||||
return ""
|
||||
}
|
||||
chatType, _, ok := strings.Cut(chatValue, ":")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(chatType))
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
func TestNormalizeProcessOptions_PopulatesDispatchFromLegacyFields(t *testing.T) {
|
||||
opts := normalizeProcessOptions(processOptions{
|
||||
SessionKey: "session-1",
|
||||
SessionAliases: []string{"legacy:one"},
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-1",
|
||||
MessageID: "msg-1",
|
||||
ReplyToMessageID: "reply-1",
|
||||
SenderID: "user-1",
|
||||
UserMessage: "hello",
|
||||
Media: []string{"media://one"},
|
||||
})
|
||||
|
||||
if opts.Dispatch.SessionKey != "session-1" {
|
||||
t.Fatalf("Dispatch.SessionKey = %q, want session-1", opts.Dispatch.SessionKey)
|
||||
}
|
||||
if len(opts.Dispatch.SessionAliases) != 1 || opts.Dispatch.SessionAliases[0] != "legacy:one" {
|
||||
t.Fatalf("Dispatch.SessionAliases = %v, want [legacy:one]", opts.Dispatch.SessionAliases)
|
||||
}
|
||||
if opts.Dispatch.Channel() != "telegram" || opts.Dispatch.ChatID() != "chat-1" {
|
||||
t.Fatalf(
|
||||
"dispatch addressing = (%q,%q), want (telegram,chat-1)",
|
||||
opts.Dispatch.Channel(),
|
||||
opts.Dispatch.ChatID(),
|
||||
)
|
||||
}
|
||||
if opts.Dispatch.SenderID() != "user-1" || opts.Dispatch.MessageID() != "msg-1" {
|
||||
t.Fatalf("dispatch sender/message = (%q,%q)", opts.Dispatch.SenderID(), opts.Dispatch.MessageID())
|
||||
}
|
||||
if opts.Dispatch.ReplyToMessageID() != "reply-1" {
|
||||
t.Fatalf("Dispatch.ReplyToMessageID() = %q, want reply-1", opts.Dispatch.ReplyToMessageID())
|
||||
}
|
||||
if opts.Dispatch.UserMessage != "hello" {
|
||||
t.Fatalf("Dispatch.UserMessage = %q, want hello", opts.Dispatch.UserMessage)
|
||||
}
|
||||
if len(opts.Dispatch.Media) != 1 || opts.Dispatch.Media[0] != "media://one" {
|
||||
t.Fatalf("Dispatch.Media = %v, want [media://one]", opts.Dispatch.Media)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProcessOptions_UsesDispatchAsSourceOfTruth(t *testing.T) {
|
||||
inbound := &bus.InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C123",
|
||||
ChatType: "channel",
|
||||
SenderID: "U123",
|
||||
MessageID: "m-1",
|
||||
ReplyToMessageID: "parent-1",
|
||||
}
|
||||
route := &routing.ResolvedRoute{
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
AccountID: "workspace-a",
|
||||
MatchedBy: "dispatch.rule:test",
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"chat", "sender"},
|
||||
},
|
||||
}
|
||||
scope := &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "channel:c123",
|
||||
},
|
||||
}
|
||||
|
||||
opts := normalizeProcessOptions(processOptions{
|
||||
Dispatch: DispatchRequest{
|
||||
SessionKey: "sk_v1_example",
|
||||
SessionAliases: []string{"agent:support:slack:channel:c123"},
|
||||
InboundContext: inbound,
|
||||
RouteResult: route,
|
||||
SessionScope: scope,
|
||||
UserMessage: "hello",
|
||||
Media: []string{"media://one"},
|
||||
},
|
||||
})
|
||||
|
||||
if opts.SessionKey != "sk_v1_example" {
|
||||
t.Fatalf("SessionKey = %q, want sk_v1_example", opts.SessionKey)
|
||||
}
|
||||
if opts.Channel != "slack" || opts.ChatID != "C123" {
|
||||
t.Fatalf("legacy mirrors = (%q,%q), want (slack,C123)", opts.Channel, opts.ChatID)
|
||||
}
|
||||
if opts.SenderID != "U123" || opts.MessageID != "m-1" {
|
||||
t.Fatalf("legacy sender/message = (%q,%q)", opts.SenderID, opts.MessageID)
|
||||
}
|
||||
if opts.ReplyToMessageID != "parent-1" {
|
||||
t.Fatalf("ReplyToMessageID = %q, want parent-1", opts.ReplyToMessageID)
|
||||
}
|
||||
if opts.RouteResult == nil || opts.RouteResult.AgentID != "support" {
|
||||
t.Fatalf("RouteResult = %#v, want support route", opts.RouteResult)
|
||||
}
|
||||
if opts.SessionScope == nil || opts.SessionScope.AgentID != "support" {
|
||||
t.Fatalf("SessionScope = %#v, want support scope", opts.SessionScope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProcessOptions_InfersLegacyChatTypeFromSessionScope(t *testing.T) {
|
||||
opts := normalizeProcessOptions(processOptions{
|
||||
Channel: "telegram",
|
||||
ChatID: "-100123",
|
||||
SenderID: "user-1",
|
||||
UserMessage: "hello",
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "group:-100123",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if opts.Dispatch.InboundContext == nil {
|
||||
t.Fatal("Dispatch.InboundContext is nil")
|
||||
}
|
||||
if opts.Dispatch.InboundContext.ChatType != "group" {
|
||||
t.Fatalf("Dispatch.InboundContext.ChatType = %q, want group", opts.Dispatch.InboundContext.ChatType)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -136,6 +138,31 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "tester",
|
||||
},
|
||||
RouteResult: &routing.ResolvedRoute{
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
AccountID: routing.DefaultAccountID,
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
MatchedBy: "default",
|
||||
},
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
Account: routing.DefaultAccountID,
|
||||
Dimensions: []string{"sender"},
|
||||
Values: map[string]string{
|
||||
"sender": "tester",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
@@ -176,6 +203,18 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
if evt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
|
||||
}
|
||||
if evt.Context == nil || evt.Context.Inbound == nil {
|
||||
t.Fatalf("event %d missing inbound turn context", i)
|
||||
}
|
||||
if evt.Context.Inbound.Channel != "cli" || evt.Context.Inbound.SenderID != "tester" {
|
||||
t.Fatalf("event %d inbound context = %+v", i, evt.Context.Inbound)
|
||||
}
|
||||
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
|
||||
t.Fatalf("event %d missing route context: %+v", i, evt.Context.Route)
|
||||
}
|
||||
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "tester" {
|
||||
t.Fatalf("event %d missing session scope: %+v", i, evt.Context.Scope)
|
||||
}
|
||||
}
|
||||
|
||||
startPayload, ok := events[0].Payload.(TurnStartPayload)
|
||||
@@ -472,7 +511,6 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
// Use legacyContextManager's summarizeSession via contextManager interface
|
||||
lcm := &legacyContextManager{al: al}
|
||||
lcm.summarizeSession(defaultAgent, "session-1")
|
||||
|
||||
@@ -572,12 +610,6 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
if payload.SourceTool != "async_followup" {
|
||||
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
|
||||
}
|
||||
if payload.Channel != "cli" {
|
||||
t.Fatalf("expected channel cli, got %q", payload.Channel)
|
||||
}
|
||||
if payload.ChatID != "direct" {
|
||||
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
|
||||
}
|
||||
if payload.ContentLen != len("background result") {
|
||||
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
|
||||
}
|
||||
|
||||
+2
-4
@@ -86,6 +86,7 @@ type Event struct {
|
||||
Kind EventKind
|
||||
Time time.Time
|
||||
Meta EventMeta
|
||||
Context *TurnContext
|
||||
Payload any
|
||||
}
|
||||
|
||||
@@ -98,6 +99,7 @@ type EventMeta struct {
|
||||
Iteration int
|
||||
TracePath string
|
||||
Source string
|
||||
turnContext *TurnContext
|
||||
}
|
||||
|
||||
// TurnEndStatus describes the terminal state of a turn.
|
||||
@@ -114,8 +116,6 @@ const (
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
type TurnStartPayload struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
MediaCount int
|
||||
}
|
||||
@@ -217,8 +217,6 @@ type SteeringInjectedPayload struct {
|
||||
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
|
||||
type FollowUpQueuedPayload struct {
|
||||
SourceTool string
|
||||
Channel string
|
||||
ChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
|
||||
+15
-8
@@ -90,12 +90,11 @@ type ToolApprover interface {
|
||||
|
||||
type LLMHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []providers.Message `json:"messages,omitempty"`
|
||||
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
|
||||
}
|
||||
|
||||
@@ -104,6 +103,8 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Messages = cloneProviderMessages(r.Messages)
|
||||
cloned.Tools = cloneToolDefinitions(r.Tools)
|
||||
cloned.Options = cloneStringAnyMap(r.Options)
|
||||
@@ -112,10 +113,9 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
|
||||
type LLMHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Response *providers.LLMResponse `json:"response,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
@@ -123,12 +123,15 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Response = cloneLLMResponse(r.Response)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
@@ -141,6 +144,8 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.HookResult = cloneToolResult(r.HookResult)
|
||||
return &cloned
|
||||
@@ -148,10 +153,9 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
|
||||
type ToolApprovalRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
@@ -159,18 +163,19 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolResultHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
@@ -178,6 +183,8 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.Result = cloneToolResult(r.Result)
|
||||
return &cloned
|
||||
|
||||
+51
-3
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -108,7 +109,8 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
}
|
||||
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
eventCh chan Event
|
||||
lastInbound *bus.InboundContext
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
@@ -125,6 +127,9 @@ func (h *llmObserverHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
if req.Context != nil {
|
||||
h.lastInbound = cloneInboundContext(req.Context.Inbound)
|
||||
}
|
||||
next := req.Clone()
|
||||
next.Model = "hook-model"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
@@ -157,6 +162,31 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "hook-user",
|
||||
},
|
||||
RouteResult: &routing.ResolvedRoute{
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
AccountID: routing.DefaultAccountID,
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
MatchedBy: "default",
|
||||
},
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
Account: routing.DefaultAccountID,
|
||||
Dimensions: []string{"sender"},
|
||||
Values: map[string]string{
|
||||
"sender": "hook-user",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
@@ -171,12 +201,30 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
if lastModel != "hook-model" {
|
||||
t.Fatalf("expected model hook-model, got %q", lastModel)
|
||||
}
|
||||
if hook.lastInbound == nil {
|
||||
t.Fatal("expected hook to receive inbound context")
|
||||
}
|
||||
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
|
||||
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
|
||||
}
|
||||
if hook.lastInbound != nil && hook.lastInbound.ChatID != "direct" {
|
||||
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.eventCh:
|
||||
if evt.Kind != EventKindTurnEnd {
|
||||
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
||||
}
|
||||
if evt.Context == nil || evt.Context.Inbound == nil {
|
||||
t.Fatal("expected observer event to carry inbound context")
|
||||
}
|
||||
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
|
||||
t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route)
|
||||
}
|
||||
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" {
|
||||
t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for hook observer event")
|
||||
}
|
||||
@@ -725,7 +773,7 @@ func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
@@ -801,7 +849,7 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
|
||||
+450
-207
File diff suppressed because it is too large
Load Diff
+513
-106
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -38,7 +39,13 @@ func (f *fakeChannel) ReasoningChannelID() string { return f.id
|
||||
|
||||
type fakeMediaChannel struct {
|
||||
fakeChannel
|
||||
sentMedia []bus.OutboundMediaMessage
|
||||
sentMessages []bus.OutboundMessage
|
||||
sentMedia []bus.OutboundMediaMessage
|
||||
}
|
||||
|
||||
func (f *fakeMediaChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
f.sentMessages = append(f.sentMessages, msg)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
@@ -139,7 +146,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "discord",
|
||||
SenderID: "discord:123",
|
||||
Sender: bus.SenderInfo{
|
||||
@@ -147,7 +154,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
},
|
||||
ChatID: "group-1",
|
||||
Content: "hello",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -198,12 +205,12 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/use shell explain how to list files",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -288,12 +295,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/use shell",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() arm error = %v", err)
|
||||
}
|
||||
@@ -301,12 +308,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
|
||||
t.Fatalf("arm response = %q, want armed confirmation", response)
|
||||
}
|
||||
|
||||
response, err = al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err = al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "explain how to list files",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() follow-up error = %v", err)
|
||||
}
|
||||
@@ -619,12 +626,12 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
path: imagePath,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -661,16 +668,21 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
route, _, err := al.resolveMessageRoute(bus.InboundMessage{
|
||||
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
sessionKey := resolveScopeKey(route, "")
|
||||
sessionKey := resolveScopeKey(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})).SessionKey, "")
|
||||
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) == 0 {
|
||||
t.Fatal("expected session history to be saved")
|
||||
@@ -714,12 +726,12 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
|
||||
loop: al,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -734,6 +746,263 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAgentLoop_ResponseHandledToolPublishesForUserWhenSendResponseDisabled(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = tmpDir
|
||||
cfg.Agents.Defaults.ModelName = "test-model"
|
||||
cfg.Agents.Defaults.MaxTokens = 4096
|
||||
cfg.Agents.Defaults.MaxToolIterations = 10
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &handledUserProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
al.SetMediaStore(store)
|
||||
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
||||
al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel))
|
||||
al.RegisterTool(&handledUserTool{})
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
Dispatch: DispatchRequest{
|
||||
SessionKey: "session-1",
|
||||
UserMessage: "take a screenshot of the screen and send it to me",
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: defaultAgent.ID,
|
||||
Channel: "telegram",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "direct:chat1",
|
||||
},
|
||||
},
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
},
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop() error = %v", err)
|
||||
}
|
||||
if response != "" {
|
||||
t.Fatalf("expected no final response when tool already handled delivery, got %q", response)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for len(telegramChannel.sentMessages) == 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if len(telegramChannel.sentMessages) != 1 {
|
||||
t.Fatalf("expected exactly 1 sent text message, got %d", len(telegramChannel.sentMessages))
|
||||
}
|
||||
if telegramChannel.sentMessages[0].Content != "Handled user output from tool." {
|
||||
t.Fatalf("unexpected sent text message: %+v", telegramChannel.sentMessages[0])
|
||||
}
|
||||
if telegramChannel.sentMessages[0].AgentID != defaultAgent.ID {
|
||||
t.Fatalf("sent text agent_id = %q, want %q", telegramChannel.sentMessages[0].AgentID, defaultAgent.ID)
|
||||
}
|
||||
if telegramChannel.sentMessages[0].SessionKey != "session-1" {
|
||||
t.Fatalf("sent text session_key = %q, want session-1", telegramChannel.sentMessages[0].SessionKey)
|
||||
}
|
||||
if telegramChannel.sentMessages[0].Scope == nil ||
|
||||
telegramChannel.sentMessages[0].Scope.Values["chat"] != "direct:chat1" {
|
||||
t.Fatalf("unexpected sent text scope: %+v", telegramChannel.sentMessages[0].Scope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendEventContextFields_IncludesInboundRouteAndScope(t *testing.T) {
|
||||
fields := map[string]any{}
|
||||
|
||||
appendEventContextFields(fields, &TurnContext{
|
||||
Inbound: &bus.InboundContext{
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
ChatID: "C123",
|
||||
ChatType: "channel",
|
||||
TopicID: "thread-42",
|
||||
SpaceType: "workspace",
|
||||
SpaceID: "T001",
|
||||
SenderID: "U123",
|
||||
Mentioned: true,
|
||||
},
|
||||
Route: &routing.ResolvedRoute{
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
AccountID: "workspace-a",
|
||||
MatchedBy: "default",
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"chat", "sender"},
|
||||
IdentityLinks: map[string][]string{
|
||||
"canonical-user": {"slack:U123"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
Dimensions: []string{"chat", "sender"},
|
||||
Values: map[string]string{
|
||||
"chat": "channel:c123",
|
||||
"sender": "u123",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if fields["inbound_channel"] != "slack" {
|
||||
t.Fatalf("inbound_channel = %v, want slack", fields["inbound_channel"])
|
||||
}
|
||||
if fields["inbound_topic_id"] != "thread-42" {
|
||||
t.Fatalf("inbound_topic_id = %v, want thread-42", fields["inbound_topic_id"])
|
||||
}
|
||||
if fields["route_matched_by"] != "default" {
|
||||
t.Fatalf("route_matched_by = %v, want default", fields["route_matched_by"])
|
||||
}
|
||||
if fields["route_dimensions"] != "chat,sender" {
|
||||
t.Fatalf("route_dimensions = %v, want chat,sender", fields["route_dimensions"])
|
||||
}
|
||||
if fields["route_identity_link_count"] != 1 {
|
||||
t.Fatalf("route_identity_link_count = %v, want 1", fields["route_identity_link_count"])
|
||||
}
|
||||
if fields["scope_dimensions"] != "chat,sender" {
|
||||
t.Fatalf("scope_dimensions = %v, want chat,sender", fields["scope_dimensions"])
|
||||
}
|
||||
if fields["scope_chat"] != "channel:c123" {
|
||||
t.Fatalf("scope_chat = %v, want channel:c123", fields["scope_chat"])
|
||||
}
|
||||
if fields["scope_sender"] != "u123" {
|
||||
t.Fatalf("scope_sender = %v, want u123", fields["scope_sender"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMessageRoute_UsesInboundContextAccount(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
},
|
||||
List: []config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
{ID: "work"},
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"})
|
||||
|
||||
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
ChatID: "C123",
|
||||
ChatType: "channel",
|
||||
SenderID: "U123",
|
||||
SpaceID: "T001",
|
||||
SpaceType: "workspace",
|
||||
},
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
if route.AgentID != "main" {
|
||||
t.Fatalf("AgentID = %q, want main", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "default" {
|
||||
t.Fatalf("MatchedBy = %q, want default", route.MatchedBy)
|
||||
}
|
||||
if route.AccountID != "workspace-a" {
|
||||
t.Fatalf("AccountID = %q, want workspace-a", route.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMessageRoute_UsesDispatchRulesInOrder(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
},
|
||||
List: []config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
{ID: "support"},
|
||||
{ID: "sales"},
|
||||
},
|
||||
Dispatch: &config.DispatchConfig{
|
||||
Rules: []config.DispatchRule{
|
||||
{
|
||||
Name: "support-group",
|
||||
Agent: "support",
|
||||
When: config.DispatchSelector{
|
||||
Channel: "telegram",
|
||||
Chat: "group:-100123",
|
||||
},
|
||||
SessionDimensions: []string{"chat"},
|
||||
},
|
||||
{
|
||||
Name: "vip-in-group",
|
||||
Agent: "sales",
|
||||
When: config.DispatchSelector{
|
||||
Channel: "telegram",
|
||||
Chat: "group:-100123",
|
||||
Sender: "12345",
|
||||
},
|
||||
SessionDimensions: []string{"chat", "sender"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"})
|
||||
|
||||
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "-100123",
|
||||
ChatType: "group",
|
||||
SenderID: "12345",
|
||||
},
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
if route.AgentID != "support" {
|
||||
t.Fatalf("AgentID = %q, want support", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "dispatch.rule:support-group" {
|
||||
t.Fatalf("MatchedBy = %q, want dispatch.rule:support-group", route.MatchedBy)
|
||||
}
|
||||
if got := route.SessionPolicy.Dimensions; len(got) != 1 || got[0] != "chat" {
|
||||
t.Fatalf("SessionPolicy.Dimensions = %v, want [chat]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := config.DefaultConfig()
|
||||
@@ -765,12 +1034,12 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
path: imagePath,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -975,6 +1244,66 @@ func (m *handledMediaProvider) GetDefaultModel() string {
|
||||
return "handled-media-model"
|
||||
}
|
||||
|
||||
type handledUserProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *handledUserProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Delivering the result now.",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_handled_user",
|
||||
Type: "function",
|
||||
Name: "handled_user_tool",
|
||||
Arguments: map[string]any{},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
return &providers.LLMResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *handledUserProvider) GetDefaultModel() string {
|
||||
return "handled-user-model"
|
||||
}
|
||||
|
||||
type messageToolProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *messageToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_message",
|
||||
Type: "function",
|
||||
Name: "message",
|
||||
Arguments: map[string]any{"content": "direct tool message"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
return &providers.LLMResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *messageToolProvider) GetDefaultModel() string {
|
||||
return "message-tool-model"
|
||||
}
|
||||
|
||||
type artifactThenSendProvider struct {
|
||||
calls int
|
||||
}
|
||||
@@ -1178,6 +1507,24 @@ func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *to
|
||||
return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
type handledUserTool struct{}
|
||||
|
||||
func (m *handledUserTool) Name() string { return "handled_user_tool" }
|
||||
func (m *handledUserTool) Description() string {
|
||||
return "Returns a user-visible result and marks delivery as handled"
|
||||
}
|
||||
|
||||
func (m *handledUserTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *handledUserTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.UserResult("Handled user output from tool.").WithResponseHandled()
|
||||
}
|
||||
|
||||
type handledMediaWithSteeringProvider struct {
|
||||
calls int
|
||||
}
|
||||
@@ -1391,13 +1738,39 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
defer cancel()
|
||||
|
||||
response, err := h.al.processMessage(timeoutCtx, msg)
|
||||
response, err := h.al.processMessage(timeoutCtx, testInboundMessage(msg))
|
||||
if err != nil {
|
||||
tb.Fatalf("processMessage failed: %v", err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func testInboundMessage(msg bus.InboundMessage) bus.InboundMessage {
|
||||
if msg.Context.Channel == "" &&
|
||||
msg.Context.Account == "" &&
|
||||
msg.Context.ChatID == "" &&
|
||||
msg.Context.ChatType == "" &&
|
||||
msg.Context.TopicID == "" &&
|
||||
msg.Context.SpaceID == "" &&
|
||||
msg.Context.SpaceType == "" &&
|
||||
msg.Context.SenderID == "" &&
|
||||
msg.Context.MessageID == "" &&
|
||||
!msg.Context.Mentioned &&
|
||||
msg.Context.ReplyToMessageID == "" &&
|
||||
msg.Context.ReplyToSenderID == "" &&
|
||||
len(msg.Context.ReplyHandles) == 0 &&
|
||||
len(msg.Context.Raw) == 0 {
|
||||
msg.Context = bus.InboundContext{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
ChatType: "direct",
|
||||
SenderID: msg.SenderID,
|
||||
MessageID: msg.MessageID,
|
||||
}
|
||||
}
|
||||
return bus.NormalizeInboundMessage(msg)
|
||||
}
|
||||
|
||||
const responseTimeout = 3 * time.Second
|
||||
|
||||
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
|
||||
@@ -1423,21 +1796,17 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
Peer: extractPeer(msg),
|
||||
})
|
||||
sessionKey := route.SessionKey
|
||||
route := al.registry.ResolveRoute(bus.NormalizeInboundMessage(msg).Context)
|
||||
sessionKey := al.allocateRouteSession(route, msg).SessionKey
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
@@ -1473,7 +1842,7 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-channel-peer",
|
||||
Dimensions: []string{"chat"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1483,21 +1852,22 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
helper := testHelper{al: al}
|
||||
|
||||
baseMsg := bus.InboundMessage{
|
||||
Channel: "whatsapp",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/show channel",
|
||||
Peer: baseMsg.Peer,
|
||||
Context: bus.InboundContext{
|
||||
Channel: baseMsg.Context.Channel,
|
||||
ChatID: baseMsg.Context.ChatID,
|
||||
ChatType: baseMsg.Context.ChatType,
|
||||
SenderID: baseMsg.Context.SenderID,
|
||||
},
|
||||
Content: "/show channel",
|
||||
})
|
||||
if showResp != "Current Channel: whatsapp" {
|
||||
t.Fatalf("unexpected /show reply: %q", showResp)
|
||||
@@ -1507,11 +1877,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
}
|
||||
|
||||
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/foo",
|
||||
Peer: baseMsg.Peer,
|
||||
Context: bus.InboundContext{
|
||||
Channel: baseMsg.Context.Channel,
|
||||
ChatID: baseMsg.Context.ChatID,
|
||||
ChatType: baseMsg.Context.ChatType,
|
||||
SenderID: baseMsg.Context.SenderID,
|
||||
},
|
||||
Content: "/foo",
|
||||
})
|
||||
if fooResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /foo reply: %q", fooResp)
|
||||
@@ -1521,11 +1893,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
}
|
||||
|
||||
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/new",
|
||||
Peer: baseMsg.Peer,
|
||||
Context: bus.InboundContext{
|
||||
Channel: baseMsg.Context.Channel,
|
||||
ChatID: baseMsg.Context.ChatID,
|
||||
ChatType: baseMsg.Context.ChatType,
|
||||
SenderID: baseMsg.Context.SenderID,
|
||||
},
|
||||
Content: "/new",
|
||||
})
|
||||
if newResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /new reply: %q", newResp)
|
||||
@@ -1578,10 +1952,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
@@ -1592,10 +1962,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
@@ -1643,10 +2009,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to missing",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if switchResp != `model "missing" not found in model_list or providers` {
|
||||
t.Fatalf("unexpected /switch error reply: %q", switchResp)
|
||||
@@ -1657,10 +2019,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
|
||||
@@ -1727,10 +2085,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello before switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if firstResp != "local reply" {
|
||||
t.Fatalf("unexpected response before switch: %q", firstResp)
|
||||
@@ -1750,10 +2104,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
@@ -1764,10 +2114,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello after switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if secondResp != "remote reply" {
|
||||
t.Fatalf("unexpected response after switch: %q", secondResp)
|
||||
@@ -1857,10 +2203,6 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if resp != "light reply" {
|
||||
t.Fatalf("response = %q, want %q", resp, "light reply")
|
||||
@@ -1942,7 +2284,6 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if resp != "fallback reply" {
|
||||
@@ -2020,7 +2361,6 @@ func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if resp != "active provider reply" {
|
||||
@@ -2291,14 +2631,16 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: "test",
|
||||
Peer: &routing.RoutePeer{
|
||||
Kind: "direct",
|
||||
ID: "cron",
|
||||
},
|
||||
route := al.registry.ResolveRoute(bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatType: "direct",
|
||||
SenderID: "cron",
|
||||
})
|
||||
history := defaultAgent.Sessions.GetHistory(route.SessionKey)
|
||||
history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "cron",
|
||||
ChatID: "chat1",
|
||||
})).SessionKey)
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("history len = %d, want 4", len(history))
|
||||
}
|
||||
@@ -2556,8 +2898,7 @@ func TestHandleReasoning(t *testing.T) {
|
||||
for i := 0; ; i++ {
|
||||
fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{
|
||||
Channel: "filler",
|
||||
ChatID: "filler",
|
||||
Context: bus.NewOutboundContext("filler", "filler", ""),
|
||||
Content: fmt.Sprintf("filler-%d", i),
|
||||
})
|
||||
fillCancel()
|
||||
@@ -2631,12 +2972,12 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
|
||||
chManager.RegisterChannel("telegram", &fakeChannel{id: "reason-chat"})
|
||||
al.SetChannelManager(chManager)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -2652,6 +2993,9 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
|
||||
if outbound.ChatID != "reason-chat" {
|
||||
t.Fatalf("reasoning chatID = %q, want %q", outbound.ChatID, "reason-chat")
|
||||
}
|
||||
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "reason-chat" {
|
||||
t.Fatalf("unexpected reasoning context: %+v", outbound.Context)
|
||||
}
|
||||
if outbound.Content != "thinking trace" {
|
||||
t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace")
|
||||
}
|
||||
@@ -2711,8 +3055,12 @@ func TestProcessMessage_PicoPublishesReasoningAsThoughtMessage(t *testing.T) {
|
||||
if thoughtMsg.Channel != "pico" || thoughtMsg.ChatID != "pico:test-session" {
|
||||
t.Fatalf("thought message route = %s/%s, want pico/pico:test-session", thoughtMsg.Channel, thoughtMsg.ChatID)
|
||||
}
|
||||
if thoughtMsg.Metadata[metadataKeyMessageKind] != messageKindThought {
|
||||
t.Fatalf("thought metadata kind = %q, want %q", thoughtMsg.Metadata[metadataKeyMessageKind], messageKindThought)
|
||||
if thoughtMsg.Context.Raw[metadataKeyMessageKind] != messageKindThought {
|
||||
t.Fatalf(
|
||||
"thought metadata kind = %q, want %q",
|
||||
thoughtMsg.Context.Raw[metadataKeyMessageKind],
|
||||
messageKindThought,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2793,12 +3141,12 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
provider := &toolFeedbackProvider{filePath: heartbeatFile}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user-1",
|
||||
ChatID: "chat-1",
|
||||
Content: "check tool feedback",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -2814,14 +3162,73 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
if outbound.ChatID != "chat-1" {
|
||||
t.Fatalf("tool feedback chatID = %q, want %q", outbound.ChatID, "chat-1")
|
||||
}
|
||||
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
|
||||
t.Fatalf("unexpected tool feedback context: %+v", outbound.Context)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, "`read_file`") {
|
||||
t.Fatalf("tool feedback content = %q, want read_file preview", outbound.Content)
|
||||
}
|
||||
if outbound.AgentID != "main" {
|
||||
t.Fatalf("tool feedback agent_id = %q, want main", outbound.AgentID)
|
||||
}
|
||||
if outbound.SessionKey == "" {
|
||||
t.Fatal("expected tool feedback to carry session_key")
|
||||
}
|
||||
if outbound.Scope == nil || outbound.Scope.AgentID != "main" || outbound.Scope.Channel != "telegram" {
|
||||
t.Fatalf("expected tool feedback scope, got %+v", outbound.Scope)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected outbound tool feedback for regular messages")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = t.TempDir()
|
||||
cfg.Agents.Defaults.ModelName = "test-model"
|
||||
cfg.Agents.Defaults.MaxTokens = 4096
|
||||
cfg.Agents.Defaults.MaxToolIterations = 10
|
||||
cfg.Session.Dimensions = []string{"chat"}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &messageToolProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user-1",
|
||||
ChatID: "chat-1",
|
||||
Content: "send a direct message",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response == "" {
|
||||
t.Fatal("expected processMessage() to return a final loop response")
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "direct tool message" {
|
||||
t.Fatalf("outbound content = %q, want direct tool message", outbound.Content)
|
||||
}
|
||||
if outbound.AgentID != "main" {
|
||||
t.Fatalf("outbound agent_id = %q, want main", outbound.AgentID)
|
||||
}
|
||||
if outbound.SessionKey == "" {
|
||||
t.Fatal("expected message tool outbound to carry session_key")
|
||||
}
|
||||
if outbound.Scope == nil || outbound.Scope.Values["chat"] != "direct:chat-1" {
|
||||
t.Fatalf("unexpected message tool outbound scope: %+v", outbound.Scope)
|
||||
}
|
||||
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
|
||||
t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected message tool outbound")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_PicoPublishesAssistantContentDuringToolCallsWithoutFinalDuplicate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -3363,13 +3770,13 @@ func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
|
||||
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
SessionKey: "test-session",
|
||||
Content: "trigger recovery",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -3405,12 +3812,12 @@ func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
|
||||
return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "hello",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package agent
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -64,9 +65,9 @@ func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
|
||||
return agent, ok
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message.
|
||||
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(input)
|
||||
// ResolveRoute determines which agent handles the normalized inbound context.
|
||||
func (r *AgentRegistry) ResolveRoute(inbound bus.InboundContext) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(inbound)
|
||||
}
|
||||
|
||||
// ListAgentIDs returns all registered agent IDs.
|
||||
|
||||
+35
-8
@@ -3,12 +3,14 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -290,12 +292,22 @@ func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
sessionKey, channel, chatID string,
|
||||
scope *session.SessionScope,
|
||||
steeringMsgs []providers.Message,
|
||||
) (string, error) {
|
||||
dispatch := DispatchRequest{
|
||||
SessionKey: sessionKey,
|
||||
SessionScope: session.CloneScope(scope),
|
||||
}
|
||||
if channel != "" || chatID != "" {
|
||||
dispatch.InboundContext = &bus.InboundContext{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
ChatType: inferChatTypeFromSessionScope(scope),
|
||||
}
|
||||
}
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Dispatch: dispatch,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
@@ -310,9 +322,19 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
return nil
|
||||
}
|
||||
|
||||
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
|
||||
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
||||
return agent
|
||||
agentIDs := registry.ListAgentIDs()
|
||||
sort.Strings(agentIDs)
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
if !ok || agent == nil {
|
||||
continue
|
||||
}
|
||||
resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey)
|
||||
if resolvedAgentID == "" {
|
||||
continue
|
||||
}
|
||||
if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok {
|
||||
return scopedAgent
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,7 +374,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
}
|
||||
}
|
||||
|
||||
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
|
||||
var scope *session.SessionScope
|
||||
if metaStore, ok := agent.Sessions.(session.MetadataAwareSessionStore); ok {
|
||||
scope = metaStore.GetSessionScope(sessionKey)
|
||||
}
|
||||
|
||||
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, scope, steeringMsgs)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
|
||||
+89
-36
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -357,7 +358,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-peer",
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -365,14 +366,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
|
||||
|
||||
activeMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "active turn",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "active turn",
|
||||
}
|
||||
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
|
||||
if !ok {
|
||||
@@ -380,14 +380,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
otherMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user2",
|
||||
ChatID: "chat2",
|
||||
Content: "other session",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user2",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat2",
|
||||
ChatType: "direct",
|
||||
SenderID: "user2",
|
||||
},
|
||||
Content: "other session",
|
||||
}
|
||||
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
|
||||
if !ok {
|
||||
@@ -422,9 +421,9 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timeout waiting for requeued message on outbound bus")
|
||||
case requeued := <-msgBus.OutboundChan():
|
||||
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
|
||||
t.Fatalf("timeout waiting for requeued message on inbound bus")
|
||||
case requeued := <-msgBus.InboundChan():
|
||||
if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID ||
|
||||
requeued.Content != otherMsg.Content {
|
||||
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
|
||||
}
|
||||
@@ -841,24 +840,22 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "first message",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "first message",
|
||||
}
|
||||
late := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "late append",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "late append",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
@@ -949,7 +946,7 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
@@ -1013,6 +1010,62 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
List: []config.AgentConfig{
|
||||
{ID: "sales", Default: true},
|
||||
{ID: "support"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
||||
support, ok := al.registry.GetAgent("support")
|
||||
if !ok || support == nil {
|
||||
t.Fatal("expected support agent")
|
||||
}
|
||||
|
||||
metaStore, ok := support.Sessions.(session.MetadataAwareSessionStore)
|
||||
if !ok {
|
||||
t.Fatal("support session store does not support metadata")
|
||||
}
|
||||
|
||||
alias := "agent:support:slack:channel:c001"
|
||||
key := session.BuildOpaqueSessionKey(alias)
|
||||
scope := &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
Account: "default",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "channel:c001",
|
||||
},
|
||||
}
|
||||
metaStore.EnsureSessionMetadata(key, scope, []string{alias})
|
||||
|
||||
got := al.agentForSession(key)
|
||||
if got == nil {
|
||||
t.Fatal("agentForSession() returned nil")
|
||||
}
|
||||
if got.ID != "support" {
|
||||
t.Fatalf("agentForSession() = %q, want %q", got.ID, "support")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
@@ -1060,7 +1113,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.SetMediaStore(store)
|
||||
@@ -1168,7 +1221,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
@@ -1322,7 +1375,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
|
||||
+13
-7
@@ -351,15 +351,17 @@ func spawnSubTurn(
|
||||
}
|
||||
|
||||
// Create processOptions for the child turn
|
||||
dispatch := DispatchRequest{
|
||||
SessionKey: childID,
|
||||
UserMessage: cfg.SystemPrompt,
|
||||
Media: nil,
|
||||
InboundContext: cloneInboundContext(parentTS.opts.Dispatch.InboundContext),
|
||||
}
|
||||
opts := processOptions{
|
||||
SessionKey: childID,
|
||||
Channel: parentTS.channel,
|
||||
ChatID: parentTS.chatID,
|
||||
SenderID: parentTS.opts.SenderID,
|
||||
Dispatch: dispatch,
|
||||
SenderID: parentTS.opts.Dispatch.SenderID(),
|
||||
SenderDisplayName: parentTS.opts.SenderDisplayName,
|
||||
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
|
||||
SystemPromptOverride: cfg.ActualSystemPrompt,
|
||||
Media: nil,
|
||||
InitialSteeringMessages: cfg.InitialMessages,
|
||||
DefaultResponse: "",
|
||||
EnableSummary: false,
|
||||
@@ -369,7 +371,11 @@ func spawnSubTurn(
|
||||
}
|
||||
|
||||
// Create event scope for the child turn
|
||||
scope := al.newTurnEventScope(agent.ID, childID)
|
||||
scope := al.newTurnEventScope(
|
||||
agent.ID,
|
||||
childID,
|
||||
newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope),
|
||||
)
|
||||
|
||||
// Create child turnState using the new API
|
||||
childTS := newTurnState(&agent, opts, scope)
|
||||
|
||||
+15
-12
@@ -56,6 +56,7 @@ type turnState struct {
|
||||
turnID string
|
||||
agentID string
|
||||
sessionKey string
|
||||
turnCtx *TurnContext
|
||||
|
||||
channel string
|
||||
chatID string
|
||||
@@ -115,11 +116,12 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
|
||||
scope: scope,
|
||||
turnID: scope.turnID,
|
||||
agentID: agent.ID,
|
||||
sessionKey: opts.SessionKey,
|
||||
channel: opts.Channel,
|
||||
chatID: opts.ChatID,
|
||||
userMessage: opts.UserMessage,
|
||||
media: append([]string(nil), opts.Media...),
|
||||
sessionKey: opts.Dispatch.SessionKey,
|
||||
turnCtx: cloneTurnContext(scope.context),
|
||||
channel: opts.Dispatch.Channel(),
|
||||
chatID: opts.Dispatch.ChatID(),
|
||||
userMessage: opts.Dispatch.UserMessage,
|
||||
media: append([]string(nil), opts.Dispatch.Media...),
|
||||
phase: TurnPhaseSetup,
|
||||
startedAt: time.Now(),
|
||||
}
|
||||
@@ -127,7 +129,7 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
|
||||
// Bind session store and capture initial history length for rollback logic
|
||||
if agent != nil && agent.Sessions != nil {
|
||||
ts.session = agent.Sessions
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
|
||||
}
|
||||
|
||||
return ts
|
||||
@@ -302,12 +304,13 @@ func (ts *turnState) hardAbortRequested() bool {
|
||||
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
||||
snap := ts.snapshot()
|
||||
return EventMeta{
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
turnContext: cloneTurnContext(ts.turnCtx),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// TurnContext carries normalized turn-scoped facts that can be shared across
|
||||
// events, hooks, and other runtime observers without re-parsing legacy fields.
|
||||
type TurnContext struct {
|
||||
Inbound *bus.InboundContext `json:"inbound,omitempty"`
|
||||
Route *routing.ResolvedRoute `json:"route,omitempty"`
|
||||
Scope *session.SessionScope `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
func newTurnContext(
|
||||
inbound *bus.InboundContext,
|
||||
route *routing.ResolvedRoute,
|
||||
scope *session.SessionScope,
|
||||
) *TurnContext {
|
||||
if inbound == nil && route == nil && scope == nil {
|
||||
return nil
|
||||
}
|
||||
return &TurnContext{
|
||||
Inbound: cloneInboundContext(inbound),
|
||||
Route: cloneResolvedRoute(route),
|
||||
Scope: session.CloneScope(scope),
|
||||
}
|
||||
}
|
||||
|
||||
func cloneTurnContext(ctx *TurnContext) *TurnContext {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *ctx
|
||||
cloned.Inbound = cloneInboundContext(ctx.Inbound)
|
||||
cloned.Route = cloneResolvedRoute(ctx.Route)
|
||||
cloned.Scope = session.CloneScope(ctx.Scope)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneInboundContext(ctx *bus.InboundContext) *bus.InboundContext {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *ctx
|
||||
cloned.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
|
||||
cloned.Raw = cloneStringMap(ctx.Raw)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneStringMap(src map[string]string) map[string]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string]string, len(src))
|
||||
for k, v := range src {
|
||||
cloned[k] = v
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneEventMeta(meta EventMeta) EventMeta {
|
||||
meta.turnContext = cloneTurnContext(meta.turnContext)
|
||||
return meta
|
||||
}
|
||||
|
||||
func cloneResolvedRoute(route *routing.ResolvedRoute) *routing.ResolvedRoute {
|
||||
if route == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *route
|
||||
cloned.SessionPolicy = routing.SessionPolicy{
|
||||
Dimensions: append([]string(nil), route.SessionPolicy.Dimensions...),
|
||||
IdentityLinks: cloneIdentityLinks(route.SessionPolicy.IdentityLinks),
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneIdentityLinks(src map[string][]string) map[string][]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string][]string, len(src))
|
||||
for canonical, ids := range src {
|
||||
dup := make([]string, len(ids))
|
||||
copy(dup, ids)
|
||||
cloned[canonical] = dup
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
+10
-9
@@ -226,8 +226,7 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err})
|
||||
}
|
||||
if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: channelType,
|
||||
ChatID: acc.chatID,
|
||||
Context: bus.NewOutboundContext(channelType, acc.chatID, ""),
|
||||
Content: "Goodbye! Leaving the voice channel.",
|
||||
}); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err})
|
||||
@@ -238,14 +237,16 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
|
||||
oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally."
|
||||
|
||||
if err := a.bus.PublishInbound(ctx, bus.InboundMessage{
|
||||
Channel: channelType,
|
||||
SenderID: acc.speakerID,
|
||||
ChatID: acc.chatID,
|
||||
Content: res.Text + oralPrompt,
|
||||
Peer: bus.Peer{Kind: "channel", ID: acc.chatID},
|
||||
Metadata: map[string]string{
|
||||
"is_voice": "true",
|
||||
Context: bus.InboundContext{
|
||||
Channel: channelType,
|
||||
ChatID: acc.chatID,
|
||||
ChatType: "channel",
|
||||
SenderID: acc.speakerID,
|
||||
Raw: map[string]string{
|
||||
"is_voice": "true",
|
||||
},
|
||||
},
|
||||
Content: res.Text + oralPrompt,
|
||||
}); err != nil {
|
||||
logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err})
|
||||
}
|
||||
|
||||
@@ -185,8 +185,8 @@ func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) {
|
||||
if !strings.Contains(msg.Content, "hello there") {
|
||||
t.Fatalf("unexpected inbound content: %q", msg.Content)
|
||||
}
|
||||
if msg.Metadata["is_voice"] != "true" {
|
||||
t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata)
|
||||
if msg.Context.Raw["is_voice"] != "true" {
|
||||
t.Fatalf("expected is_voice metadata, got %#v", msg.Context.Raw)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("expected inbound publish")
|
||||
|
||||
+19
-1
@@ -12,6 +12,12 @@ import (
|
||||
// ErrBusClosed is returned when publishing to a closed MessageBus.
|
||||
var ErrBusClosed = errors.New("message bus closed")
|
||||
|
||||
var (
|
||||
ErrMissingInboundContext = errors.New("inbound message context is required")
|
||||
ErrMissingOutboundContext = errors.New("outbound message context is required")
|
||||
ErrMissingOutboundMediaContext = errors.New("outbound media context is required")
|
||||
)
|
||||
|
||||
const defaultBusBufferSize = 64
|
||||
|
||||
// StreamDelegate is implemented by the channel Manager to provide streaming
|
||||
@@ -49,7 +55,7 @@ func NewMessageBus() *MessageBus {
|
||||
inbound: make(chan InboundMessage, defaultBusBufferSize),
|
||||
outbound: make(chan OutboundMessage, defaultBusBufferSize),
|
||||
outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize),
|
||||
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer
|
||||
audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer.
|
||||
voiceControls: make(chan VoiceControl, defaultBusBufferSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
@@ -84,6 +90,10 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
msg = NormalizeInboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingInboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.inbound, msg)
|
||||
}
|
||||
|
||||
@@ -92,6 +102,10 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
|
||||
msg = NormalizeOutboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingOutboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outbound, msg)
|
||||
}
|
||||
|
||||
@@ -100,6 +114,10 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
|
||||
msg = NormalizeOutboundMediaMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingOutboundMediaContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outboundMedia, msg)
|
||||
}
|
||||
|
||||
|
||||
+438
-15
@@ -14,10 +14,13 @@ func TestPublishConsume(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
msg := InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(ctx, msg); err != nil {
|
||||
@@ -34,6 +37,138 @@ func TestPublishConsume(t *testing.T) {
|
||||
if got.Channel != "test" {
|
||||
t.Fatalf("expected channel 'test', got %q", got.Channel)
|
||||
}
|
||||
if got.Context.Channel != "test" {
|
||||
t.Fatalf("expected context channel 'test', got %q", got.Context.Channel)
|
||||
}
|
||||
if got.Context.ChatID != "chat1" {
|
||||
t.Fatalf("expected context chat ID 'chat1', got %q", got.Context.ChatID)
|
||||
}
|
||||
if got.Context.SenderID != "user1" {
|
||||
t.Fatalf("expected context sender ID 'user1', got %q", got.Context.SenderID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_NormalizesContext(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
ChatID: "C456/1712",
|
||||
ChatType: "group",
|
||||
TopicID: "1712",
|
||||
SpaceID: "T001",
|
||||
SpaceType: "team",
|
||||
SenderID: "U123",
|
||||
MessageID: "1712.01",
|
||||
ReplyToMessageID: "1700.01",
|
||||
Mentioned: true,
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.InboundChan()
|
||||
if got.Context.Channel != "slack" {
|
||||
t.Fatalf("expected context channel slack, got %q", got.Context.Channel)
|
||||
}
|
||||
if got.Context.Account != "workspace-a" {
|
||||
t.Fatalf("expected context account workspace-a, got %q", got.Context.Account)
|
||||
}
|
||||
if got.Context.ChatType != "group" {
|
||||
t.Fatalf("expected context chat type group, got %q", got.Context.ChatType)
|
||||
}
|
||||
if got.Context.TopicID != "1712" {
|
||||
t.Fatalf("expected topic 1712, got %q", got.Context.TopicID)
|
||||
}
|
||||
if got.Context.SpaceType != "team" || got.Context.SpaceID != "T001" {
|
||||
t.Fatalf("expected team space T001, got %q/%q", got.Context.SpaceType, got.Context.SpaceID)
|
||||
}
|
||||
if !got.Context.Mentioned {
|
||||
t.Fatal("expected mentioned=true in context")
|
||||
}
|
||||
if got.Context.ReplyToMessageID != "1700.01" {
|
||||
t.Fatalf("expected reply_to_message_id 1700.01, got %q", got.Context.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_MirrorsContextIntoConvenienceFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
Account: "bot-a",
|
||||
ChatID: "-1001",
|
||||
ChatType: "group",
|
||||
TopicID: "42",
|
||||
SpaceID: "guild-9",
|
||||
SpaceType: "guild",
|
||||
SenderID: "user-1",
|
||||
MessageID: "777",
|
||||
Mentioned: true,
|
||||
ReplyToMessageID: "666",
|
||||
},
|
||||
Content: "hi",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.InboundChan()
|
||||
if got.Channel != "telegram" {
|
||||
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
|
||||
}
|
||||
if got.ChatID != "-1001" {
|
||||
t.Fatalf("expected legacy chat ID -1001, got %q", got.ChatID)
|
||||
}
|
||||
if got.SenderID != "user-1" {
|
||||
t.Fatalf("expected legacy sender ID user-1, got %q", got.SenderID)
|
||||
}
|
||||
if got.MessageID != "777" {
|
||||
t.Fatalf("expected legacy message ID 777, got %q", got.MessageID)
|
||||
}
|
||||
if got.Context.Account != "bot-a" || got.Context.SpaceID != "guild-9" || got.Context.TopicID != "42" {
|
||||
t.Fatalf("unexpected normalized context: %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "session-1",
|
||||
SenderID: "user-1",
|
||||
MessageID: "msg-1",
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.InboundChan()
|
||||
if got.Context.Channel != "pico" {
|
||||
t.Fatalf("expected context channel pico, got %q", got.Context.Channel)
|
||||
}
|
||||
if got.Context.ChatID != "session-1" {
|
||||
t.Fatalf("expected context chat ID session-1, got %q", got.Context.ChatID)
|
||||
}
|
||||
if got.Context.SenderID != "user-1" {
|
||||
t.Fatalf("expected context sender ID user-1, got %q", got.Context.SenderID)
|
||||
}
|
||||
if got.Context.MessageID != "msg-1" {
|
||||
t.Fatalf("expected context message ID msg-1, got %q", got.Context.MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
@@ -43,8 +178,10 @@ func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "123",
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "123",
|
||||
},
|
||||
Content: "world",
|
||||
}
|
||||
|
||||
@@ -59,6 +196,222 @@ func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
if got.Content != "world" {
|
||||
t.Fatalf("expected content 'world', got %q", got.Content)
|
||||
}
|
||||
if got.Context.Channel != "telegram" || got.Context.ChatID != "123" {
|
||||
t.Fatalf("expected normalized outbound context, got %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_MirrorsContextToLegacyFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-42",
|
||||
ReplyToMessageID: "msg-9",
|
||||
},
|
||||
AgentID: "main",
|
||||
SessionKey: "sk_v1_123",
|
||||
Scope: &OutboundScope{
|
||||
Version: 1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Account: "bot-a",
|
||||
Dimensions: []string{"chat", "sender"},
|
||||
Values: map[string]string{
|
||||
"chat": "direct:chat-42",
|
||||
"sender": "user-1",
|
||||
},
|
||||
},
|
||||
Content: "reply",
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundChan()
|
||||
if got.Channel != "telegram" {
|
||||
t.Fatalf("expected legacy channel telegram, got %q", got.Channel)
|
||||
}
|
||||
if got.ChatID != "chat-42" {
|
||||
t.Fatalf("expected legacy chat ID chat-42, got %q", got.ChatID)
|
||||
}
|
||||
if got.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
|
||||
}
|
||||
if got.AgentID != "main" || got.SessionKey != "sk_v1_123" {
|
||||
t.Fatalf("unexpected outbound turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
|
||||
}
|
||||
if got.Scope == nil || got.Scope.AgentID != "main" || got.Scope.Values["chat"] != "direct:chat-42" {
|
||||
t.Fatalf("unexpected outbound scope: %+v", got.Scope)
|
||||
}
|
||||
if got.Context.Channel != "telegram" || got.Context.ChatID != "chat-42" {
|
||||
t.Fatalf("unexpected outbound context: %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_PreservesExplicitReplyToMessageID(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-42",
|
||||
},
|
||||
ReplyToMessageID: "msg-9",
|
||||
Content: "reply",
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundChan()
|
||||
if got.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
|
||||
}
|
||||
if got.Context.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_PreservesExplicitReplyToMessageIDWhenContextReplyIsBlank(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-42",
|
||||
ReplyToMessageID: " ",
|
||||
},
|
||||
ReplyToMessageID: "msg-9",
|
||||
Content: "reply",
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundChan()
|
||||
if got.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
|
||||
}
|
||||
if got.Context.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := OutboundMediaMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C001",
|
||||
},
|
||||
AgentID: "support",
|
||||
SessionKey: "sk_v1_media",
|
||||
Scope: &OutboundScope{
|
||||
Version: 1,
|
||||
AgentID: "support",
|
||||
Channel: "slack",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "channel:c001",
|
||||
},
|
||||
},
|
||||
Parts: []MediaPart{{Type: "image", Ref: "media://1"}},
|
||||
}
|
||||
|
||||
if err := mb.PublishOutboundMedia(context.Background(), msg); err != nil {
|
||||
t.Fatalf("PublishOutboundMedia failed: %v", err)
|
||||
}
|
||||
|
||||
got := <-mb.OutboundMediaChan()
|
||||
if got.Channel != "slack" {
|
||||
t.Fatalf("expected legacy channel slack, got %q", got.Channel)
|
||||
}
|
||||
if got.ChatID != "C001" {
|
||||
t.Fatalf("expected legacy chat ID C001, got %q", got.ChatID)
|
||||
}
|
||||
if got.AgentID != "support" || got.SessionKey != "sk_v1_media" {
|
||||
t.Fatalf("unexpected outbound media turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey)
|
||||
}
|
||||
if got.Scope == nil || got.Scope.Values["chat"] != "channel:c001" {
|
||||
t.Fatalf("unexpected outbound media scope: %+v", got.Scope)
|
||||
}
|
||||
if got.Context.Channel != "slack" || got.Context.ChatID != "C001" {
|
||||
t.Fatalf("unexpected outbound media context: %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishAudioChunkSubscribe(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
chunk := AudioChunk{
|
||||
SessionID: "voice-1",
|
||||
SpeakerID: "speaker-1",
|
||||
ChatID: "chat-1",
|
||||
Channel: "discord",
|
||||
Sequence: 7,
|
||||
Format: "opus",
|
||||
Data: []byte{0x01, 0x02},
|
||||
}
|
||||
|
||||
if err := mb.PublishAudioChunk(context.Background(), chunk); err != nil {
|
||||
t.Fatalf("PublishAudioChunk failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := <-mb.AudioChunksChan()
|
||||
if !ok {
|
||||
t.Fatal("AudioChunksChan returned ok=false")
|
||||
}
|
||||
if got.SessionID != "voice-1" || got.Sequence != 7 {
|
||||
t.Fatalf("unexpected audio chunk: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishVoiceControlSubscribe(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
ctrl := VoiceControl{
|
||||
SessionID: "voice-1",
|
||||
ChatID: "chat-1",
|
||||
Type: "command",
|
||||
Action: "start",
|
||||
}
|
||||
|
||||
if err := mb.PublishVoiceControl(context.Background(), ctrl); err != nil {
|
||||
t.Fatalf("PublishVoiceControl failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := <-mb.VoiceControlsChan()
|
||||
if !ok {
|
||||
t.Fatal("VoiceControlsChan returned ok=false")
|
||||
}
|
||||
if got.Type != "command" || got.Action != "start" {
|
||||
t.Fatalf("unexpected voice control: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOutboundContext_NormalizesReplyAddress(t *testing.T) {
|
||||
ctx := NewOutboundContext(" telegram ", " chat-42 ", " msg-9 ")
|
||||
if ctx.Channel != "telegram" {
|
||||
t.Fatalf("expected channel telegram, got %q", ctx.Channel)
|
||||
}
|
||||
if ctx.ChatID != "chat-42" {
|
||||
t.Fatalf("expected chat_id chat-42, got %q", ctx.ChatID)
|
||||
}
|
||||
if ctx.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected reply_to_message_id msg-9, got %q", ctx.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
@@ -68,7 +421,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
// Fill the buffer
|
||||
ctx := context.Background()
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -77,7 +438,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
|
||||
err := mb.PublishInbound(cancelCtx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-overflow",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-overflow",
|
||||
},
|
||||
Content: "overflow",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from canceled context, got nil")
|
||||
}
|
||||
@@ -90,7 +459,15 @@ func TestPublishInbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
@@ -100,7 +477,13 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
|
||||
err := mb.PublishOutbound(context.Background(), OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
@@ -112,14 +495,30 @@ func TestConsumeInbound_ContextCancel(t *testing.T) {
|
||||
defer mb.Close()
|
||||
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
|
||||
mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-cancel",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-cancel",
|
||||
},
|
||||
Content: "ContextCancel",
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -213,7 +612,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
|
||||
// Fill the buffer
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -222,7 +629,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
|
||||
err := mb.PublishInbound(timeoutCtx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-overflow",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-overflow",
|
||||
},
|
||||
Content: "overflow",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when buffer is full and context times out")
|
||||
}
|
||||
@@ -240,7 +655,15 @@ func TestCloseIdempotent(t *testing.T) {
|
||||
mb.Close()
|
||||
|
||||
// After close, publish should return ErrBusClosed
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package bus
|
||||
|
||||
import "strings"
|
||||
|
||||
// NormalizeInboundMessage ensures the inbound context is normalized and keeps
|
||||
// convenience mirrors in sync for runtime consumers.
|
||||
func NormalizeInboundMessage(msg InboundMessage) InboundMessage {
|
||||
if msg.Context.Channel == "" {
|
||||
msg.Context.Channel = msg.Channel
|
||||
}
|
||||
if msg.Context.ChatID == "" {
|
||||
msg.Context.ChatID = msg.ChatID
|
||||
}
|
||||
if msg.Context.SenderID == "" {
|
||||
msg.Context.SenderID = msg.SenderID
|
||||
}
|
||||
if msg.Context.MessageID == "" {
|
||||
msg.Context.MessageID = msg.MessageID
|
||||
}
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
msg.Channel = msg.Context.Channel
|
||||
msg.SenderID = msg.Context.SenderID
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
if msg.MessageID == "" {
|
||||
msg.MessageID = msg.Context.MessageID
|
||||
}
|
||||
if msg.Context.MessageID == "" {
|
||||
msg.Context.MessageID = msg.MessageID
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (ctx InboundContext) isZero() bool {
|
||||
return ctx.Channel == "" &&
|
||||
ctx.Account == "" &&
|
||||
ctx.ChatID == "" &&
|
||||
ctx.ChatType == "" &&
|
||||
ctx.TopicID == "" &&
|
||||
ctx.SpaceID == "" &&
|
||||
ctx.SpaceType == "" &&
|
||||
ctx.SenderID == "" &&
|
||||
ctx.MessageID == "" &&
|
||||
!ctx.Mentioned &&
|
||||
ctx.ReplyToMessageID == "" &&
|
||||
ctx.ReplyToSenderID == "" &&
|
||||
len(ctx.ReplyHandles) == 0 &&
|
||||
len(ctx.Raw) == 0
|
||||
}
|
||||
|
||||
func normalizeInboundContext(ctx InboundContext) InboundContext {
|
||||
ctx.Channel = strings.TrimSpace(ctx.Channel)
|
||||
ctx.Account = strings.TrimSpace(ctx.Account)
|
||||
ctx.ChatID = strings.TrimSpace(ctx.ChatID)
|
||||
ctx.ChatType = normalizeKind(ctx.ChatType)
|
||||
ctx.TopicID = strings.TrimSpace(ctx.TopicID)
|
||||
ctx.SpaceID = strings.TrimSpace(ctx.SpaceID)
|
||||
ctx.SpaceType = normalizeKind(ctx.SpaceType)
|
||||
ctx.SenderID = strings.TrimSpace(ctx.SenderID)
|
||||
ctx.MessageID = strings.TrimSpace(ctx.MessageID)
|
||||
ctx.ReplyToMessageID = strings.TrimSpace(ctx.ReplyToMessageID)
|
||||
ctx.ReplyToSenderID = strings.TrimSpace(ctx.ReplyToSenderID)
|
||||
ctx.ReplyHandles = cloneStringMap(ctx.ReplyHandles)
|
||||
ctx.Raw = cloneStringMap(ctx.Raw)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func cloneStringMap(src map[string]string) map[string]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dst := make(map[string]string, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func normalizeKind(kind string) string {
|
||||
return strings.ToLower(strings.TrimSpace(kind))
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package bus
|
||||
|
||||
import "strings"
|
||||
|
||||
// NewOutboundContext builds the minimal normalized addressing context required
|
||||
// to deliver an outbound text message or reply.
|
||||
func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext {
|
||||
return normalizeInboundContext(InboundContext{
|
||||
Channel: strings.TrimSpace(channel),
|
||||
ChatID: strings.TrimSpace(chatID),
|
||||
ReplyToMessageID: strings.TrimSpace(replyToMessageID),
|
||||
})
|
||||
}
|
||||
|
||||
// NormalizeOutboundMessage ensures Context is normalized and keeps convenience
|
||||
// mirrors in sync for runtime consumers.
|
||||
func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage {
|
||||
msg.Channel = strings.TrimSpace(msg.Channel)
|
||||
msg.ChatID = strings.TrimSpace(msg.ChatID)
|
||||
msg.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID)
|
||||
if msg.Context.Channel == "" {
|
||||
msg.Context.Channel = msg.Channel
|
||||
}
|
||||
if msg.Context.ChatID == "" {
|
||||
msg.Context.ChatID = msg.ChatID
|
||||
}
|
||||
if msg.Context.ReplyToMessageID == "" {
|
||||
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
|
||||
}
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
if msg.Channel == "" {
|
||||
msg.Channel = msg.Context.Channel
|
||||
}
|
||||
if msg.ChatID == "" {
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
}
|
||||
if msg.ReplyToMessageID == "" {
|
||||
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
|
||||
}
|
||||
if msg.Context.ReplyToMessageID == "" {
|
||||
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
|
||||
}
|
||||
msg.Scope = cloneOutboundScope(msg.Scope)
|
||||
return msg
|
||||
}
|
||||
|
||||
// NormalizeOutboundMediaMessage ensures media outbound messages also carry a
|
||||
// normalized context while keeping convenience mirrors in sync.
|
||||
func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage {
|
||||
msg.Channel = strings.TrimSpace(msg.Channel)
|
||||
msg.ChatID = strings.TrimSpace(msg.ChatID)
|
||||
if msg.Context.Channel == "" {
|
||||
msg.Context.Channel = msg.Channel
|
||||
}
|
||||
if msg.Context.ChatID == "" {
|
||||
msg.Context.ChatID = msg.ChatID
|
||||
}
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
if msg.Channel == "" {
|
||||
msg.Channel = msg.Context.Channel
|
||||
}
|
||||
if msg.ChatID == "" {
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
}
|
||||
msg.Scope = cloneOutboundScope(msg.Scope)
|
||||
return msg
|
||||
}
|
||||
|
||||
func cloneOutboundScope(scope *OutboundScope) *OutboundScope {
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *scope
|
||||
if len(scope.Dimensions) > 0 {
|
||||
cloned.Dimensions = append([]string(nil), scope.Dimensions...)
|
||||
}
|
||||
if len(scope.Values) > 0 {
|
||||
cloned.Values = make(map[string]string, len(scope.Values))
|
||||
for key, value := range scope.Values {
|
||||
cloned.Values[key] = value
|
||||
}
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
+64
-25
@@ -1,11 +1,5 @@
|
||||
package bus
|
||||
|
||||
// Peer identifies the routing peer for a message (direct, group, channel, etc.)
|
||||
type Peer struct {
|
||||
Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// SenderInfo provides structured sender identity information.
|
||||
type SenderInfo struct {
|
||||
Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ...
|
||||
@@ -15,26 +9,67 @@ type SenderInfo struct {
|
||||
DisplayName string `json:"display_name,omitempty"` // display name
|
||||
}
|
||||
|
||||
// InboundContext captures the normalized, platform-agnostic facts about an
|
||||
// inbound message. This is the source of truth for routing and session
|
||||
// allocation.
|
||||
type InboundContext struct {
|
||||
Channel string `json:"channel"`
|
||||
Account string `json:"account,omitempty"`
|
||||
|
||||
ChatID string `json:"chat_id"`
|
||||
ChatType string `json:"chat_type,omitempty"` // direct / group / channel
|
||||
TopicID string `json:"topic_id,omitempty"`
|
||||
|
||||
SpaceID string `json:"space_id,omitempty"`
|
||||
SpaceType string `json:"space_type,omitempty"` // guild / team / workspace / tenant
|
||||
|
||||
SenderID string `json:"sender_id"`
|
||||
MessageID string `json:"message_id,omitempty"`
|
||||
|
||||
Mentioned bool `json:"mentioned,omitempty"`
|
||||
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
ReplyToSenderID string `json:"reply_to_sender_id,omitempty"`
|
||||
|
||||
ReplyHandles map[string]string `json:"reply_handles,omitempty"`
|
||||
Raw map[string]string `json:"raw,omitempty"`
|
||||
}
|
||||
|
||||
type InboundMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
SenderID string `json:"sender_id"`
|
||||
Sender SenderInfo `json:"sender"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
Peer Peer `json:"peer"` // routing peer
|
||||
MessageID string `json:"message_id,omitempty"` // platform message ID
|
||||
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
|
||||
SessionKey string `json:"session_key"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Context InboundContext `json:"context"`
|
||||
Sender SenderInfo `json:"sender"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
|
||||
SessionKey string `json:"session_key"`
|
||||
|
||||
// Convenience mirrors derived from Context for runtime consumers.
|
||||
Channel string `json:"channel"`
|
||||
SenderID string `json:"sender_id"`
|
||||
ChatID string `json:"chat_id"`
|
||||
MessageID string `json:"message_id,omitempty"` // platform message ID
|
||||
}
|
||||
|
||||
// OutboundScope captures the structured session scope associated with an
|
||||
// outbound turn result without depending on the session package.
|
||||
type OutboundScope struct {
|
||||
Version int `json:"version,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
Account string `json:"account,omitempty"`
|
||||
Dimensions []string `json:"dimensions,omitempty"`
|
||||
Values map[string]string `json:"values,omitempty"`
|
||||
}
|
||||
|
||||
type OutboundMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Context InboundContext `json:"context"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
Scope *OutboundScope `json:"scope,omitempty"`
|
||||
Content string `json:"content"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
}
|
||||
|
||||
// MediaPart describes a single media attachment to send.
|
||||
@@ -48,9 +83,13 @@ type MediaPart struct {
|
||||
|
||||
// OutboundMediaMessage carries media attachments from Agent to channels via the bus.
|
||||
type OutboundMediaMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Parts []MediaPart `json:"parts"`
|
||||
Channel string `json:"channel"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Context InboundContext `json:"context"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
Scope *OutboundScope `json:"scope,omitempty"`
|
||||
Parts []MediaPart `json:"parts"`
|
||||
}
|
||||
|
||||
// AudioChunk represents a chunk of streaming voice data.
|
||||
|
||||
+39
-19
@@ -260,12 +260,11 @@ func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *BaseChannel) HandleMessage(
|
||||
func (c *BaseChannel) HandleMessageWithContext(
|
||||
ctx context.Context,
|
||||
peer bus.Peer,
|
||||
messageID, senderID, chatID, content string,
|
||||
deliveryChatID, content string,
|
||||
media []string,
|
||||
metadata map[string]string,
|
||||
inboundCtx bus.InboundContext,
|
||||
senderOpts ...bus.SenderInfo,
|
||||
) {
|
||||
// Use SenderInfo-based allow check when available, else fall back to string
|
||||
@@ -273,6 +272,7 @@ func (c *BaseChannel) HandleMessage(
|
||||
if len(senderOpts) > 0 {
|
||||
sender = senderOpts[0]
|
||||
}
|
||||
senderID := strings.TrimSpace(inboundCtx.SenderID)
|
||||
if sender.CanonicalID != "" || sender.PlatformID != "" {
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return
|
||||
@@ -289,20 +289,28 @@ func (c *BaseChannel) HandleMessage(
|
||||
resolvedSenderID = sender.CanonicalID
|
||||
}
|
||||
|
||||
scope := BuildMediaScope(c.name, chatID, messageID)
|
||||
if resolvedSenderID == "" {
|
||||
resolvedSenderID = senderID
|
||||
}
|
||||
|
||||
inboundCtx.Channel = c.name
|
||||
if inboundCtx.ChatID == "" {
|
||||
inboundCtx.ChatID = deliveryChatID
|
||||
}
|
||||
if inboundCtx.SenderID == "" {
|
||||
inboundCtx.SenderID = resolvedSenderID
|
||||
}
|
||||
|
||||
scope := BuildMediaScope(c.name, deliveryChatID, inboundCtx.MessageID)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: c.name,
|
||||
SenderID: resolvedSenderID,
|
||||
Context: inboundCtx,
|
||||
Sender: sender,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Media: media,
|
||||
Peer: peer,
|
||||
MessageID: messageID,
|
||||
MediaScope: scope,
|
||||
Metadata: metadata,
|
||||
}
|
||||
msg = bus.NormalizeInboundMessage(msg)
|
||||
|
||||
// Auto-trigger typing indicator, message reaction, and placeholder before publishing.
|
||||
// Each capability is independent — all three may fire for the same message.
|
||||
@@ -313,14 +321,14 @@ func (c *BaseChannel) HandleMessage(
|
||||
if c.owner != nil && c.placeholderRecorder != nil {
|
||||
// Typing
|
||||
if tc, ok := c.owner.(TypingCapable); ok {
|
||||
if stop, err := tc.StartTyping(ctx, chatID); err == nil {
|
||||
c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
|
||||
if stop, err := tc.StartTyping(ctx, deliveryChatID); err == nil {
|
||||
c.placeholderRecorder.RecordTypingStop(c.name, deliveryChatID, stop)
|
||||
}
|
||||
}
|
||||
// Reaction
|
||||
if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
|
||||
if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
|
||||
c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
|
||||
if rc, ok := c.owner.(ReactionCapable); ok && msg.MessageID != "" {
|
||||
if undo, err := rc.ReactToMessage(ctx, deliveryChatID, msg.MessageID); err == nil {
|
||||
c.placeholderRecorder.RecordReactionUndo(c.name, deliveryChatID, undo)
|
||||
}
|
||||
}
|
||||
// Placeholder — independent pipeline.
|
||||
@@ -329,8 +337,8 @@ func (c *BaseChannel) HandleMessage(
|
||||
// "Thinking…" only once the voice has been processed.
|
||||
if !audioAnnotationRe.MatchString(content) {
|
||||
if pc, ok := c.owner.(PlaceholderCapable); ok {
|
||||
if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" {
|
||||
c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID)
|
||||
if phID, err := pc.SendPlaceholder(ctx, deliveryChatID); err == nil && phID != "" {
|
||||
c.placeholderRecorder.RecordPlaceholder(c.name, deliveryChatID, phID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -339,12 +347,24 @@ func (c *BaseChannel) HandleMessage(
|
||||
if err := c.bus.PublishInbound(ctx, msg); err != nil {
|
||||
logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{
|
||||
"channel": c.name,
|
||||
"chat_id": chatID,
|
||||
"chat_id": deliveryChatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleInboundContext publishes a normalized inbound message using only the
|
||||
// structured context.
|
||||
func (c *BaseChannel) HandleInboundContext(
|
||||
ctx context.Context,
|
||||
deliveryChatID, content string,
|
||||
media []string,
|
||||
inboundCtx bus.InboundContext,
|
||||
senderOpts ...bus.SenderInfo,
|
||||
) {
|
||||
c.HandleMessageWithContext(ctx, deliveryChatID, content, media, inboundCtx, senderOpts...)
|
||||
}
|
||||
|
||||
func (c *BaseChannel) SetRunning(running bool) {
|
||||
c.running.Store(running)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -263,3 +264,58 @@ func TestIsAllowedSender(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleInboundContext_PublishesNormalizedContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inbound bus.InboundContext
|
||||
wantChat string
|
||||
wantSender string
|
||||
}{
|
||||
{
|
||||
name: "direct uses sender as peer",
|
||||
inbound: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-1",
|
||||
MessageID: "msg-1",
|
||||
},
|
||||
wantChat: "chat-1",
|
||||
wantSender: "user-1",
|
||||
},
|
||||
{
|
||||
name: "group uses chat as peer",
|
||||
inbound: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "group-1",
|
||||
ChatType: "group",
|
||||
SenderID: "user-2",
|
||||
MessageID: "msg-2",
|
||||
},
|
||||
wantChat: "group-1",
|
||||
wantSender: "user-2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
defer msgBus.Close()
|
||||
|
||||
ch := NewBaseChannel("test", nil, msgBus, nil)
|
||||
ch.HandleInboundContext(context.Background(), tt.inbound.ChatID, "hello", nil, tt.inbound)
|
||||
|
||||
msg := <-msgBus.InboundChan()
|
||||
if msg.ChatID != tt.wantChat {
|
||||
t.Fatalf("ChatID = %q, want %q", msg.ChatID, tt.wantChat)
|
||||
}
|
||||
if msg.SenderID != tt.wantSender {
|
||||
t.Fatalf("SenderID = %q, want %q", msg.SenderID, tt.wantSender)
|
||||
}
|
||||
if msg.Context.ChatType != tt.inbound.ChatType {
|
||||
t.Fatalf("ChatType = %q, want %q", msg.Context.ChatType, tt.inbound.ChatType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,16 +185,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
var (
|
||||
chatType string
|
||||
isMentioned bool
|
||||
)
|
||||
if data.ConversationType == "1" {
|
||||
peerID := senderID
|
||||
if peerID == "" {
|
||||
peerID = chatID
|
||||
}
|
||||
peer = bus.Peer{Kind: "direct", ID: peerID}
|
||||
chatType = "direct"
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
|
||||
isMentioned := data.IsInAtList
|
||||
chatType = "group"
|
||||
isMentioned = data.IsInAtList
|
||||
if isMentioned {
|
||||
content = stripLeadingAtMentions(content)
|
||||
}
|
||||
@@ -232,8 +231,21 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "dingtalk",
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
SenderID: resolvedSenderID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if data.SessionWebhook != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, chatID, content, nil, inboundCtx, sender)
|
||||
|
||||
// Return nil to indicate we've handled the message asynchronously
|
||||
// The response will be sent through the message bus
|
||||
|
||||
@@ -75,8 +75,8 @@ func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention
|
||||
if inbound.ChatID != "group-abc" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
if inbound.Context.ChatType != "group" {
|
||||
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
|
||||
}
|
||||
if inbound.Content != "/help" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
@@ -103,12 +103,15 @@ func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *te
|
||||
if inbound.ChatID != "conv-direct-42" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
if inbound.Context.ChatType != "direct" {
|
||||
t.Fatalf("chat_type=%q", inbound.Context.ChatType)
|
||||
}
|
||||
if inbound.SenderID != "dingtalk:openid-user-42" {
|
||||
if inbound.SenderID != "openid-user-42" {
|
||||
t.Fatalf("sender_id=%q", inbound.SenderID)
|
||||
}
|
||||
if inbound.Sender.CanonicalID != "dingtalk:openid-user-42" {
|
||||
t.Fatalf("sender canonical_id=%q", inbound.Sender.CanonicalID)
|
||||
}
|
||||
|
||||
if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok {
|
||||
t.Fatal("expected session webhook keyed by conversation_id")
|
||||
|
||||
@@ -408,8 +408,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
|
||||
// In guild (group) channels, apply unified group trigger filtering
|
||||
// DMs (GuildID is empty) always get a response
|
||||
isMentioned := false
|
||||
if m.GuildID != "" {
|
||||
isMentioned := false
|
||||
for _, mention := range m.Mentions {
|
||||
if mention.ID == c.botUserID {
|
||||
isMentioned = true
|
||||
@@ -506,14 +506,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
})
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := m.ChannelID
|
||||
if m.GuildID == "" {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"user_id": senderID,
|
||||
"username": m.Author.Username,
|
||||
@@ -522,8 +518,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
"channel_id": m.ChannelID,
|
||||
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: m.ChannelID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: m.ID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if m.GuildID != "" {
|
||||
inboundCtx.SpaceID = m.GuildID
|
||||
inboundCtx.SpaceType = "guild"
|
||||
}
|
||||
if m.MessageReference != nil {
|
||||
inboundCtx.ReplyToMessageID = m.MessageReference.MessageID
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender)
|
||||
c.HandleInboundContext(c.ctx, m.ChannelID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// startTyping starts a continuous typing indicator loop for the given chatID.
|
||||
|
||||
@@ -445,17 +445,23 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
// Append media tags to content (like Telegram does)
|
||||
content = appendMediaTags(content, messageType, mediaRefs)
|
||||
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
chatType := stringValue(message.ChatType)
|
||||
metadata := buildInboundMetadata(message, sender)
|
||||
|
||||
var peer bus.Peer
|
||||
var (
|
||||
inboundChatType string
|
||||
isMentioned bool
|
||||
)
|
||||
if chatType == "p2p" {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
inboundChatType = "direct"
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
inboundChatType = "group"
|
||||
|
||||
// Check if bot was mentioned
|
||||
isMentioned := c.isBotMentioned(message)
|
||||
isMentioned = c.isBotMentioned(message)
|
||||
|
||||
// Strip mention placeholders from content before group trigger check
|
||||
if len(message.Mentions) > 0 {
|
||||
@@ -490,7 +496,21 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
"thread_id": stringValue(message.ThreadId),
|
||||
})
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "feishu",
|
||||
ChatID: chatID,
|
||||
ChatType: inboundChatType,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
|
||||
inboundCtx.SpaceType = "tenant"
|
||||
inboundCtx.SpaceID = *sender.TenantKey
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, chatID, content, mediaRefs, inboundCtx, senderInfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -51,14 +51,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&")
|
||||
|
||||
var chatID string
|
||||
var peer bus.Peer
|
||||
|
||||
if isDM {
|
||||
chatID = nick
|
||||
peer = bus.Peer{Kind: "direct", ID: nick}
|
||||
} else {
|
||||
chatID = target
|
||||
peer = bus.Peer{Kind: "group", ID: target}
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
@@ -73,9 +70,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
return
|
||||
}
|
||||
|
||||
isMentioned := false
|
||||
|
||||
// For channel messages, check group trigger (mention detection)
|
||||
if !isDM {
|
||||
isMentioned := isBotMentioned(content, currentNick)
|
||||
isMentioned = isBotMentioned(content, currentNick)
|
||||
if isMentioned {
|
||||
content = stripBotMention(content, currentNick)
|
||||
}
|
||||
@@ -100,7 +99,21 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
metadata["channel"] = target
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "irc",
|
||||
ChatID: chatID,
|
||||
SenderID: nick,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if isDM {
|
||||
inboundCtx.ChatType = "direct"
|
||||
} else {
|
||||
inboundCtx.ChatType = "group"
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// nickMentionedAt returns the byte index where botNick is mentioned in content
|
||||
|
||||
@@ -354,8 +354,9 @@ func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
}
|
||||
|
||||
// In group chats, apply unified group trigger filtering
|
||||
isMentioned := false
|
||||
if isGroup {
|
||||
isMentioned := c.isBotMentioned(msg)
|
||||
isMentioned = c.isBotMentioned(msg)
|
||||
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
|
||||
if !respond {
|
||||
logger.DebugCF("line", "Ignoring group message by group trigger", map[string]any{
|
||||
@@ -371,13 +372,6 @@ func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
"source_type": event.Source.Type,
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
if isGroup {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
}
|
||||
|
||||
logger.DebugCF("line", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
@@ -396,7 +390,25 @@ func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: chatID,
|
||||
ChatType: map[bool]string{true: "group", false: "direct"}[isGroup],
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if event.ReplyToken != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"reply_token": event.ReplyToken,
|
||||
}
|
||||
if msg.QuoteToken != "" {
|
||||
inboundCtx.ReplyHandles["quote_token"] = msg.QuoteToken
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message.
|
||||
|
||||
@@ -200,17 +200,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: "channel", ID: "default"},
|
||||
"",
|
||||
senderID,
|
||||
chatID,
|
||||
content,
|
||||
[]string{},
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "maixcam",
|
||||
ChatID: chatID,
|
||||
ChatType: "channel",
|
||||
SenderID: senderID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
|
||||
|
||||
+47
-23
@@ -98,6 +98,22 @@ type asyncTask struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func outboundMessageChannel(msg bus.OutboundMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
|
||||
func outboundMessageChatID(msg bus.OutboundMessage) string {
|
||||
return msg.ChatID
|
||||
}
|
||||
|
||||
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
|
||||
func outboundMediaChatID(msg bus.OutboundMediaMessage) string {
|
||||
return msg.ChatID
|
||||
}
|
||||
|
||||
// RecordPlaceholder registers a placeholder message for later editing.
|
||||
// Implements PlaceholderRecorder.
|
||||
func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) {
|
||||
@@ -161,7 +177,8 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
|
||||
// preSend handles typing stop, reaction undo, and placeholder editing before sending a message.
|
||||
// Returns the delivered message IDs and true when delivery completed before a normal Send.
|
||||
func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) {
|
||||
key := name + ":" + msg.ChatID
|
||||
chatID := outboundMessageChatID(msg)
|
||||
key := name + ":" + chatID
|
||||
|
||||
// 1. Stop typing
|
||||
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
|
||||
@@ -183,9 +200,9 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
// Prefer deleting the placeholder (cleaner UX than editing to same content)
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
|
||||
} else if editor, ok := ch.(MessageEditor); ok {
|
||||
editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback
|
||||
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,7 +213,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if editor, ok := ch.(MessageEditor); ok {
|
||||
if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil {
|
||||
if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil {
|
||||
return []string{entry.id}, true
|
||||
}
|
||||
// edit failed → fall through to normal Send
|
||||
@@ -212,7 +229,8 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
// delivery never edits the placeholder because there is no text payload to
|
||||
// replace it with; it only attempts to delete the placeholder when possible.
|
||||
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
|
||||
key := name + ":" + msg.ChatID
|
||||
chatID := outboundMediaChatID(msg)
|
||||
key := name + ":" + chatID
|
||||
|
||||
// 1. Stop typing
|
||||
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
|
||||
@@ -235,7 +253,7 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -820,7 +838,7 @@ func (m *Manager) sendWithRetry(
|
||||
// All retries exhausted or permanent failure
|
||||
logger.ErrorCF("channels", "Send failed", map[string]any{
|
||||
"channel": name,
|
||||
"chat_id": msg.ChatID,
|
||||
"chat_id": outboundMessageChatID(msg),
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
@@ -882,7 +900,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.OutboundChan(),
|
||||
func(msg bus.OutboundMessage) string { return msg.Channel },
|
||||
func(msg bus.OutboundMessage) string { return outboundMessageChannel(msg) },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
@@ -902,7 +920,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.OutboundMediaChan(),
|
||||
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
|
||||
func(msg bus.OutboundMediaMessage) string { return outboundMediaChannel(msg) },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
@@ -1001,7 +1019,7 @@ func (m *Manager) sendMediaWithRetry(
|
||||
// All retries exhausted or permanent failure
|
||||
logger.ErrorCF("channels", "SendMedia failed", map[string]any{
|
||||
"channel": name,
|
||||
"chat_id": msg.ChatID,
|
||||
"chat_id": outboundMediaChatID(msg),
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
@@ -1200,16 +1218,19 @@ func (m *Manager) UnregisterChannel(name string) {
|
||||
// delivered (or all retries are exhausted), which preserves ordering when
|
||||
// a subsequent operation depends on the message having been sent.
|
||||
func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
msg = bus.NormalizeOutboundMessage(msg)
|
||||
channelName := outboundMessageChannel(msg)
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channelName]
|
||||
w, wExists := m.workers[channelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("channel %s not found", msg.Channel)
|
||||
return fmt.Errorf("channel %s not found", channelName)
|
||||
}
|
||||
if !wExists || w == nil {
|
||||
return fmt.Errorf("channel %s has no active worker", msg.Channel)
|
||||
return fmt.Errorf("channel %s has no active worker", channelName)
|
||||
}
|
||||
|
||||
maxLen := 0
|
||||
@@ -1220,10 +1241,10 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
for _, chunk := range SplitMessage(msg.Content, maxLen) {
|
||||
chunkMsg := msg
|
||||
chunkMsg.Content = chunk
|
||||
m.sendWithRetry(ctx, msg.Channel, w, chunkMsg)
|
||||
m.sendWithRetry(ctx, channelName, w, chunkMsg)
|
||||
}
|
||||
} else {
|
||||
m.sendWithRetry(ctx, msg.Channel, w, msg)
|
||||
m.sendWithRetry(ctx, channelName, w, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1233,19 +1254,22 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
// retries are exhausted), which preserves ordering when later agent behavior
|
||||
// depends on actual media delivery.
|
||||
func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
msg = bus.NormalizeOutboundMediaMessage(msg)
|
||||
channelName := outboundMediaChannel(msg)
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channelName]
|
||||
w, wExists := m.workers[channelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("channel %s not found", msg.Channel)
|
||||
return fmt.Errorf("channel %s not found", channelName)
|
||||
}
|
||||
if !wExists || w == nil {
|
||||
return fmt.Errorf("channel %s has no active worker", msg.Channel)
|
||||
return fmt.Errorf("channel %s has no active worker", channelName)
|
||||
}
|
||||
|
||||
_, err := m.sendMediaWithRetry(ctx, msg.Channel, w, msg)
|
||||
_, err := m.sendMediaWithRetry(ctx, channelName, w, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1260,10 +1284,10 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
Channel: channelName,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channelName, chatID, ""),
|
||||
Content: content,
|
||||
}
|
||||
msg = bus.NormalizeOutboundMessage(msg)
|
||||
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
|
||||
+138
-47
@@ -175,11 +175,11 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
if err := m.bus.PublishOutbound(pubCtx, testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "good",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}); err != nil {
|
||||
})); err != nil {
|
||||
t.Fatalf("PublishOutbound() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -197,6 +197,20 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
|
||||
}
|
||||
return bus.NormalizeOutboundMessage(msg)
|
||||
}
|
||||
|
||||
func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, "")
|
||||
}
|
||||
return bus.NormalizeOutboundMediaMessage(msg)
|
||||
}
|
||||
|
||||
func TestSendWithRetry_Success(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
@@ -212,7 +226,7 @@ func TestSendWithRetry_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -239,7 +253,7 @@ func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -263,7 +277,7 @@ func TestSendWithRetry_PermanentFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -287,7 +301,7 @@ func TestSendWithRetry_NotRunning(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -314,7 +328,7 @@ func TestSendWithRetry_RateLimitRetry(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
@@ -344,7 +358,7 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -370,11 +384,11 @@ func TestSendMedia_Success(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
@@ -397,11 +411,11 @@ func TestSendMedia_PropagatesFailure(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error")
|
||||
}
|
||||
@@ -424,11 +438,11 @@ func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error for unsupported channel")
|
||||
}
|
||||
@@ -454,11 +468,11 @@ func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) {
|
||||
m.workers["test"] = w
|
||||
m.RecordPlaceholder("test", "chat1", "placeholder-1")
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
@@ -491,7 +505,7 @@ func TestSendWithRetry_UnknownError(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -515,7 +529,7 @@ func TestSendWithRetry_ContextCancelled(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
// Cancel context after first Send attempt returns
|
||||
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
@@ -561,7 +575,7 @@ func TestWorkerRateLimiter(t *testing.T) {
|
||||
|
||||
// Enqueue 4 messages
|
||||
for i := range 4 {
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
|
||||
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)})
|
||||
}
|
||||
|
||||
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
|
||||
@@ -637,7 +651,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
// Send a message that should be split
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
|
||||
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -678,7 +692,7 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
@@ -738,7 +752,7 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
|
||||
// Register placeholder
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !edited {
|
||||
@@ -768,7 +782,7 @@ func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) {
|
||||
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if edited {
|
||||
@@ -827,7 +841,7 @@ func TestPreSend_TypingStopCalled(t *testing.T) {
|
||||
stopCalled = true
|
||||
})
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -844,7 +858,7 @@ func TestPreSend_NoRegisteredState(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if edited {
|
||||
@@ -874,7 +888,7 @@ func TestPreSend_TypingAndPlaceholder(t *testing.T) {
|
||||
})
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -938,7 +952,7 @@ func TestRecordTypingStop_ReplacesExistingStop(t *testing.T) {
|
||||
t.Fatalf("expected replacement typing stop to stay active until preSend, got %d calls", newStopCalls)
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.preSend(context.Background(), "test", msg, &mockChannel{})
|
||||
|
||||
if newStopCalls != 1 {
|
||||
@@ -972,7 +986,7 @@ func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) {
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.sendWithRetry(context.Background(), "test", w, msg)
|
||||
|
||||
if sendCalled {
|
||||
@@ -1135,7 +1149,7 @@ func TestPreSendStillWorksWithWrappedTypes(t *testing.T) {
|
||||
})
|
||||
m.RecordPlaceholder("test", "chat1", "ph_id")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -1238,11 +1252,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
|
||||
|
||||
// Transcription feedback arrives first — it should consume the placeholder
|
||||
// and be delivered via EditMessage, not Send.
|
||||
msgTranscript := bus.OutboundMessage{
|
||||
msgTranscript := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "mock",
|
||||
ChatID: "chat-1",
|
||||
Content: "Transcript: hello",
|
||||
}
|
||||
})
|
||||
mgr.sendWithRetry(ctx, "mock", worker, msgTranscript)
|
||||
|
||||
if mockCh.editedMessages != 1 {
|
||||
@@ -1258,11 +1272,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
// Final LLM response arrives — no placeholder left, so it goes through Send
|
||||
msgFinal := bus.OutboundMessage{
|
||||
msgFinal := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "mock",
|
||||
ChatID: "chat-1",
|
||||
Content: "Final Answer",
|
||||
}
|
||||
})
|
||||
mgr.sendWithRetry(ctx, "mock", worker, msgFinal)
|
||||
|
||||
if len(mockCh.sentMessages) != 1 {
|
||||
@@ -1288,12 +1302,12 @@ func TestSendMessage_Synchronous(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello world",
|
||||
ReplyToMessageID: "msg-456",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1315,11 +1329,11 @@ func TestSendMessage_Synchronous(t *testing.T) {
|
||||
func TestSendMessage_UnknownChannel(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "nonexistent",
|
||||
ChatID: "123",
|
||||
Content: "hello",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err == nil {
|
||||
@@ -1336,11 +1350,11 @@ func TestSendMessage_NoWorker(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
// No worker registered
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err == nil {
|
||||
@@ -1369,11 +1383,11 @@ func TestSendMessage_WithRetry(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "retry me",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1385,6 +1399,46 @@ func TestSendMessage_WithRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_ContextOnlyUsesContextAddressing(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var received []bus.OutboundMessage
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
|
||||
received = append(received, msg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Context: bus.NewOutboundContext("test", "123", "msg-9"),
|
||||
Content: "hello",
|
||||
})
|
||||
|
||||
if err := m.SendMessage(context.Background(), msg); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 message sent, got %d", len(received))
|
||||
}
|
||||
if received[0].Channel != "test" || received[0].ChatID != "123" {
|
||||
t.Fatalf("expected mirrored legacy address, got %+v", received[0])
|
||||
}
|
||||
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "123" {
|
||||
t.Fatalf("expected context address to be preserved, got %+v", received[0].Context)
|
||||
}
|
||||
if received[0].ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected reply_to_message_id msg-9, got %q", received[0].ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
@@ -1406,11 +1460,11 @@ func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello world",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1422,6 +1476,43 @@ func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_ContextOnlyUsesContextAddressing(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var received []bus.OutboundMediaMessage
|
||||
ch := &mockMediaChannel{
|
||||
sendMediaFn: func(_ context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
received = append(received, msg)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Context: bus.NewOutboundContext("test", "media-chat", ""),
|
||||
Parts: []bus.MediaPart{{Type: "image", Ref: "media://1"}},
|
||||
})
|
||||
|
||||
if err := m.SendMedia(context.Background(), msg); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 media message sent, got %d", len(received))
|
||||
}
|
||||
if received[0].Channel != "test" || received[0].ChatID != "media-chat" {
|
||||
t.Fatalf("expected mirrored legacy media address, got %+v", received[0])
|
||||
}
|
||||
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "media-chat" {
|
||||
t.Fatalf("expected media context address to be preserved, got %+v", received[0].Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_PreservesOrdering(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
@@ -1441,12 +1532,12 @@ func TestSendMessage_PreservesOrdering(t *testing.T) {
|
||||
m.workers["test"] = w
|
||||
|
||||
// Send two messages sequentially — they must arrive in order
|
||||
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
|
||||
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test", ChatID: "1", Content: "first",
|
||||
})
|
||||
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
|
||||
}))
|
||||
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test", ChatID: "1", Content: "second",
|
||||
})
|
||||
}))
|
||||
|
||||
if len(order) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(order))
|
||||
|
||||
@@ -739,10 +739,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := senderID
|
||||
if isGroup {
|
||||
peerKind = "group"
|
||||
peerID = roomID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -755,17 +753,19 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
|
||||
metadata["reply_to_msg_id"] = replyTo.String()
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.baseContext(),
|
||||
bus.Peer{Kind: peerKind, ID: peerID},
|
||||
evt.ID.String(),
|
||||
senderID,
|
||||
roomID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "matrix",
|
||||
ChatID: roomID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: evt.ID.String(),
|
||||
Raw: metadata,
|
||||
}
|
||||
if replyTo := msgEvt.GetRelatesTo().GetReplyTo(); replyTo != "" {
|
||||
inboundCtx.ReplyToMessageID = replyTo.String()
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.baseContext(), roomID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// decryptEvent decrypts an encrypted event and returns the decrypted message event content.
|
||||
|
||||
@@ -995,8 +995,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
|
||||
senderID := strconv.FormatInt(userID, 10)
|
||||
var chatID string
|
||||
|
||||
var peer bus.Peer
|
||||
var contextChatID string
|
||||
var contextChatType string
|
||||
|
||||
metadata := map[string]string{}
|
||||
|
||||
@@ -1007,12 +1007,14 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
switch raw.MessageType {
|
||||
case "private":
|
||||
chatID = "private:" + senderID
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
contextChatID = senderID
|
||||
contextChatType = "direct"
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(groupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
peer = bus.Peer{Kind: "group", ID: groupIDStr}
|
||||
contextChatID = groupIDStr
|
||||
contextChatType = "group"
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(sender.UserID)
|
||||
@@ -1076,7 +1078,18 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: contextChatID,
|
||||
ChatType: contextChatType,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isBotMentioned,
|
||||
ReplyToMessageID: parsed.ReplyTo,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, parsed.Media, inboundCtx, senderInfo)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
|
||||
@@ -259,8 +259,6 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
|
||||
chatID := "pico_client:" + sessionID
|
||||
senderID := "pico-remote"
|
||||
peer := bus.Peer{Kind: "direct", ID: chatID}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "pico_client",
|
||||
PlatformID: senderID,
|
||||
@@ -271,10 +269,19 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{
|
||||
"platform": "pico_client",
|
||||
"session_id": sessionID,
|
||||
}, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "pico_client",
|
||||
ChatID: chatID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Raw: map[string]string{
|
||||
"platform": "pico_client",
|
||||
"session_id": sessionID,
|
||||
},
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// Send sends a message to the remote server.
|
||||
|
||||
@@ -39,11 +39,11 @@ var allowedInlineImageMIMETypes = map[string]struct{}{
|
||||
"image/bmp": {},
|
||||
}
|
||||
|
||||
func outboundMessageIsThought(metadata map[string]string) bool {
|
||||
if len(metadata) == 0 {
|
||||
func outboundMessageIsThought(msg bus.OutboundMessage) bool {
|
||||
if len(msg.Context.Raw) == 0 {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(metadata["message_kind"]), MessageKindThought)
|
||||
return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindThought)
|
||||
}
|
||||
|
||||
// writeJSON sends a JSON message to the connection with write locking.
|
||||
@@ -260,7 +260,7 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
isThought := outboundMessageIsThought(msg.Metadata)
|
||||
isThought := outboundMessageIsThought(msg)
|
||||
|
||||
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
||||
PayloadKeyContent: msg.Content,
|
||||
@@ -578,8 +578,6 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
||||
chatID := "pico:" + sessionID
|
||||
senderID := "pico-user"
|
||||
|
||||
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"platform": "pico",
|
||||
"session_id": sessionID,
|
||||
@@ -602,7 +600,16 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, media, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "pico",
|
||||
ChatID: chatID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, media, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// truncate truncates a string to maxLen runes.
|
||||
|
||||
+31
-20
@@ -590,12 +590,22 @@ func qqFileType(partType string) uint64 {
|
||||
}
|
||||
|
||||
func (c *QQChannel) maxBase64FileSizeBytes() int64 {
|
||||
if c.config == nil {
|
||||
return 0
|
||||
}
|
||||
if c.config.MaxBase64FileSizeMiB <= 0 {
|
||||
return 0
|
||||
}
|
||||
return c.config.MaxBase64FileSizeMiB * bytesPerMiB
|
||||
}
|
||||
|
||||
func (c *QQChannel) accountID() string {
|
||||
if c.config == nil {
|
||||
return ""
|
||||
}
|
||||
return c.config.AppID
|
||||
}
|
||||
|
||||
// handleC2CMessage handles QQ private messages.
|
||||
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
|
||||
@@ -649,17 +659,17 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
metadata := map[string]string{
|
||||
"account_id": senderID,
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.accountID(),
|
||||
ChatID: senderID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: data.ID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
bus.Peer{Kind: "direct", ID: senderID},
|
||||
data.ID,
|
||||
senderID,
|
||||
senderID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, senderID, content, mediaPaths, inboundCtx, sender)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -727,17 +737,18 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
"account_id": senderID,
|
||||
"group_id": data.GroupID,
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.accountID(),
|
||||
ChatID: data.GroupID,
|
||||
ChatType: "group",
|
||||
SenderID: senderID,
|
||||
MessageID: data.ID,
|
||||
Mentioned: true,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
bus.Peer{Kind: "group", ID: data.GroupID},
|
||||
data.ID,
|
||||
senderID,
|
||||
data.GroupID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, data.GroupID, content, mediaPaths, inboundCtx, sender)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,8 +54,8 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message")
|
||||
}
|
||||
if inbound.Metadata["account_id"] != "7750283E123456" {
|
||||
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
|
||||
if inbound.Context.Raw["account_id"] != "7750283E123456" {
|
||||
t.Fatalf("account_id raw = %q, want %q", inbound.Context.Raw["account_id"], "7750283E123456")
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -165,8 +165,8 @@ func TestHandleGroupATMessage_AttachmentOnlyPublishesMedia(t *testing.T) {
|
||||
if !strings.HasPrefix(inbound.Media[0], "media://") {
|
||||
t.Fatalf("inbound.Media[0] = %q, want media:// ref", inbound.Media[0])
|
||||
}
|
||||
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-1" {
|
||||
t.Fatalf("inbound.Peer = %+v, want group/group-1", inbound.Peer)
|
||||
if inbound.Context.ChatType != "group" {
|
||||
t.Fatalf("inbound.Context.ChatType = %q, want group", inbound.Context.ChatType)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+83
-28
@@ -117,7 +117,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
channelID, threadTS := parseSlackChatID(msg.ChatID)
|
||||
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if channelID == "" {
|
||||
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str
|
||||
return nil, fmt.Errorf("slack send: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
|
||||
if ref, ok := c.pendingAcks.LoadAndDelete(deliveryChatID); ok {
|
||||
msgRef := ref.(slackMessageRef)
|
||||
c.api.AddReaction("white_check_mark", slack.ItemRef{
|
||||
Channel: msgRef.ChannelID,
|
||||
@@ -161,7 +161,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
channelID, _ := parseSlackChatID(msg.ChatID)
|
||||
_, channelID, threadTS := resolveSlackMediaOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if channelID == "" {
|
||||
return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
|
||||
}
|
||||
@@ -192,10 +192,11 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa
|
||||
}
|
||||
|
||||
_, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{
|
||||
Channel: channelID,
|
||||
File: localPath,
|
||||
Filename: filename,
|
||||
Title: title,
|
||||
Channel: channelID,
|
||||
ThreadTimestamp: threadTS,
|
||||
File: localPath,
|
||||
Filename: filename,
|
||||
Title: title,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Failed to upload media", map[string]any{
|
||||
@@ -360,14 +361,10 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
}
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
@@ -383,7 +380,22 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
"has_thread": threadTS != "",
|
||||
})
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.teamID,
|
||||
ChatID: channelID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: messageTS,
|
||||
SpaceID: c.teamID,
|
||||
SpaceType: "workspace",
|
||||
Raw: metadata,
|
||||
}
|
||||
if threadTS != "" {
|
||||
inboundCtx.TopicID = threadTS
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
@@ -431,14 +443,10 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
}
|
||||
|
||||
mentionPeerKind := "channel"
|
||||
mentionPeerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
mentionPeerKind = "direct"
|
||||
mentionPeerID = senderID
|
||||
}
|
||||
|
||||
mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
@@ -447,8 +455,21 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
"is_mention": "true",
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.teamID,
|
||||
ChatID: channelID,
|
||||
ChatType: mentionPeerKind,
|
||||
TopicID: threadTS,
|
||||
SenderID: senderID,
|
||||
MessageID: messageTS,
|
||||
SpaceID: c.teamID,
|
||||
SpaceType: "workspace",
|
||||
Mentioned: true,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, mentionSender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
@@ -495,18 +516,22 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
"command": cmd.Command,
|
||||
"text": utils.Truncate(content, 50),
|
||||
})
|
||||
peerKind := "channel"
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: c.teamID,
|
||||
ChatID: channelID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
SpaceID: c.teamID,
|
||||
SpaceType: "workspace",
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: "channel", ID: channelID},
|
||||
"",
|
||||
senderID,
|
||||
chatID,
|
||||
content,
|
||||
nil,
|
||||
metadata,
|
||||
cmdSender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, cmdSender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||
@@ -541,3 +566,33 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
|
||||
}
|
||||
return channelID, threadTS
|
||||
}
|
||||
|
||||
func resolveSlackOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
|
||||
deliveryChatID := strings.TrimSpace(chatID)
|
||||
if deliveryChatID == "" && outboundCtx != nil {
|
||||
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
|
||||
}
|
||||
channelID, threadTS := parseSlackChatID(deliveryChatID)
|
||||
if threadTS == "" && outboundCtx != nil {
|
||||
threadTS = strings.TrimSpace(outboundCtx.TopicID)
|
||||
if threadTS != "" && channelID != "" {
|
||||
deliveryChatID = channelID + "/" + threadTS
|
||||
}
|
||||
}
|
||||
return deliveryChatID, channelID, threadTS
|
||||
}
|
||||
|
||||
func resolveSlackMediaOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) {
|
||||
deliveryChatID := strings.TrimSpace(chatID)
|
||||
if deliveryChatID == "" && outboundCtx != nil {
|
||||
deliveryChatID = strings.TrimSpace(outboundCtx.ChatID)
|
||||
}
|
||||
channelID, threadTS := parseSlackChatID(deliveryChatID)
|
||||
if threadTS == "" && outboundCtx != nil {
|
||||
threadTS = strings.TrimSpace(outboundCtx.TopicID)
|
||||
if threadTS != "" && channelID != "" {
|
||||
deliveryChatID = channelID + "/" + threadTS
|
||||
}
|
||||
}
|
||||
return deliveryChatID, channelID, threadTS
|
||||
}
|
||||
|
||||
@@ -53,6 +53,24 @@ func TestParseSlackChatID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSlackOutboundTarget_PrefersContextTopicID(t *testing.T) {
|
||||
deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget("C123456", &bus.InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C123456",
|
||||
TopicID: "1234567890.123456",
|
||||
})
|
||||
|
||||
if deliveryChatID != "C123456/1234567890.123456" {
|
||||
t.Fatalf("deliveryChatID = %q, want %q", deliveryChatID, "C123456/1234567890.123456")
|
||||
}
|
||||
if channelID != "C123456" {
|
||||
t.Fatalf("channelID = %q, want %q", channelID, "C123456")
|
||||
}
|
||||
if threadTS != "1234567890.123456" {
|
||||
t.Fatalf("threadTS = %q, want %q", threadTS, "1234567890.123456")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBotMention(t *testing.T) {
|
||||
ch := &SlackChannel{botUserID: "U12345BOT"}
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]
|
||||
|
||||
useMarkdownV2 := c.tgCfg.UseMarkdownV2
|
||||
|
||||
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
|
||||
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
@@ -469,7 +469,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
|
||||
chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
@@ -697,8 +697,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
}
|
||||
|
||||
// In group chats, apply unified group trigger filtering
|
||||
isMentioned := false
|
||||
if message.Chat.Type != "private" {
|
||||
isMentioned := c.isBotMentioned(message)
|
||||
isMentioned = c.isBotMentioned(message)
|
||||
if isMentioned {
|
||||
content = c.stripBotMention(content)
|
||||
}
|
||||
@@ -744,13 +745,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
})
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := fmt.Sprintf("%d", user.ID)
|
||||
if message.Chat.Type != "private" {
|
||||
peerKind = "group"
|
||||
peerID = compositeChatID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
messageID := fmt.Sprintf("%d", message.MessageID)
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -759,24 +756,29 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
"first_name": user.FirstName,
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
}
|
||||
if message.ReplyToMessage != nil {
|
||||
metadata["reply_to_message_id"] = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
|
||||
}
|
||||
|
||||
// Set parent_peer metadata for per-topic agent binding.
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: fmt.Sprintf("%d", chatID),
|
||||
ChatType: peerKind,
|
||||
SenderID: platformID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if message.Chat.IsForum && threadID != 0 {
|
||||
metadata["parent_peer_kind"] = "topic"
|
||||
metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID)
|
||||
inboundCtx.TopicID = fmt.Sprintf("%d", threadID)
|
||||
}
|
||||
if message.ReplyToMessage != nil {
|
||||
inboundCtx.ReplyToMessageID = fmt.Sprintf("%d", message.ReplyToMessage.MessageID)
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
peer,
|
||||
messageID,
|
||||
platformID,
|
||||
c.HandleMessageWithContext(
|
||||
c.ctx,
|
||||
compositeChatID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
inboundCtx,
|
||||
sender,
|
||||
)
|
||||
return nil
|
||||
@@ -964,6 +966,28 @@ func parseTelegramChatID(chatID string) (int64, int, error) {
|
||||
return cid, tid, nil
|
||||
}
|
||||
|
||||
func resolveTelegramOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (int64, int, error) {
|
||||
targetChatID := strings.TrimSpace(chatID)
|
||||
if targetChatID == "" && outboundCtx != nil {
|
||||
targetChatID = strings.TrimSpace(outboundCtx.ChatID)
|
||||
}
|
||||
resolvedChatID, resolvedThreadID, err := parseTelegramChatID(targetChatID)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if resolvedThreadID != 0 || outboundCtx == nil {
|
||||
return resolvedChatID, resolvedThreadID, nil
|
||||
}
|
||||
topicID := strings.TrimSpace(outboundCtx.TopicID)
|
||||
if topicID == "" {
|
||||
return resolvedChatID, resolvedThreadID, nil
|
||||
}
|
||||
if threadID, convErr := strconv.Atoi(topicID); convErr == nil {
|
||||
return resolvedChatID, threadID, nil
|
||||
}
|
||||
return resolvedChatID, resolvedThreadID, nil
|
||||
}
|
||||
|
||||
func logParseFailed(err error, useMarkdownV2 bool) {
|
||||
parsingName := "HTML"
|
||||
if useMarkdownV2 {
|
||||
|
||||
@@ -528,6 +528,38 @@ func TestSend_WithForumThreadID(t *testing.T) {
|
||||
assert.Len(t, caller.calls, 1)
|
||||
}
|
||||
|
||||
func TestSend_UsesContextTopicIDWhenChatIDDoesNotIncludeThread(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
_, err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "-1001234567890",
|
||||
Content: "Hello from topic context",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "-1001234567890",
|
||||
TopicID: "42",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, caller.calls, 1)
|
||||
|
||||
var params struct {
|
||||
ChatID int64 `json:"chat_id"`
|
||||
MessageThreadID int `json:"message_thread_id"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, ¶ms))
|
||||
assert.Equal(t, int64(-1001234567890), params.ChatID)
|
||||
assert.Equal(t, 42, params.MessageThreadID)
|
||||
assert.Equal(t, "Hello from topic context", params.Text)
|
||||
}
|
||||
|
||||
func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
@@ -557,16 +589,10 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok, "expected inbound message")
|
||||
|
||||
// Composite chatID should include thread ID
|
||||
assert.Equal(t, "-1001234567890/42", inbound.ChatID)
|
||||
|
||||
// Peer ID should include thread ID for session key isolation
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-1001234567890/42", inbound.Peer.ID)
|
||||
|
||||
// Parent peer metadata should be set for agent binding
|
||||
assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"])
|
||||
assert.Equal(t, "42", inbound.Metadata["parent_peer_id"])
|
||||
// ChatID remains the parent chat; TopicID isolates the sub-conversation.
|
||||
assert.Equal(t, "-1001234567890", inbound.ChatID)
|
||||
assert.Equal(t, "group", inbound.Context.ChatType)
|
||||
assert.Equal(t, "42", inbound.Context.TopicID)
|
||||
}
|
||||
|
||||
func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
|
||||
@@ -599,13 +625,8 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
|
||||
// Plain chatID without thread suffix
|
||||
assert.Equal(t, "-100999", inbound.ChatID)
|
||||
|
||||
// Peer ID should be raw chat ID (no thread suffix)
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-100999", inbound.Peer.ID)
|
||||
|
||||
// No parent peer metadata
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_id"])
|
||||
assert.Equal(t, "group", inbound.Context.ChatType)
|
||||
assert.Empty(t, inbound.Context.TopicID)
|
||||
}
|
||||
|
||||
func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
|
||||
@@ -642,13 +663,8 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
|
||||
// chatID should NOT include thread suffix for non-forum groups
|
||||
assert.Equal(t, "-100999", inbound.ChatID)
|
||||
|
||||
// Peer ID should be raw chat ID (shared session for whole group)
|
||||
assert.Equal(t, "group", inbound.Peer.Kind)
|
||||
assert.Equal(t, "-100999", inbound.Peer.ID)
|
||||
|
||||
// No parent peer metadata
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_kind"])
|
||||
assert.Empty(t, inbound.Metadata["parent_peer_id"])
|
||||
assert.Equal(t, "group", inbound.Context.ChatType)
|
||||
assert.Empty(t, inbound.Context.TopicID)
|
||||
}
|
||||
|
||||
func assertHandleMessageQuotedUserReply(
|
||||
@@ -701,7 +717,7 @@ func assertHandleMessageQuotedUserReply(
|
||||
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Metadata["reply_to_message_id"])
|
||||
assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Context.ReplyToMessageID)
|
||||
assert.Equal(t, expectedContent, inbound.Content)
|
||||
}
|
||||
|
||||
@@ -787,7 +803,7 @@ func TestHandleMessage_ReplyToOwnBotMessage_UsesAssistantRole(t *testing.T) {
|
||||
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "101", inbound.Metadata["reply_to_message_id"])
|
||||
assert.Equal(t, "101", inbound.Context.ReplyToMessageID)
|
||||
assert.Equal(
|
||||
t,
|
||||
"[quoted assistant message from afjcjsbx_picoclaw_bot]: Fatto! Ho creato il file notizie_2026_03_28.md\n\nti ricordi questo file?",
|
||||
|
||||
+11
-15
@@ -172,14 +172,11 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
|
||||
_ = groupTrigger
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerIDStr := userID
|
||||
chatType := "direct"
|
||||
if isGroupChat {
|
||||
peerKind = "group"
|
||||
peerIDStr = chatID
|
||||
chatType = "group"
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerIDStr}
|
||||
messageID := strconv.Itoa(msg.ConversationMessageID)
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -187,16 +184,15 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
|
||||
"is_group": fmt.Sprintf("%t", isGroupChat),
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
peer,
|
||||
messageID,
|
||||
userID,
|
||||
chatID,
|
||||
text,
|
||||
nil,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{
|
||||
Channel: "vk",
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
SenderID: userID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isGroupChat && c.isMentioned(msg),
|
||||
Raw: metadata,
|
||||
}, sender)
|
||||
}
|
||||
|
||||
func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
|
||||
@@ -570,7 +570,6 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
|
||||
return err
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
|
||||
metadata := map[string]string{
|
||||
"channel": "wecom",
|
||||
"req_id": reqID,
|
||||
@@ -583,7 +582,20 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
|
||||
metadata["quote_text"] = quoteText
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
Account: strings.TrimSpace(msg.AIBotID),
|
||||
ChatID: actualChatID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: msg.MsgID,
|
||||
ReplyHandles: map[string]string{
|
||||
"req_id": reqID,
|
||||
},
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -50,11 +50,11 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
|
||||
if inbound.MessageID != "msg-1" {
|
||||
t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID)
|
||||
}
|
||||
if inbound.Peer.ID != "chat-1" {
|
||||
t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID)
|
||||
if inbound.Context.ChatType != "direct" {
|
||||
t.Fatalf("inbound Context.ChatType = %q, want direct", inbound.Context.ChatType)
|
||||
}
|
||||
if inbound.Metadata["req_id"] != "req-1" {
|
||||
t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"])
|
||||
if inbound.Context.ReplyHandles["req_id"] != "req-1" {
|
||||
t.Fatalf("inbound req_id = %q, want req-1", inbound.Context.ReplyHandles["req_id"])
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected inbound message to be published")
|
||||
|
||||
@@ -357,8 +357,6 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
|
||||
return
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: "direct", ID: fromUserID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"from_user_id": fromUserID,
|
||||
"context_token": msg.ContextToken,
|
||||
@@ -377,7 +375,21 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
|
||||
c.persistContextTokens()
|
||||
}
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "weixin",
|
||||
ChatID: fromUserID,
|
||||
ChatType: "direct",
|
||||
SenderID: fromUserID,
|
||||
MessageID: messageID,
|
||||
Raw: metadata,
|
||||
}
|
||||
if msg.ContextToken != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"context_token": msg.ContextToken,
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, fromUserID, content, mediaRefs, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// Send implements channels.Channel by sending a text message to the WeChat user.
|
||||
|
||||
@@ -227,13 +227,6 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
metadata["user_name"] = userName
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
if chatID == senderID {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
}
|
||||
|
||||
logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
|
||||
"sender": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
@@ -252,5 +245,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: chatID,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Raw: metadata,
|
||||
}
|
||||
if chatID == senderID {
|
||||
inboundCtx.ChatType = "direct"
|
||||
} else {
|
||||
inboundCtx.ChatType = "group"
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
@@ -377,7 +377,6 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
|
||||
if evt.Info.Chat.Server == types.GroupServer {
|
||||
peerKind = "group"
|
||||
}
|
||||
peer := bus.Peer{Kind: peerKind, ID: chatID}
|
||||
messageID := evt.Info.ID
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "whatsapp",
|
||||
@@ -395,7 +394,17 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
|
||||
"WhatsApp message received",
|
||||
map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)},
|
||||
)
|
||||
c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: chatID,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
ChatType: peerKind,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.runCtx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
|
||||
+33
-24
@@ -30,10 +30,10 @@ func init() {
|
||||
|
||||
// Config is the current config structure with version support.
|
||||
type Config struct {
|
||||
Version int `json:"version" yaml:"-"` // Config schema version for migration
|
||||
// Config schema version for migration.
|
||||
Version int `json:"version" yaml:"-"`
|
||||
Isolation IsolationConfig `json:"isolation,omitempty" yaml:"-"`
|
||||
Agents AgentsConfig `json:"agents" yaml:"-"`
|
||||
Bindings []AgentBinding `json:"bindings,omitempty" yaml:"-"`
|
||||
Session SessionConfig `json:"session,omitempty" yaml:"-"`
|
||||
Channels ChannelsConfig `json:"channel_list" yaml:"channel_list"`
|
||||
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
|
||||
@@ -120,7 +120,7 @@ type BuildInfo struct {
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for Config
|
||||
// to omit providers section when empty and session when empty
|
||||
// to omit providers section when empty and session when empty.
|
||||
func (c *Config) MarshalJSON() ([]byte, error) {
|
||||
type Alias Config
|
||||
aux := &struct {
|
||||
@@ -130,17 +130,18 @@ func (c *Config) MarshalJSON() ([]byte, error) {
|
||||
Alias: (*Alias)(c),
|
||||
}
|
||||
|
||||
// Only include session if not empty
|
||||
if c.Session.DMScope != "" || len(c.Session.IdentityLinks) > 0 {
|
||||
aux.Session = &c.Session
|
||||
if len(c.Session.Dimensions) > 0 || len(c.Session.IdentityLinks) > 0 {
|
||||
sessionCfg := c.Session
|
||||
aux.Session = &sessionCfg
|
||||
}
|
||||
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
type AgentsConfig struct {
|
||||
Defaults AgentDefaults `json:"defaults"`
|
||||
List []AgentConfig `json:"list,omitempty"`
|
||||
Defaults AgentDefaults `json:"defaults"`
|
||||
List []AgentConfig `json:"list,omitempty"`
|
||||
Dispatch *DispatchConfig `json:"dispatch,omitempty"`
|
||||
}
|
||||
|
||||
// AgentModelConfig supports both string and structured model config.
|
||||
@@ -197,26 +198,29 @@ type SubagentsConfig struct {
|
||||
Model *AgentModelConfig `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type PeerMatch struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
type DispatchConfig struct {
|
||||
Rules []DispatchRule `json:"rules,omitempty"`
|
||||
}
|
||||
|
||||
type BindingMatch struct {
|
||||
Channel string `json:"channel"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Peer *PeerMatch `json:"peer,omitempty"`
|
||||
GuildID string `json:"guild_id,omitempty"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
type DispatchRule struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Agent string `json:"agent"`
|
||||
When DispatchSelector `json:"when"`
|
||||
SessionDimensions []string `json:"session_dimensions,omitempty"`
|
||||
}
|
||||
|
||||
type AgentBinding struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Match BindingMatch `json:"match"`
|
||||
type DispatchSelector struct {
|
||||
Channel string `json:"channel,omitempty"`
|
||||
Account string `json:"account,omitempty"`
|
||||
Space string `json:"space,omitempty"`
|
||||
Chat string `json:"chat,omitempty"`
|
||||
Topic string `json:"topic,omitempty"`
|
||||
Sender string `json:"sender,omitempty"`
|
||||
Mentioned *bool `json:"mentioned,omitempty"`
|
||||
}
|
||||
|
||||
type SessionConfig struct {
|
||||
DMScope string `json:"dm_scope,omitempty"`
|
||||
Dimensions []string `json:"dimensions,omitempty"`
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
@@ -509,9 +513,10 @@ type DevicesConfig struct {
|
||||
}
|
||||
|
||||
type VoiceConfig struct {
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
|
||||
TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"`
|
||||
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"`
|
||||
TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"`
|
||||
EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"`
|
||||
ElevenLabsAPIKey string `json:"elevenlabs_api_key,omitempty" env:"PICOCLAW_VOICE_ELEVENLABS_API_KEY"`
|
||||
}
|
||||
|
||||
// ModelConfig represents a model-centric provider configuration.
|
||||
@@ -1066,6 +1071,8 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version)
|
||||
}
|
||||
|
||||
applyLegacyBindingsMigration(data, cfg)
|
||||
|
||||
if err = env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1256,6 +1263,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
isVirtual: true,
|
||||
}
|
||||
expanded = append(expanded, additionalEntry)
|
||||
@@ -1277,6 +1285,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
APIKeys: SimpleSecureStrings(keys[0]),
|
||||
}
|
||||
|
||||
|
||||
+243
-38
@@ -109,18 +109,8 @@ func TestAgentConfig_FullParse(t *testing.T) {
|
||||
}
|
||||
]
|
||||
},
|
||||
"bindings": [
|
||||
{
|
||||
"agent_id": "support",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"account_id": "*",
|
||||
"peer": {"kind": "direct", "id": "user123"}
|
||||
}
|
||||
}
|
||||
],
|
||||
"session": {
|
||||
"dm_scope": "per-peer",
|
||||
"dimensions": ["sender"],
|
||||
"identity_links": {
|
||||
"john": ["telegram:123", "discord:john#1234"]
|
||||
}
|
||||
@@ -158,19 +148,8 @@ func TestAgentConfig_FullParse(t *testing.T) {
|
||||
t.Errorf("support.Subagents = %+v", support.Subagents)
|
||||
}
|
||||
|
||||
if len(cfg.Bindings) != 1 {
|
||||
t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings))
|
||||
}
|
||||
binding := cfg.Bindings[0]
|
||||
if binding.AgentID != "support" || binding.Match.Channel != "telegram" {
|
||||
t.Errorf("binding = %+v", binding)
|
||||
}
|
||||
if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" {
|
||||
t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer)
|
||||
}
|
||||
|
||||
if cfg.Session.DMScope != "per-peer" {
|
||||
t.Errorf("Session.DMScope = %q", cfg.Session.DMScope)
|
||||
if len(cfg.Session.Dimensions) != 1 || cfg.Session.Dimensions[0] != "sender" {
|
||||
t.Errorf("Session.Dimensions = %v", cfg.Session.Dimensions)
|
||||
}
|
||||
if len(cfg.Session.IdentityLinks) != 1 {
|
||||
t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks)
|
||||
@@ -236,8 +215,242 @@ func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
|
||||
if len(cfg.Agents.List) != 0 {
|
||||
t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List))
|
||||
}
|
||||
if len(cfg.Bindings) != 0 {
|
||||
t.Errorf("bindings should be empty, got %d", len(cfg.Bindings))
|
||||
}
|
||||
|
||||
func TestAgentConfig_ParsesDispatchRules(t *testing.T) {
|
||||
jsonData := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7"
|
||||
},
|
||||
"list": [
|
||||
{ "id": "main", "default": true },
|
||||
{ "id": "support" }
|
||||
],
|
||||
"dispatch": {
|
||||
"rules": [
|
||||
{
|
||||
"name": "support-vip",
|
||||
"agent": "support",
|
||||
"when": {
|
||||
"channel": "telegram",
|
||||
"chat": "group:-100123",
|
||||
"sender": "12345",
|
||||
"mentioned": true
|
||||
},
|
||||
"session_dimensions": ["chat", "sender"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if cfg.Agents.Dispatch == nil {
|
||||
t.Fatal("Agents.Dispatch should not be nil")
|
||||
}
|
||||
if len(cfg.Agents.Dispatch.Rules) != 1 {
|
||||
t.Fatalf("Dispatch.Rules len = %d, want 1", len(cfg.Agents.Dispatch.Rules))
|
||||
}
|
||||
rule := cfg.Agents.Dispatch.Rules[0]
|
||||
if rule.Name != "support-vip" || rule.Agent != "support" {
|
||||
t.Fatalf("rule = %+v", rule)
|
||||
}
|
||||
if rule.When.Channel != "telegram" || rule.When.Chat != "group:-100123" || rule.When.Sender != "12345" {
|
||||
t.Fatalf("rule.When = %+v", rule.When)
|
||||
}
|
||||
if rule.When.Mentioned == nil || !*rule.When.Mentioned {
|
||||
t.Fatalf("rule.When.Mentioned = %+v, want true", rule.When.Mentioned)
|
||||
}
|
||||
if got := rule.SessionDimensions; len(got) != 2 || got[0] != "chat" || got[1] != "sender" {
|
||||
t.Fatalf("rule.SessionDimensions = %v, want [chat sender]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MigratesLegacyBindingsToDispatchRules(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
raw := `{
|
||||
"version": 2,
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7"
|
||||
},
|
||||
"list": [
|
||||
{ "id": "main", "default": true },
|
||||
{ "id": "support" },
|
||||
{ "id": "ops" },
|
||||
{ "id": "slack" }
|
||||
]
|
||||
},
|
||||
"bindings": [
|
||||
{
|
||||
"agent_id": "support",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"peer": { "kind": "group", "id": "-100123" }
|
||||
}
|
||||
},
|
||||
{
|
||||
"agent_id": "ops",
|
||||
"match": {
|
||||
"channel": "discord",
|
||||
"guild_id": "guild-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"agent_id": "slack",
|
||||
"match": {
|
||||
"channel": "slack",
|
||||
"account_id": "*"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(configPath): %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Agents.Dispatch == nil {
|
||||
t.Fatal("Agents.Dispatch should not be nil")
|
||||
}
|
||||
if len(cfg.Agents.Dispatch.Rules) != 3 {
|
||||
t.Fatalf("Dispatch.Rules len = %d, want 3", len(cfg.Agents.Dispatch.Rules))
|
||||
}
|
||||
|
||||
first := cfg.Agents.Dispatch.Rules[0]
|
||||
if first.Agent != "support" {
|
||||
t.Fatalf("first.Agent = %q, want %q", first.Agent, "support")
|
||||
}
|
||||
if first.When.Channel != "telegram" || first.When.Chat != "group:-100123" {
|
||||
t.Fatalf("first.When = %+v", first.When)
|
||||
}
|
||||
if first.When.Account != legacyDefaultAccountID {
|
||||
t.Fatalf("first.When.Account = %q, want %q", first.When.Account, legacyDefaultAccountID)
|
||||
}
|
||||
|
||||
second := cfg.Agents.Dispatch.Rules[1]
|
||||
if second.Agent != "ops" || second.When.Space != "guild:guild-1" {
|
||||
t.Fatalf("second = %+v", second)
|
||||
}
|
||||
|
||||
third := cfg.Agents.Dispatch.Rules[2]
|
||||
if third.Agent != "slack" {
|
||||
t.Fatalf("third.Agent = %q, want %q", third.Agent, "slack")
|
||||
}
|
||||
if third.When.Channel != "slack" || third.When.Account != "" {
|
||||
t.Fatalf("third.When = %+v", third.When)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_PrefersDispatchRulesOverLegacyBindings(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
raw := `{
|
||||
"version": 2,
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7"
|
||||
},
|
||||
"list": [
|
||||
{ "id": "main", "default": true },
|
||||
{ "id": "support" }
|
||||
],
|
||||
"dispatch": {
|
||||
"rules": [
|
||||
{
|
||||
"name": "explicit",
|
||||
"agent": "support",
|
||||
"when": {
|
||||
"channel": "telegram",
|
||||
"chat": "group:-100123"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"bindings": [
|
||||
{
|
||||
"agent_id": "main",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"account_id": "*"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(configPath): %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Agents.Dispatch == nil {
|
||||
t.Fatal("Agents.Dispatch should not be nil")
|
||||
}
|
||||
if len(cfg.Agents.Dispatch.Rules) != 1 {
|
||||
t.Fatalf("Dispatch.Rules len = %d, want 1", len(cfg.Agents.Dispatch.Rules))
|
||||
}
|
||||
if cfg.Agents.Dispatch.Rules[0].Name != "explicit" {
|
||||
t.Fatalf("Dispatch.Rules[0].Name = %q, want %q", cfg.Agents.Dispatch.Rules[0].Name, "explicit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MigratesLegacyDirectBindingsWithIdentityLinks(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
raw := `{
|
||||
"version": 2,
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7"
|
||||
},
|
||||
"list": [
|
||||
{ "id": "main", "default": true },
|
||||
{ "id": "support" }
|
||||
]
|
||||
},
|
||||
"session": {
|
||||
"identity_links": {
|
||||
"john": ["telegram:123", "123"]
|
||||
}
|
||||
},
|
||||
"bindings": [
|
||||
{
|
||||
"agent_id": "support",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"peer": { "kind": "direct", "id": "123" }
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(configPath): %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Agents.Dispatch == nil || len(cfg.Agents.Dispatch.Rules) != 1 {
|
||||
t.Fatalf("Dispatch.Rules = %+v, want 1 migrated rule", cfg.Agents.Dispatch)
|
||||
}
|
||||
if got := cfg.Agents.Dispatch.Rules[0].When.Sender; got != "john" {
|
||||
t.Fatalf("migrated sender selector = %q, want %q", got, "john")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,13 +587,6 @@ func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_ReadFileMode(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Tools.ReadFile.EffectiveMode() != ReadFileModeBytes {
|
||||
t.Fatalf("expected default read_file mode %q, got %q", ReadFileModeBytes, cfg.Tools.ReadFile.EffectiveMode())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveConfig_FilePermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permission bits are not enforced on Windows")
|
||||
@@ -825,7 +1031,7 @@ func TestLoadConfig_HooksProcessConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_DMScope verifies the default dm_scope value
|
||||
// TestDefaultConfig_SessionDimensions verifies the default session dimensions
|
||||
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
|
||||
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
@@ -838,11 +1044,11 @@ func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
func TestDefaultConfig_SessionDimensions(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Session.DMScope != "per-channel-peer" {
|
||||
t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope)
|
||||
if len(cfg.Session.Dimensions) != 1 || cfg.Session.Dimensions[0] != "chat" {
|
||||
t.Errorf("Session.Dimensions = %v, want [chat]", cfg.Session.Dimensions)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1076,7 +1282,6 @@ func TestLoadConfig_TelegramPlaceholderTextAcceptsSingleString(t *testing.T) {
|
||||
data := `{
|
||||
"version": 1,
|
||||
"agents": { "defaults": { "workspace": "", "model": "", "max_tokens": 0, "max_tool_iterations": 0 } },
|
||||
"bindings": [],
|
||||
"session": {},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
|
||||
@@ -41,9 +41,8 @@ func DefaultConfig() *Config {
|
||||
SplitOnMarker: false,
|
||||
},
|
||||
},
|
||||
Bindings: []AgentBinding{},
|
||||
Session: SessionConfig{
|
||||
DMScope: "per-channel-peer",
|
||||
Dimensions: []string{"chat"},
|
||||
},
|
||||
Channels: defaultChannels(),
|
||||
Hooks: HooksConfig{
|
||||
@@ -422,7 +421,9 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
Voice: VoiceConfig{
|
||||
ModelName: "",
|
||||
TTSModelName: "",
|
||||
EchoTranscription: false,
|
||||
ElevenLabsAPIKey: "",
|
||||
},
|
||||
BuildInfo: BuildInfo{
|
||||
Version: Version,
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const legacyDefaultAccountID = "default"
|
||||
|
||||
type legacyBindingsEnvelope struct {
|
||||
Bindings json.RawMessage `json:"bindings"`
|
||||
}
|
||||
|
||||
type legacyAgentBinding struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Match legacyBindingMatch `json:"match"`
|
||||
}
|
||||
|
||||
type legacyBindingMatch struct {
|
||||
Channel string `json:"channel"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Peer *legacyPeerMatch `json:"peer,omitempty"`
|
||||
GuildID string `json:"guild_id,omitempty"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
}
|
||||
|
||||
type legacyPeerMatch struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func applyLegacyBindingsMigration(data []byte, cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bindings, found, err := decodeLegacyBindings(data)
|
||||
if err != nil {
|
||||
logger.WarnF(
|
||||
"legacy bindings config detected but could not be decoded",
|
||||
map[string]any{"error": err},
|
||||
)
|
||||
return
|
||||
}
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Agents.Dispatch != nil && len(cfg.Agents.Dispatch.Rules) > 0 {
|
||||
logger.WarnF(
|
||||
"legacy bindings config is deprecated and ignored because agents.dispatch.rules is configured",
|
||||
map[string]any{"bindings": len(bindings), "dispatch_rules": len(cfg.Agents.Dispatch.Rules)},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
rules, dropped := migrateLegacyBindings(bindings, cfg.Session.IdentityLinks)
|
||||
if len(rules) == 0 {
|
||||
logger.WarnF(
|
||||
"legacy bindings config is deprecated and could not be migrated",
|
||||
map[string]any{"bindings": len(bindings), "dropped_bindings": dropped},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Agents.Dispatch == nil {
|
||||
cfg.Agents.Dispatch = &DispatchConfig{}
|
||||
}
|
||||
cfg.Agents.Dispatch.Rules = rules
|
||||
|
||||
fields := map[string]any{
|
||||
"bindings": len(bindings),
|
||||
"dispatch_rules": len(rules),
|
||||
}
|
||||
if dropped > 0 {
|
||||
fields["dropped_bindings"] = dropped
|
||||
}
|
||||
logger.WarnF("legacy bindings config is deprecated; migrated to agents.dispatch.rules in memory", fields)
|
||||
}
|
||||
|
||||
func decodeLegacyBindings(data []byte) ([]legacyAgentBinding, bool, error) {
|
||||
var envelope legacyBindingsEnvelope
|
||||
if err := json.Unmarshal(data, &envelope); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(envelope.Bindings) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
var bindings []legacyAgentBinding
|
||||
if err := json.Unmarshal(envelope.Bindings, &bindings); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return bindings, true, nil
|
||||
}
|
||||
|
||||
func migrateLegacyBindings(bindings []legacyAgentBinding, identityLinks map[string][]string) ([]DispatchRule, int) {
|
||||
if len(bindings) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
type prioritizedRule struct {
|
||||
rule DispatchRule
|
||||
index int
|
||||
kind int
|
||||
}
|
||||
|
||||
prioritized := make([]prioritizedRule, 0, len(bindings))
|
||||
dropped := 0
|
||||
for i, binding := range bindings {
|
||||
rule, kind, ok := migrateLegacyBinding(binding, i, identityLinks)
|
||||
if !ok {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
prioritized = append(prioritized, prioritizedRule{rule: rule, index: i, kind: kind})
|
||||
}
|
||||
if len(prioritized) == 0 {
|
||||
return nil, dropped
|
||||
}
|
||||
|
||||
rules := make([]DispatchRule, 0, len(prioritized))
|
||||
for kind := 0; kind <= 4; kind++ {
|
||||
for _, item := range prioritized {
|
||||
if item.kind == kind {
|
||||
rules = append(rules, item.rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rules, dropped
|
||||
}
|
||||
|
||||
func migrateLegacyBinding(
|
||||
binding legacyAgentBinding,
|
||||
index int,
|
||||
identityLinks map[string][]string,
|
||||
) (DispatchRule, int, bool) {
|
||||
channel := strings.ToLower(strings.TrimSpace(binding.Match.Channel))
|
||||
agentID := strings.TrimSpace(binding.AgentID)
|
||||
if channel == "" || agentID == "" {
|
||||
return DispatchRule{}, 0, false
|
||||
}
|
||||
|
||||
rule := DispatchRule{
|
||||
Name: fmt.Sprintf("legacy-binding-%d", index+1),
|
||||
Agent: agentID,
|
||||
When: DispatchSelector{
|
||||
Channel: channel,
|
||||
},
|
||||
}
|
||||
|
||||
switch normalizeLegacyAccountSelector(binding.Match.AccountID) {
|
||||
case "":
|
||||
case "*":
|
||||
default:
|
||||
rule.When.Account = normalizeLegacyAccountSelector(binding.Match.AccountID)
|
||||
}
|
||||
|
||||
if peer := binding.Match.Peer; peer != nil {
|
||||
peerKind := strings.ToLower(strings.TrimSpace(peer.Kind))
|
||||
peerID := strings.TrimSpace(peer.ID)
|
||||
if peerID == "" {
|
||||
return DispatchRule{}, 0, false
|
||||
}
|
||||
switch peerKind {
|
||||
case "direct":
|
||||
rule.When.Sender = canonicalLegacyBindingSenderID(channel, peerID, identityLinks)
|
||||
return rule, 0, true
|
||||
case "group", "channel":
|
||||
rule.When.Chat = peerKind + ":" + peerID
|
||||
return rule, 0, true
|
||||
case "topic":
|
||||
rule.When.Topic = "topic:" + peerID
|
||||
return rule, 0, true
|
||||
default:
|
||||
return DispatchRule{}, 0, false
|
||||
}
|
||||
}
|
||||
|
||||
if guildID := strings.TrimSpace(binding.Match.GuildID); guildID != "" {
|
||||
rule.When.Space = "guild:" + guildID
|
||||
return rule, 1, true
|
||||
}
|
||||
|
||||
if teamID := strings.TrimSpace(binding.Match.TeamID); teamID != "" {
|
||||
rule.When.Space = "team:" + teamID
|
||||
return rule, 2, true
|
||||
}
|
||||
|
||||
accountSelector := normalizeLegacyAccountSelector(binding.Match.AccountID)
|
||||
if accountSelector == "*" {
|
||||
rule.When.Account = ""
|
||||
return rule, 4, true
|
||||
}
|
||||
|
||||
rule.When.Account = accountSelector
|
||||
return rule, 3, true
|
||||
}
|
||||
|
||||
func normalizeLegacyAccountSelector(accountID string) string {
|
||||
accountID = strings.TrimSpace(accountID)
|
||||
switch accountID {
|
||||
case "":
|
||||
return legacyDefaultAccountID
|
||||
case "*":
|
||||
return "*"
|
||||
default:
|
||||
return strings.ToLower(accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func canonicalLegacyBindingSenderID(channel, peerID string, identityLinks map[string][]string) string {
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if linked := resolveLegacyBindingLinkedID(identityLinks, channel, peerID); linked != "" {
|
||||
return strings.ToLower(linked)
|
||||
}
|
||||
|
||||
return strings.ToLower(peerID)
|
||||
}
|
||||
|
||||
func resolveLegacyBindingLinkedID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]struct{})
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = struct{}{}
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
candidates[channel+":"+rawCandidate] = struct{}{}
|
||||
}
|
||||
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
|
||||
candidates[rawCandidate[idx+1:]] = struct{}{}
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonical = strings.TrimSpace(canonical)
|
||||
if canonical == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := candidates[normalized]; ok {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -131,8 +131,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Context: bus.NewOutboundContext(platform, userID, ""),
|
||||
Content: msg,
|
||||
})
|
||||
|
||||
|
||||
@@ -339,8 +339,7 @@ func (hs *HeartbeatService) sendResponse(response string) {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Context: bus.NewOutboundContext(platform, userID, ""),
|
||||
Content: response,
|
||||
})
|
||||
|
||||
|
||||
+334
-15
@@ -32,14 +32,19 @@ const (
|
||||
maxLineSize = 10 * 1024 * 1024 // 10 MB
|
||||
)
|
||||
|
||||
// sessionMeta holds per-session metadata stored in a .meta.json file.
|
||||
type sessionMeta struct {
|
||||
Key string `json:"key"`
|
||||
Summary string `json:"summary"`
|
||||
Skip int `json:"skip"`
|
||||
Count int `json:"count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
// SessionMeta holds per-session metadata stored in a .meta.json file.
|
||||
//
|
||||
// Scope is stored as raw JSON so pkg/memory can stay decoupled from the
|
||||
// higher-level session package while still preserving structured scope data.
|
||||
type SessionMeta struct {
|
||||
Key string `json:"key"`
|
||||
Summary string `json:"summary"`
|
||||
Skip int `json:"skip"`
|
||||
Count int `json:"count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Scope json.RawMessage `json:"scope,omitempty"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
// JSONLStore implements Store using append-only JSONL files.
|
||||
@@ -98,25 +103,31 @@ func sanitizeKey(key string) string {
|
||||
|
||||
// readMeta loads the metadata file for a session.
|
||||
// Returns a zero-value sessionMeta if the file does not exist.
|
||||
func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
|
||||
func (s *JSONLStore) readMeta(key string) (SessionMeta, error) {
|
||||
data, err := os.ReadFile(s.metaPath(key))
|
||||
if os.IsNotExist(err) {
|
||||
return sessionMeta{Key: key}, nil
|
||||
return SessionMeta{Key: key}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
|
||||
return SessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
|
||||
}
|
||||
var meta sessionMeta
|
||||
var meta SessionMeta
|
||||
err = json.Unmarshal(data, &meta)
|
||||
if err != nil {
|
||||
return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
|
||||
return SessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
|
||||
}
|
||||
if meta.Key == "" {
|
||||
meta.Key = key
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// writeMeta atomically writes the metadata file using the project's
|
||||
// standard WriteFileAtomic (temp + fsync + rename).
|
||||
func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
|
||||
func (s *JSONLStore) writeMeta(key string, meta SessionMeta) error {
|
||||
if strings.TrimSpace(meta.Key) == "" {
|
||||
meta.Key = key
|
||||
}
|
||||
data, err := json.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("memory: encode meta: %w", err)
|
||||
@@ -124,6 +135,314 @@ func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
|
||||
return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
|
||||
}
|
||||
|
||||
func cloneRawJSON(data json.RawMessage) json.RawMessage {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return append(json.RawMessage(nil), data...)
|
||||
}
|
||||
|
||||
func normalizeAliases(canonicalKey string, aliases []string) []string {
|
||||
if len(aliases) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(aliases))
|
||||
seen := make(map[string]struct{}, len(aliases))
|
||||
canonicalKey = strings.TrimSpace(canonicalKey)
|
||||
for _, alias := range aliases {
|
||||
alias = strings.TrimSpace(alias)
|
||||
if alias == "" || alias == canonicalKey {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[alias]; ok {
|
||||
continue
|
||||
}
|
||||
seen[alias] = struct{}{}
|
||||
normalized = append(normalized, alias)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func (s *JSONLStore) sessionExists(key string) bool {
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
if _, err := os.Stat(s.jsonlPath(key)); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := os.Stat(s.metaPath(key)); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSessionMeta returns the current metadata snapshot for sessionKey.
|
||||
func (s *JSONLStore) GetSessionMeta(_ context.Context, sessionKey string) (SessionMeta, error) {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return SessionMeta{}, err
|
||||
}
|
||||
meta.Scope = cloneRawJSON(meta.Scope)
|
||||
if len(meta.Aliases) > 0 {
|
||||
meta.Aliases = append([]string(nil), meta.Aliases...)
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// UpsertSessionMeta stores structured session metadata while preserving
|
||||
// summary/count/skip timestamps maintained by the core JSONL store.
|
||||
func (s *JSONLStore) UpsertSessionMeta(
|
||||
_ context.Context,
|
||||
sessionKey string,
|
||||
scope json.RawMessage,
|
||||
aliases []string,
|
||||
) error {
|
||||
l := s.sessionLock(sessionKey)
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
meta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
meta.Scope = cloneRawJSON(scope)
|
||||
meta.Aliases = normalizeAliases(sessionKey, aliases)
|
||||
now := time.Now()
|
||||
if meta.CreatedAt.IsZero() {
|
||||
meta.CreatedAt = now
|
||||
}
|
||||
meta.UpdatedAt = now
|
||||
|
||||
return s.writeMeta(sessionKey, meta)
|
||||
}
|
||||
|
||||
// PromoteAliasHistory atomically promotes the first non-empty alias session
|
||||
// into the canonical session when the canonical session is still empty.
|
||||
func (s *JSONLStore) PromoteAliasHistory(
|
||||
_ context.Context,
|
||||
sessionKey string,
|
||||
scope json.RawMessage,
|
||||
aliases []string,
|
||||
) (bool, error) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
aliases = normalizeAliases(sessionKey, aliases)
|
||||
for _, alias := range aliases {
|
||||
unlock := s.lockSessionPair(sessionKey, alias)
|
||||
promoted, err := s.promoteAliasHistoryLocked(sessionKey, alias, scope, aliases)
|
||||
unlock()
|
||||
if err != nil || promoted {
|
||||
return promoted, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ResolveSessionKey returns the canonical session key for a candidate key.
|
||||
// It short-circuits direct canonical keys when possible, then scans metadata
|
||||
// once to resolve aliases or canonical metadata keys.
|
||||
func (s *JSONLStore) ResolveSessionKey(_ context.Context, sessionKey string) (string, bool, error) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
hasDirectSession := s.sessionExists(sessionKey)
|
||||
if hasDirectSession && shouldShortCircuitSessionResolve(sessionKey) {
|
||||
return sessionKey, true, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(s.dir)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("memory: read sessions dir: %w", err)
|
||||
}
|
||||
|
||||
var directMetaMatch string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
|
||||
continue
|
||||
}
|
||||
|
||||
data, readErr := os.ReadFile(filepath.Join(s.dir, entry.Name()))
|
||||
if readErr != nil {
|
||||
log.Printf("memory: skipping unreadable meta %s: %v", entry.Name(), readErr)
|
||||
continue
|
||||
}
|
||||
|
||||
var meta SessionMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
log.Printf("memory: skipping corrupt meta %s: %v", entry.Name(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
if meta.Key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if meta.Key == sessionKey {
|
||||
directMetaMatch = meta.Key
|
||||
}
|
||||
|
||||
for _, alias := range meta.Aliases {
|
||||
if alias == sessionKey && meta.Key != sessionKey {
|
||||
return meta.Key, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if directMetaMatch != "" {
|
||||
return directMetaMatch, true, nil
|
||||
}
|
||||
|
||||
if hasDirectSession {
|
||||
return sessionKey, true, nil
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func shouldShortCircuitSessionResolve(sessionKey string) bool {
|
||||
sessionKey = strings.TrimSpace(strings.ToLower(sessionKey))
|
||||
if sessionKey == "" {
|
||||
return false
|
||||
}
|
||||
return !strings.ContainsAny(sessionKey, ":/\\")
|
||||
}
|
||||
|
||||
func (s *JSONLStore) lockSessionPair(keyA, keyB string) func() {
|
||||
lockA := s.sessionLock(keyA)
|
||||
lockB := s.sessionLock(keyB)
|
||||
if lockA == lockB {
|
||||
lockA.Lock()
|
||||
return func() { lockA.Unlock() }
|
||||
}
|
||||
if keyA <= keyB {
|
||||
lockA.Lock()
|
||||
lockB.Lock()
|
||||
return func() {
|
||||
lockB.Unlock()
|
||||
lockA.Unlock()
|
||||
}
|
||||
}
|
||||
lockB.Lock()
|
||||
lockA.Lock()
|
||||
return func() {
|
||||
lockA.Unlock()
|
||||
lockB.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *JSONLStore) promoteAliasHistoryLocked(
|
||||
sessionKey string,
|
||||
alias string,
|
||||
scope json.RawMessage,
|
||||
aliases []string,
|
||||
) (bool, error) {
|
||||
canonicalMeta, err := s.readMeta(sessionKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
canonicalHasContent, err := s.sessionHasVisibleContentLocked(sessionKey, canonicalMeta)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if canonicalHasContent {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
aliasMeta, err := s.readMeta(alias)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
aliasHistory, err := readMessages(s.jsonlPath(alias), aliasMeta.Skip)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
aliasSummary := strings.TrimSpace(aliasMeta.Summary)
|
||||
if len(aliasHistory) == 0 && aliasSummary == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
previousJSONL, hadPreviousJSONL, err := s.readRawJSONL(sessionKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if canonicalMeta.CreatedAt.IsZero() {
|
||||
canonicalMeta.CreatedAt = now
|
||||
}
|
||||
canonicalMeta.Scope = cloneRawJSON(scope)
|
||||
canonicalMeta.Aliases = normalizeAliases(sessionKey, aliases)
|
||||
canonicalMeta.Skip = 0
|
||||
canonicalMeta.Count = len(aliasHistory)
|
||||
canonicalMeta.UpdatedAt = now
|
||||
if aliasSummary != "" {
|
||||
canonicalMeta.Summary = aliasSummary
|
||||
}
|
||||
|
||||
if err := s.rewriteJSONL(sessionKey, aliasHistory); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := s.writeMeta(sessionKey, canonicalMeta); err != nil {
|
||||
if rollbackErr := s.restoreRawJSONL(sessionKey, previousJSONL, hadPreviousJSONL); rollbackErr != nil {
|
||||
return false, fmt.Errorf("memory: write promoted meta: %w (rollback jsonl: %v)", err, rollbackErr)
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *JSONLStore) sessionHasVisibleContentLocked(sessionKey string, meta SessionMeta) (bool, error) {
|
||||
if meta.Count-meta.Skip > 0 || strings.TrimSpace(meta.Summary) != "" {
|
||||
return true, nil
|
||||
}
|
||||
if meta.Count != 0 || meta.Skip != 0 {
|
||||
return false, nil
|
||||
}
|
||||
history, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(history) > 0, nil
|
||||
}
|
||||
|
||||
func (s *JSONLStore) readRawJSONL(sessionKey string) ([]byte, bool, error) {
|
||||
data, err := os.ReadFile(s.jsonlPath(sessionKey))
|
||||
if os.IsNotExist(err) {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("memory: read jsonl: %w", err)
|
||||
}
|
||||
return data, true, nil
|
||||
}
|
||||
|
||||
func (s *JSONLStore) restoreRawJSONL(sessionKey string, data []byte, existed bool) error {
|
||||
path := s.jsonlPath(sessionKey)
|
||||
if !existed {
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("memory: remove jsonl rollback: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := fileutil.WriteFileAtomic(path, data, 0o644); err != nil {
|
||||
return fmt.Errorf("memory: restore jsonl rollback: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readMessages reads valid JSON lines from a .jsonl file, skipping
|
||||
// the first `skip` lines without unmarshaling them. This avoids the
|
||||
// cost of json.Unmarshal on logically truncated messages.
|
||||
@@ -471,7 +790,7 @@ func (s *JSONLStore) ListSessions() []string {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var meta sessionMeta
|
||||
var meta SessionMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@ package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -241,6 +243,142 @@ func TestSetSummary_GetSummary(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionMetaScopeAndAliasesPersist(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
scope := json.RawMessage(`{"version":1,"channel":"telegram","values":{"chat":"group:c1"}}`)
|
||||
aliases := []string{"legacy:one", "legacy:one", "canonical"}
|
||||
if err := store.UpsertSessionMeta(ctx, "canonical", scope, aliases); err != nil {
|
||||
t.Fatalf("UpsertSessionMeta() error = %v", err)
|
||||
}
|
||||
|
||||
meta, err := store.GetSessionMeta(ctx, "canonical")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionMeta() error = %v", err)
|
||||
}
|
||||
var gotScope map[string]any
|
||||
if err := json.Unmarshal(meta.Scope, &gotScope); err != nil {
|
||||
t.Fatalf("Unmarshal(meta.Scope) error = %v", err)
|
||||
}
|
||||
var wantScope map[string]any
|
||||
if err := json.Unmarshal(scope, &wantScope); err != nil {
|
||||
t.Fatalf("Unmarshal(scope) error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotScope, wantScope) {
|
||||
t.Fatalf("meta.Scope = %#v, want %#v", gotScope, wantScope)
|
||||
}
|
||||
if len(meta.Aliases) != 1 || meta.Aliases[0] != "legacy:one" {
|
||||
t.Fatalf("meta.Aliases = %#v, want [legacy:one]", meta.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSessionKeyByAlias(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil {
|
||||
t.Fatalf("AddMessage() error = %v", err)
|
||||
}
|
||||
if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil {
|
||||
t.Fatalf("UpsertSessionMeta() error = %v", err)
|
||||
}
|
||||
|
||||
resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveSessionKey() error = %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("ResolveSessionKey() did not find alias")
|
||||
}
|
||||
if resolved != "canonical" {
|
||||
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSessionKeyByAlias_PrefersMetadataOverLegacyFile(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := store.AddMessage(ctx, "legacy:key", "user", "legacy"); err != nil {
|
||||
t.Fatalf("AddMessage(legacy) error = %v", err)
|
||||
}
|
||||
if err := store.AddMessage(ctx, "canonical", "user", "canonical"); err != nil {
|
||||
t.Fatalf("AddMessage(canonical) error = %v", err)
|
||||
}
|
||||
if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil {
|
||||
t.Fatalf("UpsertSessionMeta() error = %v", err)
|
||||
}
|
||||
|
||||
resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveSessionKey() error = %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("ResolveSessionKey() did not find alias")
|
||||
}
|
||||
if resolved != "canonical" {
|
||||
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSessionKey_DirectHitSkipsCorruptMetadata(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil {
|
||||
t.Fatalf("AddMessage() error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(store.dir, "broken.meta.json"),
|
||||
[]byte("{not-json"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile(broken.meta.json) error = %v", err)
|
||||
}
|
||||
|
||||
resolved, found, err := store.ResolveSessionKey(ctx, "canonical")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveSessionKey() error = %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("ResolveSessionKey() did not find direct session")
|
||||
}
|
||||
if resolved != "canonical" {
|
||||
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSessionKey_SkipsCorruptMetadataDuringAliasScan(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil {
|
||||
t.Fatalf("AddMessage() error = %v", err)
|
||||
}
|
||||
if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil {
|
||||
t.Fatalf("UpsertSessionMeta() error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(store.dir, "broken.meta.json"),
|
||||
[]byte("{not-json"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile(broken.meta.json) error = %v", err)
|
||||
}
|
||||
|
||||
resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveSessionKey() error = %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("ResolveSessionKey() did not find alias")
|
||||
}
|
||||
if resolved != "canonical" {
|
||||
t.Fatalf("resolved = %q, want %q", resolved, "canonical")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateHistory_KeepLast(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
+245
-184
@@ -1,32 +1,29 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// RouteInput contains the routing context from an inbound message.
|
||||
type RouteInput struct {
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
ParentPeer *RoutePeer
|
||||
GuildID string
|
||||
TeamID string
|
||||
// SessionPolicy describes how a routed message should be mapped to a session.
|
||||
type SessionPolicy struct {
|
||||
Dimensions []string
|
||||
IdentityLinks map[string][]string
|
||||
}
|
||||
|
||||
// ResolvedRoute is the result of agent routing.
|
||||
type ResolvedRoute struct {
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
SessionKey string
|
||||
MainSessionKey string
|
||||
MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default"
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
SessionPolicy SessionPolicy
|
||||
MatchedBy string
|
||||
}
|
||||
|
||||
// RouteResolver determines which agent handles a message based on config bindings.
|
||||
// RouteResolver determines which agent handles a message.
|
||||
type RouteResolver struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
@@ -36,182 +33,32 @@ func NewRouteResolver(cfg *config.Config) *RouteResolver {
|
||||
return &RouteResolver{cfg: cfg}
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message and constructs session keys.
|
||||
// Implements the 7-level priority cascade:
|
||||
// peer > parent_peer > guild > team > account > channel_wildcard > default
|
||||
func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
|
||||
channel := strings.ToLower(strings.TrimSpace(input.Channel))
|
||||
accountID := NormalizeAccountID(input.AccountID)
|
||||
peer := input.Peer
|
||||
// ResolveRoute determines which agent handles the message from a normalized
|
||||
// inbound context and returns the session policy that should be used to
|
||||
// allocate session state.
|
||||
func (r *RouteResolver) ResolveRoute(inbound bus.InboundContext) ResolvedRoute {
|
||||
channel := strings.ToLower(strings.TrimSpace(inbound.Channel))
|
||||
accountID := NormalizeAccountID(inbound.Account)
|
||||
identityLinks := cloneIdentityLinks(r.cfg.Session.IdentityLinks)
|
||||
view := buildDispatchView(inbound, identityLinks)
|
||||
|
||||
dmScope := DMScope(r.cfg.Session.DMScope)
|
||||
if dmScope == "" {
|
||||
dmScope = DMScopeMain
|
||||
}
|
||||
identityLinks := r.cfg.Session.IdentityLinks
|
||||
|
||||
bindings := r.filterBindings(channel, accountID)
|
||||
|
||||
choose := func(agentID string, matchedBy string) ResolvedRoute {
|
||||
resolvedAgentID := r.pickAgentID(agentID)
|
||||
sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: resolvedAgentID,
|
||||
if rule := r.matchDispatchRule(view); rule != nil {
|
||||
return ResolvedRoute{
|
||||
AgentID: r.pickAgentID(rule.Agent),
|
||||
Channel: channel,
|
||||
AccountID: accountID,
|
||||
Peer: peer,
|
||||
DMScope: dmScope,
|
||||
IdentityLinks: identityLinks,
|
||||
}))
|
||||
mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID))
|
||||
return ResolvedRoute{
|
||||
AgentID: resolvedAgentID,
|
||||
Channel: channel,
|
||||
AccountID: accountID,
|
||||
SessionKey: sessionKey,
|
||||
MainSessionKey: mainSessionKey,
|
||||
MatchedBy: matchedBy,
|
||||
SessionPolicy: r.sessionPolicy(rule),
|
||||
MatchedBy: matchedByForRule(rule),
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 1: Peer binding
|
||||
if peer != nil && strings.TrimSpace(peer.ID) != "" {
|
||||
if match := r.findPeerMatch(bindings, peer); match != nil {
|
||||
return choose(match.AgentID, "binding.peer")
|
||||
}
|
||||
return ResolvedRoute{
|
||||
AgentID: r.pickAgentID(r.resolveDefaultAgentID()),
|
||||
Channel: channel,
|
||||
AccountID: accountID,
|
||||
SessionPolicy: r.sessionPolicy(nil),
|
||||
MatchedBy: "default",
|
||||
}
|
||||
|
||||
// Priority 2: Parent peer binding
|
||||
parentPeer := input.ParentPeer
|
||||
if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" {
|
||||
if match := r.findPeerMatch(bindings, parentPeer); match != nil {
|
||||
return choose(match.AgentID, "binding.peer.parent")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Guild binding
|
||||
guildID := strings.TrimSpace(input.GuildID)
|
||||
if guildID != "" {
|
||||
if match := r.findGuildMatch(bindings, guildID); match != nil {
|
||||
return choose(match.AgentID, "binding.guild")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: Team binding
|
||||
teamID := strings.TrimSpace(input.TeamID)
|
||||
if teamID != "" {
|
||||
if match := r.findTeamMatch(bindings, teamID); match != nil {
|
||||
return choose(match.AgentID, "binding.team")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 5: Account binding
|
||||
if match := r.findAccountMatch(bindings); match != nil {
|
||||
return choose(match.AgentID, "binding.account")
|
||||
}
|
||||
|
||||
// Priority 6: Channel wildcard binding
|
||||
if match := r.findChannelWildcardMatch(bindings); match != nil {
|
||||
return choose(match.AgentID, "binding.channel")
|
||||
}
|
||||
|
||||
// Priority 7: Default agent
|
||||
return choose(r.resolveDefaultAgentID(), "default")
|
||||
}
|
||||
|
||||
func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding {
|
||||
var filtered []config.AgentBinding
|
||||
for _, b := range r.cfg.Bindings {
|
||||
matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel))
|
||||
if matchChannel == "" || matchChannel != channel {
|
||||
continue
|
||||
}
|
||||
if !matchesAccountID(b.Match.AccountID, accountID) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, b)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func matchesAccountID(matchAccountID, actual string) bool {
|
||||
trimmed := strings.TrimSpace(matchAccountID)
|
||||
if trimmed == "" {
|
||||
return actual == DefaultAccountID
|
||||
}
|
||||
if trimmed == "*" {
|
||||
return true
|
||||
}
|
||||
return strings.ToLower(trimmed) == strings.ToLower(actual)
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
if b.Match.Peer == nil {
|
||||
continue
|
||||
}
|
||||
peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind))
|
||||
peerID := strings.TrimSpace(b.Match.Peer.ID)
|
||||
if peerKind == "" || peerID == "" {
|
||||
continue
|
||||
}
|
||||
if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
matchGuild := strings.TrimSpace(b.Match.GuildID)
|
||||
if matchGuild != "" && matchGuild == guildID {
|
||||
return &bindings[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
matchTeam := strings.TrimSpace(b.Match.TeamID)
|
||||
if matchTeam != "" && matchTeam == teamID {
|
||||
return &bindings[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
accountID := strings.TrimSpace(b.Match.AccountID)
|
||||
if accountID == "*" {
|
||||
continue
|
||||
}
|
||||
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
|
||||
continue
|
||||
}
|
||||
return &bindings[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
accountID := strings.TrimSpace(b.Match.AccountID)
|
||||
if accountID != "*" {
|
||||
continue
|
||||
}
|
||||
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
|
||||
continue
|
||||
}
|
||||
return &bindings[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) pickAgentID(agentID string) string {
|
||||
@@ -250,3 +97,217 @@ func (r *RouteResolver) resolveDefaultAgentID() string {
|
||||
}
|
||||
return DefaultAgentID
|
||||
}
|
||||
|
||||
func (r *RouteResolver) sessionPolicy(rule *config.DispatchRule) SessionPolicy {
|
||||
dimensions := r.cfg.Session.Dimensions
|
||||
if rule != nil && len(rule.SessionDimensions) > 0 {
|
||||
dimensions = rule.SessionDimensions
|
||||
}
|
||||
return SessionPolicy{
|
||||
Dimensions: normalizeSessionDimensions(dimensions),
|
||||
IdentityLinks: cloneIdentityLinks(r.cfg.Session.IdentityLinks),
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSessionDimensions(dimensions []string) []string {
|
||||
if len(dimensions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(dimensions))
|
||||
seen := make(map[string]struct{}, len(dimensions))
|
||||
for _, dimension := range dimensions {
|
||||
dimension = strings.ToLower(strings.TrimSpace(dimension))
|
||||
switch dimension {
|
||||
case "space", "chat", "topic", "sender":
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[dimension]; ok {
|
||||
continue
|
||||
}
|
||||
seen[dimension] = struct{}{}
|
||||
normalized = append(normalized, dimension)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func cloneIdentityLinks(src map[string][]string) map[string][]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string][]string, len(src))
|
||||
for canonical, ids := range src {
|
||||
dup := make([]string, len(ids))
|
||||
copy(dup, ids)
|
||||
cloned[canonical] = dup
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
type dispatchView struct {
|
||||
Channel string
|
||||
Account string
|
||||
Space string
|
||||
Chat string
|
||||
Topic string
|
||||
Sender string
|
||||
Mentioned bool
|
||||
}
|
||||
|
||||
func (r *RouteResolver) matchDispatchRule(view dispatchView) *config.DispatchRule {
|
||||
if r.cfg == nil || r.cfg.Agents.Dispatch == nil || len(r.cfg.Agents.Dispatch.Rules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range r.cfg.Agents.Dispatch.Rules {
|
||||
rule := &r.cfg.Agents.Dispatch.Rules[i]
|
||||
if !selectorHasAnyConstraint(rule.When) {
|
||||
continue
|
||||
}
|
||||
if ruleMatchesView(*rule, view) {
|
||||
return rule
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ruleMatchesView(rule config.DispatchRule, view dispatchView) bool {
|
||||
when := normalizeDispatchSelector(rule.When)
|
||||
if when.Channel != "" && when.Channel != view.Channel {
|
||||
return false
|
||||
}
|
||||
if when.Account != "" && when.Account != view.Account {
|
||||
return false
|
||||
}
|
||||
if when.Space != "" && when.Space != view.Space {
|
||||
return false
|
||||
}
|
||||
if when.Chat != "" && when.Chat != view.Chat {
|
||||
return false
|
||||
}
|
||||
if when.Topic != "" && when.Topic != view.Topic {
|
||||
return false
|
||||
}
|
||||
if when.Sender != "" && when.Sender != view.Sender {
|
||||
return false
|
||||
}
|
||||
if when.Mentioned != nil && *when.Mentioned != view.Mentioned {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func matchedByForRule(rule *config.DispatchRule) string {
|
||||
if rule == nil {
|
||||
return "default"
|
||||
}
|
||||
name := strings.TrimSpace(rule.Name)
|
||||
if name == "" {
|
||||
return "dispatch.rule"
|
||||
}
|
||||
return "dispatch.rule:" + strings.ToLower(name)
|
||||
}
|
||||
|
||||
func buildDispatchView(inbound bus.InboundContext, identityLinks map[string][]string) dispatchView {
|
||||
view := dispatchView{
|
||||
Channel: strings.ToLower(strings.TrimSpace(inbound.Channel)),
|
||||
Account: NormalizeAccountID(inbound.Account),
|
||||
Mentioned: inbound.Mentioned,
|
||||
}
|
||||
|
||||
if spaceID := strings.TrimSpace(inbound.SpaceID); spaceID != "" {
|
||||
spaceType := strings.ToLower(strings.TrimSpace(inbound.SpaceType))
|
||||
if spaceType == "" {
|
||||
spaceType = "space"
|
||||
}
|
||||
view.Space = fmt.Sprintf("%s:%s", spaceType, strings.ToLower(spaceID))
|
||||
}
|
||||
|
||||
if chatID := strings.TrimSpace(inbound.ChatID); chatID != "" {
|
||||
chatType := strings.ToLower(strings.TrimSpace(inbound.ChatType))
|
||||
if chatType == "" {
|
||||
chatType = "direct"
|
||||
}
|
||||
view.Chat = fmt.Sprintf("%s:%s", chatType, strings.ToLower(chatID))
|
||||
}
|
||||
|
||||
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
|
||||
view.Topic = "topic:" + strings.ToLower(topicID)
|
||||
}
|
||||
|
||||
view.Sender = canonicalDispatchSenderID(inbound.Channel, inbound.SenderID, identityLinks)
|
||||
|
||||
return view
|
||||
}
|
||||
|
||||
func normalizeDispatchSelector(selector config.DispatchSelector) config.DispatchSelector {
|
||||
selector.Channel = strings.ToLower(strings.TrimSpace(selector.Channel))
|
||||
selector.Account = NormalizeAccountID(selector.Account)
|
||||
selector.Space = strings.ToLower(strings.TrimSpace(selector.Space))
|
||||
selector.Chat = strings.ToLower(strings.TrimSpace(selector.Chat))
|
||||
selector.Topic = strings.ToLower(strings.TrimSpace(selector.Topic))
|
||||
selector.Sender = strings.ToLower(strings.TrimSpace(selector.Sender))
|
||||
return selector
|
||||
}
|
||||
|
||||
func selectorHasAnyConstraint(selector config.DispatchSelector) bool {
|
||||
return strings.TrimSpace(selector.Channel) != "" ||
|
||||
strings.TrimSpace(selector.Account) != "" ||
|
||||
strings.TrimSpace(selector.Space) != "" ||
|
||||
strings.TrimSpace(selector.Chat) != "" ||
|
||||
strings.TrimSpace(selector.Topic) != "" ||
|
||||
strings.TrimSpace(selector.Sender) != "" ||
|
||||
selector.Mentioned != nil
|
||||
}
|
||||
|
||||
func canonicalDispatchSenderID(channel, rawID string, identityLinks map[string][]string) string {
|
||||
normalizedID := strings.TrimSpace(rawID)
|
||||
if normalizedID == "" {
|
||||
return ""
|
||||
}
|
||||
if linked := resolveLinkedDispatchID(identityLinks, channel, normalizedID); linked != "" {
|
||||
normalizedID = linked
|
||||
}
|
||||
return strings.ToLower(normalizedID)
|
||||
}
|
||||
|
||||
func resolveLinkedDispatchID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]bool)
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = true
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true
|
||||
}
|
||||
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
|
||||
candidates[rawCandidate[idx+1:]] = true
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonicalName := strings.TrimSpace(canonical)
|
||||
if canonicalName == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized != "" && candidates[normalized] {
|
||||
return canonicalName
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
+126
-190
@@ -3,10 +3,11 @@ package routing
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config {
|
||||
func testConfig(agents []config.AgentConfig) *config.Config {
|
||||
return &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
@@ -15,20 +16,20 @@ func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *co
|
||||
},
|
||||
List: agents,
|
||||
},
|
||||
Bindings: bindings,
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-peer",
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
|
||||
cfg := testConfig(nil, nil)
|
||||
cfg := testConfig(nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
})
|
||||
|
||||
if route.AgentID != DefaultAgentID {
|
||||
@@ -37,202 +38,152 @@ func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
|
||||
if route.MatchedBy != "default" {
|
||||
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
|
||||
}
|
||||
if len(route.SessionPolicy.Dimensions) != 1 || route.SessionPolicy.Dimensions[0] != "sender" {
|
||||
t.Errorf("SessionPolicy.Dimensions = %v, want [sender]", route.SessionPolicy.Dimensions)
|
||||
}
|
||||
if route.SessionPolicy.IdentityLinks != nil {
|
||||
t.Errorf("SessionPolicy.IdentityLinks = %v, want nil", route.SessionPolicy.IdentityLinks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_PeerBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "sales", Default: true},
|
||||
{ID: "support"},
|
||||
func TestResolveRoute_UsesNormalizedInboundContextFields(t *testing.T) {
|
||||
cfg := testConfig([]config.AgentConfig{{ID: "sales", Default: true}})
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "Telegram",
|
||||
Account: "Bot2",
|
||||
ChatType: "direct",
|
||||
SenderID: "user123",
|
||||
})
|
||||
|
||||
if route.AgentID != "sales" {
|
||||
t.Errorf("AgentID = %q, want 'sales'", route.AgentID)
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "support",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
Peer: &config.PeerMatch{Kind: "direct", ID: "user123"},
|
||||
if route.Channel != "telegram" {
|
||||
t.Errorf("Channel = %q, want 'telegram'", route.Channel)
|
||||
}
|
||||
if route.AccountID != "bot2" {
|
||||
t.Errorf("AccountID = %q, want 'bot2'", route.AccountID)
|
||||
}
|
||||
if route.MatchedBy != "default" {
|
||||
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_DispatchFirstMatchWins(t *testing.T) {
|
||||
cfg := testConfig([]config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
{ID: "support"},
|
||||
{ID: "sales"},
|
||||
})
|
||||
cfg.Agents.Dispatch = &config.DispatchConfig{
|
||||
Rules: []config.DispatchRule{
|
||||
{
|
||||
Name: "support-group",
|
||||
Agent: "support",
|
||||
When: config.DispatchSelector{
|
||||
Channel: "telegram",
|
||||
Chat: "group:-100123",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "vip-in-group",
|
||||
Agent: "sales",
|
||||
When: config.DispatchSelector{
|
||||
Channel: "telegram",
|
||||
Chat: "group:-100123",
|
||||
Sender: "12345",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "-100123",
|
||||
ChatType: "group",
|
||||
SenderID: "12345",
|
||||
})
|
||||
|
||||
if route.AgentID != "support" {
|
||||
t.Errorf("AgentID = %q, want 'support'", route.AgentID)
|
||||
t.Fatalf("AgentID = %q, want support", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.peer" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
|
||||
if route.MatchedBy != "dispatch.rule:support-group" {
|
||||
t.Fatalf("MatchedBy = %q, want dispatch.rule:support-group", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_GuildBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "gaming"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "gaming",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
GuildID: "guild-abc",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-abc",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "ch1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "gaming" {
|
||||
t.Errorf("AgentID = %q, want 'gaming'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.guild" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_TeamBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "work"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "work",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "slack",
|
||||
AccountID: "*",
|
||||
TeamID: "T12345",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "slack",
|
||||
TeamID: "T12345",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "C001"},
|
||||
})
|
||||
|
||||
if route.AgentID != "work" {
|
||||
t.Errorf("AgentID = %q, want 'work'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.team" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_AccountBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "default-agent", Default: true},
|
||||
{ID: "premium"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "premium",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "bot2",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
AccountID: "bot2",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "premium" {
|
||||
t.Errorf("AgentID = %q, want 'premium'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.account" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_ChannelWildcard(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
func TestResolveRoute_DispatchOverridesSessionDimensions(t *testing.T) {
|
||||
cfg := testConfig([]config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
{ID: "telegram-bot"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "telegram-bot",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
{ID: "support"},
|
||||
})
|
||||
cfg.Session.Dimensions = []string{"chat"}
|
||||
cfg.Agents.Dispatch = &config.DispatchConfig{
|
||||
Rules: []config.DispatchRule{
|
||||
{
|
||||
Name: "support-dm",
|
||||
Agent: "support",
|
||||
When: config.DispatchSelector{
|
||||
Channel: "telegram",
|
||||
Chat: "direct:user-1",
|
||||
},
|
||||
SessionDimensions: []string{"chat", "sender"},
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "user-1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-1",
|
||||
})
|
||||
|
||||
if route.AgentID != "telegram-bot" {
|
||||
t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID)
|
||||
if route.AgentID != "support" {
|
||||
t.Fatalf("AgentID = %q, want support", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.channel" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy)
|
||||
if got := route.SessionPolicy.Dimensions; len(got) != 2 || got[0] != "chat" || got[1] != "sender" {
|
||||
t.Fatalf("SessionPolicy.Dimensions = %v, want [chat sender]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "vip"},
|
||||
{ID: "gaming"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "vip",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"},
|
||||
},
|
||||
},
|
||||
{
|
||||
AgentID: "gaming",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
GuildID: "guild-1",
|
||||
func TestResolveRoute_DispatchMentionedRule(t *testing.T) {
|
||||
cfg := testConfig([]config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
{ID: "support"},
|
||||
})
|
||||
mentioned := true
|
||||
cfg.Agents.Dispatch = &config.DispatchConfig{
|
||||
Rules: []config.DispatchRule{
|
||||
{
|
||||
Name: "slack-mentions",
|
||||
Agent: "support",
|
||||
When: config.DispatchSelector{
|
||||
Channel: "slack",
|
||||
Space: "workspace:t001",
|
||||
Mentioned: &mentioned,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user-vip"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C123",
|
||||
ChatType: "channel",
|
||||
SpaceID: "T001",
|
||||
SpaceType: "workspace",
|
||||
SenderID: "U123",
|
||||
Mentioned: true,
|
||||
})
|
||||
|
||||
if route.AgentID != "vip" {
|
||||
t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.peer" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
|
||||
if route.AgentID != "support" {
|
||||
t.Fatalf("AgentID = %q, want support", route.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,21 +191,10 @@ func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "nonexistent",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
cfg := testConfig(agents)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
})
|
||||
route := r.ResolveRoute(bus.InboundContext{Channel: "telegram"})
|
||||
|
||||
if route.AgentID != "main" {
|
||||
t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID)
|
||||
@@ -267,12 +207,10 @@ func TestResolveRoute_DefaultAgentSelection(t *testing.T) {
|
||||
{ID: "beta", Default: true},
|
||||
{ID: "gamma"},
|
||||
}
|
||||
cfg := testConfig(agents, nil)
|
||||
cfg := testConfig(agents)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
route := r.ResolveRoute(bus.InboundContext{Channel: "cli"})
|
||||
|
||||
if route.AgentID != "beta" {
|
||||
t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID)
|
||||
@@ -284,12 +222,10 @@ func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) {
|
||||
{ID: "alpha"},
|
||||
{ID: "beta"},
|
||||
}
|
||||
cfg := testConfig(agents, nil)
|
||||
cfg := testConfig(agents)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
route := r.ResolveRoute(bus.InboundContext{Channel: "cli"})
|
||||
|
||||
if route.AgentID != "alpha" {
|
||||
t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID)
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DMScope controls DM session isolation granularity.
|
||||
type DMScope string
|
||||
|
||||
const (
|
||||
DMScopeMain DMScope = "main"
|
||||
DMScopePerPeer DMScope = "per-peer"
|
||||
DMScopePerChannelPeer DMScope = "per-channel-peer"
|
||||
DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer"
|
||||
)
|
||||
|
||||
// RoutePeer represents a chat peer with kind and ID.
|
||||
type RoutePeer struct {
|
||||
Kind string // "direct", "group", "channel"
|
||||
ID string
|
||||
}
|
||||
|
||||
// SessionKeyParams holds all inputs for session key construction.
|
||||
type SessionKeyParams struct {
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
DMScope DMScope
|
||||
IdentityLinks map[string][]string
|
||||
}
|
||||
|
||||
// ParsedSessionKey is the result of parsing an agent-scoped session key.
|
||||
type ParsedSessionKey struct {
|
||||
AgentID string
|
||||
Rest string
|
||||
}
|
||||
|
||||
// BuildAgentMainSessionKey returns "agent:<agentId>:main".
|
||||
func BuildAgentMainSessionKey(agentID string) string {
|
||||
return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey)
|
||||
}
|
||||
|
||||
// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope.
|
||||
func BuildAgentPeerSessionKey(params SessionKeyParams) string {
|
||||
agentID := NormalizeAgentID(params.AgentID)
|
||||
|
||||
peer := params.Peer
|
||||
if peer == nil {
|
||||
peer = &RoutePeer{Kind: "direct"}
|
||||
}
|
||||
peerKind := strings.TrimSpace(peer.Kind)
|
||||
if peerKind == "" {
|
||||
peerKind = "direct"
|
||||
}
|
||||
|
||||
if peerKind == "direct" {
|
||||
dmScope := params.DMScope
|
||||
if dmScope == "" {
|
||||
dmScope = DMScopeMain
|
||||
}
|
||||
peerID := strings.TrimSpace(peer.ID)
|
||||
|
||||
// Resolve identity links (cross-platform collapse)
|
||||
if dmScope != DMScopeMain && peerID != "" {
|
||||
if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" {
|
||||
peerID = linked
|
||||
}
|
||||
}
|
||||
peerID = strings.ToLower(peerID)
|
||||
|
||||
switch dmScope {
|
||||
case DMScopePerAccountChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
accountID := NormalizeAccountID(params.AccountID)
|
||||
return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID)
|
||||
}
|
||||
case DMScopePerChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID)
|
||||
}
|
||||
case DMScopePerPeer:
|
||||
if peerID != "" {
|
||||
return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID)
|
||||
}
|
||||
}
|
||||
return BuildAgentMainSessionKey(agentID)
|
||||
}
|
||||
|
||||
// Group/channel peers always get per-peer sessions
|
||||
channel := normalizeChannel(params.Channel)
|
||||
peerID := strings.ToLower(strings.TrimSpace(peer.ID))
|
||||
if peerID == "" {
|
||||
peerID = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
|
||||
}
|
||||
|
||||
// ParseAgentSessionKey extracts agentId and rest from "agent:<agentId>:<rest>".
|
||||
func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.SplitN(raw, ":", 3)
|
||||
if len(parts) < 3 {
|
||||
return nil
|
||||
}
|
||||
if parts[0] != "agent" {
|
||||
return nil
|
||||
}
|
||||
agentID := strings.TrimSpace(parts[1])
|
||||
rest := parts[2]
|
||||
if agentID == "" || rest == "" {
|
||||
return nil
|
||||
}
|
||||
return &ParsedSessionKey{AgentID: agentID, Rest: rest}
|
||||
}
|
||||
|
||||
// IsSubagentSessionKey returns true if the session key represents a subagent.
|
||||
func IsSubagentSessionKey(sessionKey string) bool {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "subagent:") {
|
||||
return true
|
||||
}
|
||||
parsed := ParseAgentSessionKey(raw)
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:")
|
||||
}
|
||||
|
||||
func normalizeChannel(channel string) string {
|
||||
c := strings.TrimSpace(strings.ToLower(channel))
|
||||
if c == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]bool)
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = true
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
|
||||
candidates[scopedCandidate] = true
|
||||
}
|
||||
|
||||
// If peerID is already in canonical "platform:id" format, also add the
|
||||
// bare ID part as a candidate for backward compatibility with identity_links
|
||||
// that use raw IDs (e.g. "123" instead of "telegram:123").
|
||||
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
|
||||
bareID := rawCandidate[idx+1:]
|
||||
candidates[bareID] = true
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonicalName := strings.TrimSpace(canonical)
|
||||
if canonicalName == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized != "" && candidates[normalized] {
|
||||
return canonicalName
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildAgentMainSessionKey(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("sales")
|
||||
want := "agent:sales:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("Sales Bot")
|
||||
want := "agent:sales-bot:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopeMain,
|
||||
})
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("DMScopeMain = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
AccountID: "bot1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "User123"},
|
||||
DMScope: DMScopePerAccountChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:bot1:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "group", ID: "chat456"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:telegram:group:chat456"
|
||||
if got != want {
|
||||
t.Errorf("GroupPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: nil,
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
// nil peer defaults to direct with empty ID, falls to main
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("NilPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:user123", "discord:john#1234"},
|
||||
}
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
IdentityLinks: links,
|
||||
})
|
||||
want := "agent:main:direct:john"
|
||||
if got != want {
|
||||
t.Errorf("IdentityLink = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) {
|
||||
// When peerID is already in canonical "platform:id" format,
|
||||
// it should match identity_links that use the bare ID.
|
||||
links := map[string][]string{
|
||||
"john": {"123"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "telegram", "telegram:123")
|
||||
if got != "john" {
|
||||
t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) {
|
||||
// When identity_links contain canonical IDs and peerID is canonical too
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:123", "discord:456"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "telegram", "telegram:123")
|
||||
if got != "john" {
|
||||
t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) {
|
||||
// When peerID is bare "123" and links have "telegram:123",
|
||||
// the scoped candidate "telegram:123" should match.
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:123"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "telegram", "123")
|
||||
if got != "john" {
|
||||
t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_NoMatch(t *testing.T) {
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:123"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "discord", "999")
|
||||
if got != "" {
|
||||
t.Errorf("resolveLinkedPeerID no match = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Valid(t *testing.T) {
|
||||
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
|
||||
if parsed == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if parsed.AgentID != "sales" {
|
||||
t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID)
|
||||
}
|
||||
if parsed.Rest != "telegram:direct:user123" {
|
||||
t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Invalid(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"foo:bar",
|
||||
"notprefix:sales:main",
|
||||
"agent::main",
|
||||
"agent:sales:",
|
||||
}
|
||||
for _, input := range tests {
|
||||
if got := ParseAgentSessionKey(input); got != nil {
|
||||
t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubagentSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"subagent:task-1", true},
|
||||
{"agent:main:subagent:task-1", true},
|
||||
{"agent:main:main", false},
|
||||
{"agent:main:telegram:direct:user123", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := IsSubagentSessionKey(tt.input); got != tt.want {
|
||||
t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
)
|
||||
|
||||
// Allocation contains the concrete session keys selected for a routed turn.
|
||||
// The current implementation intentionally preserves the legacy session-key
|
||||
// layout while moving key construction out of the router.
|
||||
type Allocation struct {
|
||||
Scope SessionScope
|
||||
SessionKey string
|
||||
SessionAliases []string
|
||||
MainSessionKey string
|
||||
MainAliases []string
|
||||
}
|
||||
|
||||
// AllocationInput contains the routing result and peer context needed to
|
||||
// derive the session keys for a turn.
|
||||
type AllocationInput struct {
|
||||
AgentID string
|
||||
Context bus.InboundContext
|
||||
SessionPolicy routing.SessionPolicy
|
||||
}
|
||||
|
||||
// AllocateRouteSession maps a route decision onto a structured scope and the
|
||||
// current opaque session-key format.
|
||||
func AllocateRouteSession(input AllocationInput) Allocation {
|
||||
scope := buildSessionScope(input)
|
||||
legacySessionAliases := buildLegacySessionAliases(input)
|
||||
legacyMainSessionKey := strings.ToLower(BuildLegacyMainAlias(input.AgentID))
|
||||
return Allocation{
|
||||
Scope: scope,
|
||||
SessionKey: BuildSessionKey(scope),
|
||||
SessionAliases: legacySessionAliases,
|
||||
MainSessionKey: BuildOpaqueSessionKey(legacyMainSessionKey),
|
||||
MainAliases: []string{legacyMainSessionKey},
|
||||
}
|
||||
}
|
||||
|
||||
func buildSessionScope(input AllocationInput) SessionScope {
|
||||
inbound := input.Context
|
||||
includeTopicInChatDimension := shouldPreserveTelegramForumIsolation(input)
|
||||
scope := SessionScope{
|
||||
Version: ScopeVersionV1,
|
||||
AgentID: routing.NormalizeAgentID(input.AgentID),
|
||||
Channel: strings.ToLower(strings.TrimSpace(inbound.Channel)),
|
||||
Account: routing.NormalizeAccountID(inbound.Account),
|
||||
}
|
||||
if scope.Channel == "" {
|
||||
scope.Channel = "unknown"
|
||||
}
|
||||
|
||||
dimensions := make([]string, 0, len(input.SessionPolicy.Dimensions))
|
||||
values := make(map[string]string, len(input.SessionPolicy.Dimensions))
|
||||
|
||||
for _, dimension := range input.SessionPolicy.Dimensions {
|
||||
switch dimension {
|
||||
case "space":
|
||||
if spaceID := strings.TrimSpace(inbound.SpaceID); spaceID != "" {
|
||||
spaceType := strings.ToLower(strings.TrimSpace(inbound.SpaceType))
|
||||
if spaceType == "" {
|
||||
spaceType = "space"
|
||||
}
|
||||
dimensions = append(dimensions, "space")
|
||||
values["space"] = fmt.Sprintf("%s:%s", spaceType, strings.ToLower(spaceID))
|
||||
}
|
||||
case "chat":
|
||||
chatID := strings.TrimSpace(inbound.ChatID)
|
||||
if chatID == "" {
|
||||
continue
|
||||
}
|
||||
if includeTopicInChatDimension {
|
||||
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
|
||||
chatID = chatID + "/" + topicID
|
||||
}
|
||||
}
|
||||
chatType := strings.ToLower(strings.TrimSpace(inbound.ChatType))
|
||||
if chatType == "" {
|
||||
chatType = "direct"
|
||||
}
|
||||
dimensions = append(dimensions, "chat")
|
||||
values["chat"] = fmt.Sprintf("%s:%s", chatType, strings.ToLower(chatID))
|
||||
case "topic":
|
||||
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
|
||||
dimensions = append(dimensions, "topic")
|
||||
values["topic"] = "topic:" + strings.ToLower(topicID)
|
||||
}
|
||||
case "sender":
|
||||
senderID := CanonicalSessionIdentityID(
|
||||
inbound.Channel,
|
||||
inbound.SenderID,
|
||||
input.SessionPolicy.IdentityLinks,
|
||||
)
|
||||
if senderID == "" {
|
||||
continue
|
||||
}
|
||||
dimensions = append(dimensions, "sender")
|
||||
values["sender"] = senderID
|
||||
}
|
||||
}
|
||||
|
||||
if len(dimensions) > 0 {
|
||||
scope.Dimensions = dimensions
|
||||
scope.Values = values
|
||||
}
|
||||
|
||||
return scope
|
||||
}
|
||||
|
||||
func buildLegacySessionAliases(input AllocationInput) []string {
|
||||
aliases := []string{strings.ToLower(BuildLegacyMainAlias(input.AgentID))}
|
||||
inbound := input.Context
|
||||
|
||||
if strings.EqualFold(strings.TrimSpace(inbound.ChatType), "direct") {
|
||||
peerIDs := buildLegacyDirectPeerIDs(input)
|
||||
if len(peerIDs) == 0 {
|
||||
return uniqueAliases(aliases)
|
||||
}
|
||||
for _, peerID := range peerIDs {
|
||||
aliases = append(
|
||||
aliases,
|
||||
BuildLegacyDirectAliases(input.AgentID, inbound.Channel, inbound.Account, peerID)...,
|
||||
)
|
||||
}
|
||||
return uniqueAliases(aliases)
|
||||
}
|
||||
|
||||
peerID := strings.TrimSpace(inbound.ChatID)
|
||||
if peerID == "" {
|
||||
return uniqueAliases(aliases)
|
||||
}
|
||||
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
|
||||
peerID = peerID + "/" + topicID
|
||||
}
|
||||
aliases = append(aliases, BuildLegacyPeerAlias(
|
||||
input.AgentID,
|
||||
inbound.Channel,
|
||||
strings.ToLower(strings.TrimSpace(inbound.ChatType)),
|
||||
peerID,
|
||||
))
|
||||
|
||||
return uniqueAliases(aliases)
|
||||
}
|
||||
|
||||
func shouldPreserveTelegramForumIsolation(input AllocationInput) bool {
|
||||
inbound := input.Context
|
||||
if !strings.EqualFold(strings.TrimSpace(inbound.Channel), "telegram") {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(inbound.TopicID) == "" {
|
||||
return false
|
||||
}
|
||||
for _, dimension := range input.SessionPolicy.Dimensions {
|
||||
if strings.EqualFold(strings.TrimSpace(dimension), "topic") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func buildLegacyDirectPeerIDs(input AllocationInput) []string {
|
||||
inbound := input.Context
|
||||
peerIDs := make([]string, 0, 3)
|
||||
|
||||
rawSenderID := strings.TrimSpace(inbound.SenderID)
|
||||
if rawSenderID != "" {
|
||||
peerIDs = append(peerIDs, strings.ToLower(rawSenderID))
|
||||
}
|
||||
|
||||
canonicalSenderID := CanonicalSessionIdentityID(
|
||||
inbound.Channel,
|
||||
inbound.SenderID,
|
||||
input.SessionPolicy.IdentityLinks,
|
||||
)
|
||||
if canonicalSenderID != "" {
|
||||
peerIDs = append(peerIDs, canonicalSenderID)
|
||||
}
|
||||
|
||||
chatID := strings.TrimSpace(inbound.ChatID)
|
||||
if chatID != "" {
|
||||
peerIDs = append(peerIDs, strings.ToLower(chatID))
|
||||
}
|
||||
|
||||
return uniqueAliases(peerIDs)
|
||||
}
|
||||
|
||||
func uniqueAliases(aliases []string) []string {
|
||||
if len(aliases) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(aliases))
|
||||
seen := make(map[string]struct{}, len(aliases))
|
||||
for _, alias := range aliases {
|
||||
alias = strings.TrimSpace(strings.ToLower(alias))
|
||||
if alias == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[alias]; ok {
|
||||
continue
|
||||
}
|
||||
seen[alias] = struct{}{}
|
||||
normalized = append(normalized, alias)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
)
|
||||
|
||||
func TestAllocateRouteSession_PerPeerDM(t *testing.T) {
|
||||
allocation := AllocateRouteSession(AllocationInput{
|
||||
AgentID: "main",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
Account: "default",
|
||||
ChatID: "dm-123",
|
||||
ChatType: "direct",
|
||||
SenderID: "User123",
|
||||
},
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
})
|
||||
|
||||
if allocation.SessionKey == "" || !IsOpaqueSessionKey(allocation.SessionKey) {
|
||||
t.Fatalf("SessionKey = %q, want opaque session key", allocation.SessionKey)
|
||||
}
|
||||
if !containsAlias(allocation.SessionAliases, "agent:main:direct:user123") {
|
||||
t.Fatalf("SessionAliases = %v, want to contain agent:main:direct:user123", allocation.SessionAliases)
|
||||
}
|
||||
if allocation.MainSessionKey == "" || !IsOpaqueSessionKey(allocation.MainSessionKey) {
|
||||
t.Fatalf("MainSessionKey = %q, want opaque session key", allocation.MainSessionKey)
|
||||
}
|
||||
if len(allocation.MainAliases) != 1 || allocation.MainAliases[0] != "agent:main:main" {
|
||||
t.Fatalf("MainAliases = %v, want [agent:main:main]", allocation.MainAliases)
|
||||
}
|
||||
if allocation.Scope.Version != ScopeVersionV1 {
|
||||
t.Fatalf("Scope.Version = %d, want %d", allocation.Scope.Version, ScopeVersionV1)
|
||||
}
|
||||
if len(allocation.Scope.Dimensions) != 1 || allocation.Scope.Dimensions[0] != "sender" {
|
||||
t.Fatalf("Scope.Dimensions = %v, want [sender]", allocation.Scope.Dimensions)
|
||||
}
|
||||
if allocation.Scope.Values["sender"] != "user123" {
|
||||
t.Fatalf("Scope.Values[sender] = %q, want user123", allocation.Scope.Values["sender"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateRouteSession_GroupPeer(t *testing.T) {
|
||||
allocation := AllocateRouteSession(AllocationInput{
|
||||
AgentID: "main",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
ChatID: "C001",
|
||||
ChatType: "channel",
|
||||
SenderID: "U001",
|
||||
},
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"chat"},
|
||||
},
|
||||
})
|
||||
|
||||
if allocation.SessionKey == "" || !IsOpaqueSessionKey(allocation.SessionKey) {
|
||||
t.Fatalf("SessionKey = %q, want opaque session key", allocation.SessionKey)
|
||||
}
|
||||
if !containsAlias(allocation.SessionAliases, "agent:main:slack:channel:c001") {
|
||||
t.Fatalf("SessionAliases = %v, want to contain agent:main:slack:channel:c001", allocation.SessionAliases)
|
||||
}
|
||||
if allocation.MainSessionKey == "" || !IsOpaqueSessionKey(allocation.MainSessionKey) {
|
||||
t.Fatalf("MainSessionKey = %q, want opaque session key", allocation.MainSessionKey)
|
||||
}
|
||||
if len(allocation.MainAliases) != 1 || allocation.MainAliases[0] != "agent:main:main" {
|
||||
t.Fatalf("MainAliases = %v, want [agent:main:main]", allocation.MainAliases)
|
||||
}
|
||||
if len(allocation.Scope.Dimensions) != 1 || allocation.Scope.Dimensions[0] != "chat" {
|
||||
t.Fatalf("Scope.Dimensions = %v, want [chat]", allocation.Scope.Dimensions)
|
||||
}
|
||||
if allocation.Scope.Values["chat"] != "channel:c001" {
|
||||
t.Fatalf("Scope.Values[chat] = %q, want channel:c001", allocation.Scope.Values["chat"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateRouteSession_TelegramForumTopicsRemainIsolatedByDefault(t *testing.T) {
|
||||
first := AllocateRouteSession(AllocationInput{
|
||||
AgentID: "main",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "-1001234567890",
|
||||
ChatType: "group",
|
||||
TopicID: "42",
|
||||
SenderID: "7",
|
||||
},
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"chat"},
|
||||
},
|
||||
})
|
||||
second := AllocateRouteSession(AllocationInput{
|
||||
AgentID: "main",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "-1001234567890",
|
||||
ChatType: "group",
|
||||
TopicID: "99",
|
||||
SenderID: "7",
|
||||
},
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"chat"},
|
||||
},
|
||||
})
|
||||
|
||||
if first.SessionKey == second.SessionKey {
|
||||
t.Fatalf("forum topics should not share default session key: %q", first.SessionKey)
|
||||
}
|
||||
if got := first.Scope.Values["chat"]; got != "group:-1001234567890/42" {
|
||||
t.Fatalf("first.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/42")
|
||||
}
|
||||
if got := second.Scope.Values["chat"]; got != "group:-1001234567890/99" {
|
||||
t.Fatalf("second.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/99")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateRouteSession_PicoDirectAliasesIncludeLegacyChatKey(t *testing.T) {
|
||||
allocation := AllocateRouteSession(AllocationInput{
|
||||
AgentID: "main",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "pico",
|
||||
Account: "default",
|
||||
ChatID: "pico:session-123",
|
||||
ChatType: "direct",
|
||||
SenderID: "pico-user",
|
||||
},
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
})
|
||||
|
||||
if !containsAlias(allocation.SessionAliases, "agent:main:pico:direct:pico:session-123") {
|
||||
t.Fatalf("SessionAliases = %v, want pico legacy alias", allocation.SessionAliases)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpaqueSessionKey_IsStable(t *testing.T) {
|
||||
first := BuildOpaqueSessionKey("agent:main:direct:user123")
|
||||
second := BuildOpaqueSessionKey("agent:main:direct:user123")
|
||||
if first != second {
|
||||
t.Fatalf("BuildOpaqueSessionKey() mismatch: %q != %q", first, second)
|
||||
}
|
||||
if !IsOpaqueSessionKey(first) {
|
||||
t.Fatalf("expected opaque session key, got %q", first)
|
||||
}
|
||||
}
|
||||
|
||||
func containsAlias(aliases []string, want string) bool {
|
||||
for _, alias := range aliases {
|
||||
if alias == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -2,7 +2,9 @@ package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -15,24 +17,123 @@ type JSONLBackend struct {
|
||||
store memory.Store
|
||||
}
|
||||
|
||||
type metaAwareStore interface {
|
||||
GetSessionMeta(ctx context.Context, sessionKey string) (memory.SessionMeta, error)
|
||||
UpsertSessionMeta(ctx context.Context, sessionKey string, scope json.RawMessage, aliases []string) error
|
||||
ResolveSessionKey(ctx context.Context, sessionKey string) (string, bool, error)
|
||||
}
|
||||
|
||||
type aliasPromotingStore interface {
|
||||
PromoteAliasHistory(ctx context.Context, sessionKey string, scope json.RawMessage, aliases []string) (bool, error)
|
||||
}
|
||||
|
||||
// MetadataAwareSessionStore exposes structured session metadata operations.
|
||||
type MetadataAwareSessionStore interface {
|
||||
EnsureSessionMetadata(sessionKey string, scope *SessionScope, aliases []string)
|
||||
ResolveSessionKey(sessionKey string) string
|
||||
GetSessionScope(sessionKey string) *SessionScope
|
||||
}
|
||||
|
||||
// NewJSONLBackend wraps a memory.Store for use as a SessionStore.
|
||||
func NewJSONLBackend(store memory.Store) *JSONLBackend {
|
||||
return &JSONLBackend{store: store}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) resolveSessionKey(sessionKey string) string {
|
||||
metaStore, ok := b.store.(metaAwareStore)
|
||||
if !ok {
|
||||
return sessionKey
|
||||
}
|
||||
resolved, found, err := metaStore.ResolveSessionKey(context.Background(), sessionKey)
|
||||
if err != nil {
|
||||
log.Printf("session: resolve session key: %v", err)
|
||||
return sessionKey
|
||||
}
|
||||
if found && resolved != "" {
|
||||
return resolved
|
||||
}
|
||||
return sessionKey
|
||||
}
|
||||
|
||||
// ResolveSessionKey maps aliases onto their canonical session key when the
|
||||
// underlying store supports structured metadata. Unknown aliases fall back to
|
||||
// the original input so existing callers remain compatible.
|
||||
func (b *JSONLBackend) ResolveSessionKey(sessionKey string) string {
|
||||
return b.resolveSessionKey(sessionKey)
|
||||
}
|
||||
|
||||
// EnsureSessionMetadata persists scope and alias metadata for a session.
|
||||
func (b *JSONLBackend) EnsureSessionMetadata(sessionKey string, scope *SessionScope, aliases []string) {
|
||||
metaStore, ok := b.store.(metaAwareStore)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var rawScope json.RawMessage
|
||||
if scope != nil {
|
||||
data, err := json.Marshal(scope)
|
||||
if err != nil {
|
||||
log.Printf("session: encode session scope: %v", err)
|
||||
return
|
||||
}
|
||||
rawScope = data
|
||||
}
|
||||
ctx := context.Background()
|
||||
if err := metaStore.UpsertSessionMeta(ctx, sessionKey, rawScope, aliases); err != nil {
|
||||
log.Printf("session: upsert session metadata: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if promotingStore, ok := b.store.(aliasPromotingStore); ok {
|
||||
if _, err := promotingStore.PromoteAliasHistory(ctx, sessionKey, rawScope, aliases); err != nil {
|
||||
log.Printf("session: promote alias history: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessionScope reads structured scope metadata for a session key or alias.
|
||||
func (b *JSONLBackend) GetSessionScope(sessionKey string) *SessionScope {
|
||||
metaStore, ok := b.store.(metaAwareStore)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
sessionKey = b.resolveSessionKey(sessionKey)
|
||||
meta, err := metaStore.GetSessionMeta(context.Background(), sessionKey)
|
||||
if err != nil {
|
||||
log.Printf("session: get session metadata: %v", err)
|
||||
return nil
|
||||
}
|
||||
if len(meta.Scope) == 0 {
|
||||
return nil
|
||||
}
|
||||
var scope SessionScope
|
||||
if err := json.Unmarshal(meta.Scope, &scope); err != nil {
|
||||
log.Printf("session: decode session scope: %v", err)
|
||||
return nil
|
||||
}
|
||||
return CloneScope(&scope)
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) AddMessage(sessionKey, role, content string) {
|
||||
sessionKey = b.resolveSessionKey(sessionKey)
|
||||
if err := b.store.AddMessage(context.Background(), sessionKey, role, content); err != nil {
|
||||
log.Printf("session: add message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) AddFullMessage(sessionKey string, msg providers.Message) {
|
||||
sessionKey = b.resolveSessionKey(sessionKey)
|
||||
if err := b.store.AddFullMessage(context.Background(), sessionKey, msg); err != nil {
|
||||
log.Printf("session: add full message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) GetHistory(key string) []providers.Message {
|
||||
key = b.resolveSessionKey(key)
|
||||
msgs, err := b.store.GetHistory(context.Background(), key)
|
||||
if err != nil {
|
||||
log.Printf("session: get history: %v", err)
|
||||
@@ -42,6 +143,7 @@ func (b *JSONLBackend) GetHistory(key string) []providers.Message {
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) GetSummary(key string) string {
|
||||
key = b.resolveSessionKey(key)
|
||||
summary, err := b.store.GetSummary(context.Background(), key)
|
||||
if err != nil {
|
||||
log.Printf("session: get summary: %v", err)
|
||||
@@ -51,18 +153,21 @@ func (b *JSONLBackend) GetSummary(key string) string {
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) SetSummary(key, summary string) {
|
||||
key = b.resolveSessionKey(key)
|
||||
if err := b.store.SetSummary(context.Background(), key, summary); err != nil {
|
||||
log.Printf("session: set summary: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) SetHistory(key string, history []providers.Message) {
|
||||
key = b.resolveSessionKey(key)
|
||||
if err := b.store.SetHistory(context.Background(), key, history); err != nil {
|
||||
log.Printf("session: set history: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *JSONLBackend) TruncateHistory(key string, keepLast int) {
|
||||
key = b.resolveSessionKey(key)
|
||||
if err := b.store.TruncateHistory(context.Background(), key, keepLast); err != nil {
|
||||
log.Printf("session: truncate history: %v", err)
|
||||
}
|
||||
@@ -72,6 +177,7 @@ func (b *JSONLBackend) TruncateHistory(key string, keepLast int) {
|
||||
// immediately, the data is already durable. Save runs compaction to reclaim
|
||||
// space from logically truncated messages (no-op when there are none).
|
||||
func (b *JSONLBackend) Save(key string) error {
|
||||
key = b.resolveSessionKey(key)
|
||||
return b.store.Compact(context.Background(), key)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
@@ -177,3 +179,126 @@ func TestJSONLBackend_SummarizeFlow(t *testing.T) {
|
||||
t.Errorf("first message = %q, want %q", history[0].Content, "msg 16")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_ResolveAliasAndPersistMetadata(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
scope := &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Account: "default",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "group:c1",
|
||||
},
|
||||
}
|
||||
b.EnsureSessionMetadata("canonical", scope, []string{"legacy"})
|
||||
|
||||
if got := b.ResolveSessionKey("legacy"); got != "canonical" {
|
||||
t.Fatalf("ResolveSessionKey() = %q, want %q", got, "canonical")
|
||||
}
|
||||
|
||||
b.AddMessage("legacy", "user", "hello through alias")
|
||||
history := b.GetHistory("canonical")
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("len(history) = %d, want 1", len(history))
|
||||
}
|
||||
if history[0].Content != "hello through alias" {
|
||||
t.Fatalf("history[0].Content = %q, want %q", history[0].Content, "hello through alias")
|
||||
}
|
||||
|
||||
resolvedScope := b.GetSessionScope("legacy")
|
||||
if resolvedScope == nil {
|
||||
t.Fatal("GetSessionScope() returned nil")
|
||||
}
|
||||
if resolvedScope.AgentID != scope.AgentID || resolvedScope.Values["chat"] != scope.Values["chat"] {
|
||||
t.Fatalf("GetSessionScope() = %+v, want %+v", resolvedScope, scope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyAliasHistory(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
legacyKey := "agent:main:direct:legacy-user"
|
||||
b.AddMessage(legacyKey, "user", "legacy history")
|
||||
b.SetSummary(legacyKey, "legacy summary")
|
||||
|
||||
canonicalKey := session.BuildOpaqueSessionKey(legacyKey)
|
||||
b.EnsureSessionMetadata(canonicalKey, &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
}, []string{legacyKey})
|
||||
|
||||
if got := b.ResolveSessionKey(legacyKey); got != canonicalKey {
|
||||
t.Fatalf("ResolveSessionKey() = %q, want %q", got, canonicalKey)
|
||||
}
|
||||
history := b.GetHistory(canonicalKey)
|
||||
if len(history) != 1 || history[0].Content != "legacy history" {
|
||||
t.Fatalf("promoted history = %+v", history)
|
||||
}
|
||||
if summary := b.GetSummary(canonicalKey); summary != "legacy summary" {
|
||||
t.Fatalf("promoted summary = %q, want %q", summary, "legacy summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyPicoDirectAliasHistory(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
legacyKey := "agent:main:pico:direct:pico:session-123"
|
||||
b.AddMessage(legacyKey, "user", "legacy pico history")
|
||||
|
||||
scope := &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "pico",
|
||||
Account: "default",
|
||||
Dimensions: []string{"sender"},
|
||||
Values: map[string]string{
|
||||
"sender": "pico-user",
|
||||
},
|
||||
}
|
||||
allocation := session.AllocateRouteSession(session.AllocationInput{
|
||||
AgentID: "main",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "pico",
|
||||
Account: "default",
|
||||
ChatID: "pico:session-123",
|
||||
ChatType: "direct",
|
||||
SenderID: "pico-user",
|
||||
},
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
})
|
||||
|
||||
b.EnsureSessionMetadata(allocation.SessionKey, scope, allocation.SessionAliases)
|
||||
|
||||
if got := b.ResolveSessionKey(legacyKey); got != allocation.SessionKey {
|
||||
t.Fatalf("ResolveSessionKey() = %q, want %q", got, allocation.SessionKey)
|
||||
}
|
||||
history := b.GetHistory(allocation.SessionKey)
|
||||
if len(history) != 1 || history[0].Content != "legacy pico history" {
|
||||
t.Fatalf("promoted history = %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONLBackend_EnsureSessionMetadata_DoesNotOverwriteNonEmptyCanonicalHistory(t *testing.T) {
|
||||
b := newBackend(t)
|
||||
|
||||
canonicalKey := session.BuildOpaqueSessionKey("agent:main:direct:current-user")
|
||||
legacyKey := "agent:main:direct:legacy-user"
|
||||
|
||||
b.AddMessage(canonicalKey, "user", "current canonical history")
|
||||
b.AddMessage(legacyKey, "user", "legacy history")
|
||||
|
||||
b.EnsureSessionMetadata(canonicalKey, &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
}, []string{legacyKey})
|
||||
|
||||
history := b.GetHistory(canonicalKey)
|
||||
if len(history) != 1 || history[0].Content != "current canonical history" {
|
||||
t.Fatalf("canonical history overwritten: %+v", history)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionKeyV1Prefix = "sk_v1_"
|
||||
legacyAgentSessionKeyPrefix = "agent:"
|
||||
)
|
||||
|
||||
type ParsedLegacySessionKey struct {
|
||||
AgentID string
|
||||
Rest string
|
||||
}
|
||||
|
||||
// BuildOpaqueSessionKey returns a stable opaque session key derived from a
|
||||
// canonical alias string. The alias remains available through metadata for
|
||||
// compatibility and migration purposes.
|
||||
func BuildOpaqueSessionKey(alias string) string {
|
||||
normalized := strings.TrimSpace(strings.ToLower(alias))
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
return sessionKeyV1Prefix + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// IsOpaqueSessionKey returns true when the key matches the current opaque
|
||||
// session-key format.
|
||||
func IsOpaqueSessionKey(key string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), sessionKeyV1Prefix)
|
||||
}
|
||||
|
||||
func IsLegacyAgentSessionKey(key string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), legacyAgentSessionKeyPrefix)
|
||||
}
|
||||
|
||||
func IsExplicitSessionKey(key string) bool {
|
||||
return IsOpaqueSessionKey(key) || IsLegacyAgentSessionKey(key)
|
||||
}
|
||||
|
||||
func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.SplitN(raw, ":", 3)
|
||||
if len(parts) < 3 || parts[0] != "agent" {
|
||||
return nil
|
||||
}
|
||||
agentID := strings.TrimSpace(parts[1])
|
||||
rest := parts[2]
|
||||
if agentID == "" || rest == "" {
|
||||
return nil
|
||||
}
|
||||
return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest}
|
||||
}
|
||||
|
||||
// ResolveAgentID returns the routed agent ID associated with a session. It
|
||||
// prefers structured session scope metadata when available and falls back to
|
||||
// legacy agent-scoped session keys for compatibility.
|
||||
func ResolveAgentID(store any, sessionKey string) string {
|
||||
if scopeReader, ok := store.(interface {
|
||||
GetSessionScope(sessionKey string) *SessionScope
|
||||
}); ok {
|
||||
scope := scopeReader.GetSessionScope(sessionKey)
|
||||
if scope != nil && strings.TrimSpace(scope.AgentID) != "" {
|
||||
return routing.NormalizeAgentID(scope.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
if parsed := ParseLegacyAgentSessionKey(sessionKey); parsed != nil {
|
||||
return routing.NormalizeAgentID(parsed.AgentID)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func BuildLegacyMainAlias(agentID string) string {
|
||||
return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID))
|
||||
}
|
||||
|
||||
// BuildMainSessionKey returns the canonical opaque main-session key for an
|
||||
// agent. The corresponding legacy alias remains available via
|
||||
// BuildLegacyMainAlias for compatibility and migration logic.
|
||||
func BuildMainSessionKey(agentID string) string {
|
||||
return BuildOpaqueSessionKey(BuildLegacyMainAlias(agentID))
|
||||
}
|
||||
|
||||
func BuildLegacyDirectAliases(agentID, channel, account, peerID string) []string {
|
||||
agentID = routing.NormalizeAgentID(agentID)
|
||||
channel = normalizeLegacyChannel(channel)
|
||||
account = routing.NormalizeAccountID(account)
|
||||
peerID = strings.ToLower(strings.TrimSpace(peerID))
|
||||
if peerID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{
|
||||
fmt.Sprintf("agent:%s:direct:%s", agentID, peerID),
|
||||
fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID),
|
||||
fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, account, peerID),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildLegacyPeerAlias(agentID, channel, peerKind, peerID string) string {
|
||||
agentID = routing.NormalizeAgentID(agentID)
|
||||
channel = normalizeLegacyChannel(channel)
|
||||
peerKind = strings.ToLower(strings.TrimSpace(peerKind))
|
||||
if peerKind == "" {
|
||||
peerKind = "unknown"
|
||||
}
|
||||
peerID = strings.ToLower(strings.TrimSpace(peerID))
|
||||
if peerID == "" {
|
||||
peerID = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
|
||||
}
|
||||
|
||||
// CanonicalSessionIdentityID collapses an identity using identity_links when
|
||||
// possible, then returns a normalized lowercase identifier.
|
||||
func CanonicalSessionIdentityID(channel, rawID string, identityLinks map[string][]string) string {
|
||||
normalizedID := strings.TrimSpace(rawID)
|
||||
if normalizedID == "" {
|
||||
return ""
|
||||
}
|
||||
if linked := resolveLinkedPeerID(identityLinks, channel, normalizedID); linked != "" {
|
||||
normalizedID = linked
|
||||
}
|
||||
return strings.ToLower(normalizedID)
|
||||
}
|
||||
|
||||
func normalizeLegacyChannel(channel string) string {
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return channel
|
||||
}
|
||||
|
||||
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]bool)
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = true
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true
|
||||
}
|
||||
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
|
||||
candidates[rawCandidate[idx+1:]] = true
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonicalName := strings.TrimSpace(canonical)
|
||||
if canonicalName == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized != "" && candidates[normalized] {
|
||||
return canonicalName
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CanonicalScopeSignature returns a stable serialized representation of scope.
|
||||
func CanonicalScopeSignature(scope SessionScope) string {
|
||||
parts := []string{
|
||||
fmt.Sprintf("v=%d", scope.Version),
|
||||
fmt.Sprintf("agent=%s", strings.TrimSpace(strings.ToLower(scope.AgentID))),
|
||||
fmt.Sprintf("channel=%s", strings.TrimSpace(strings.ToLower(scope.Channel))),
|
||||
fmt.Sprintf("account=%s", strings.TrimSpace(strings.ToLower(scope.Account))),
|
||||
}
|
||||
for _, dimension := range scope.Dimensions {
|
||||
dimension = strings.TrimSpace(strings.ToLower(dimension))
|
||||
if dimension == "" {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(strings.ToLower(scope.Values[dimension]))
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", dimension, value))
|
||||
}
|
||||
return strings.Join(parts, "|")
|
||||
}
|
||||
|
||||
// BuildSessionKey returns the current opaque key for a structured session scope.
|
||||
func BuildSessionKey(scope SessionScope) string {
|
||||
return BuildOpaqueSessionKey(CanonicalScopeSignature(scope))
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package session
|
||||
|
||||
import "testing"
|
||||
|
||||
type testScopeReader struct {
|
||||
scope *SessionScope
|
||||
}
|
||||
|
||||
func (r testScopeReader) GetSessionScope(sessionKey string) *SessionScope {
|
||||
return CloneScope(r.scope)
|
||||
}
|
||||
|
||||
func TestIsExplicitSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
want bool
|
||||
}{
|
||||
{"sk_v1_abc", true},
|
||||
{"agent:main:direct:user123", true},
|
||||
{"custom-key", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := IsExplicitSessionKey(tt.key); got != tt.want {
|
||||
t.Fatalf("IsExplicitSessionKey(%q) = %v, want %v", tt.key, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLegacyAgentSessionKey(t *testing.T) {
|
||||
parsed := ParseLegacyAgentSessionKey("agent:sales:telegram:direct:user123")
|
||||
if parsed == nil {
|
||||
t.Fatal("expected parsed legacy key, got nil")
|
||||
}
|
||||
if parsed.AgentID != "sales" {
|
||||
t.Fatalf("AgentID = %q, want sales", parsed.AgentID)
|
||||
}
|
||||
if parsed.Rest != "telegram:direct:user123" {
|
||||
t.Fatalf("Rest = %q, want telegram:direct:user123", parsed.Rest)
|
||||
}
|
||||
|
||||
if got := ParseLegacyAgentSessionKey("sk_v1_abc"); got != nil {
|
||||
t.Fatalf("expected nil for opaque key, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLegacyDirectAliases(t *testing.T) {
|
||||
aliases := BuildLegacyDirectAliases("Main", "Telegram", "BotA", "User123")
|
||||
want := []string{
|
||||
"agent:main:direct:user123",
|
||||
"agent:main:telegram:direct:user123",
|
||||
"agent:main:telegram:bota:direct:user123",
|
||||
}
|
||||
if len(aliases) != len(want) {
|
||||
t.Fatalf("len(aliases) = %d, want %d", len(aliases), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if aliases[i] != want[i] {
|
||||
t.Fatalf("aliases[%d] = %q, want %q", i, aliases[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLegacyPeerAlias(t *testing.T) {
|
||||
got := BuildLegacyPeerAlias("Main", "Slack", "channel", "C001")
|
||||
if got != "agent:main:slack:channel:c001" {
|
||||
t.Fatalf("BuildLegacyPeerAlias() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMainSessionKey(t *testing.T) {
|
||||
got := BuildMainSessionKey("Main")
|
||||
if !IsOpaqueSessionKey(got) {
|
||||
t.Fatalf("BuildMainSessionKey() = %q, want opaque key", got)
|
||||
}
|
||||
if got != BuildOpaqueSessionKey("agent:main:main") {
|
||||
t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAgentID_PrefersSessionScope(t *testing.T) {
|
||||
store := testScopeReader{
|
||||
scope: &SessionScope{
|
||||
Version: ScopeVersionV1,
|
||||
AgentID: "Support",
|
||||
Channel: "slack",
|
||||
},
|
||||
}
|
||||
|
||||
if got := ResolveAgentID(store, "sk_v1_anything"); got != "support" {
|
||||
t.Fatalf("ResolveAgentID() = %q, want support", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAgentID_FallsBackToLegacyKey(t *testing.T) {
|
||||
if got := ResolveAgentID(nil, "agent:Sales:telegram:direct:user123"); got != "sales" {
|
||||
t.Fatalf("ResolveAgentID() = %q, want sales", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package session
|
||||
|
||||
// ScopeVersionV1 is the first structured session-scope schema version.
|
||||
const ScopeVersionV1 = 1
|
||||
|
||||
// SessionScope describes the semantic session partition selected for a turn.
|
||||
type SessionScope struct {
|
||||
Version int `json:"version"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Channel string `json:"channel"`
|
||||
Account string `json:"account"`
|
||||
Dimensions []string `json:"dimensions"`
|
||||
Values map[string]string `json:"values"`
|
||||
}
|
||||
|
||||
// CloneScope returns a deep copy of scope.
|
||||
func CloneScope(scope *SessionScope) *SessionScope {
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *scope
|
||||
if len(scope.Dimensions) > 0 {
|
||||
cloned.Dimensions = append([]string(nil), scope.Dimensions...)
|
||||
}
|
||||
if len(scope.Values) > 0 {
|
||||
cloned.Values = make(map[string]string, len(scope.Values))
|
||||
for key, value := range scope.Values {
|
||||
cloned.Values[key] = value
|
||||
}
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
+38
-1
@@ -1,6 +1,10 @@
|
||||
package tools
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// Tool is the interface that all tools must implement.
|
||||
type Tool interface {
|
||||
@@ -25,6 +29,9 @@ var (
|
||||
ctxKeyChatID = &toolCtxKey{"chatID"}
|
||||
ctxKeyMessageID = &toolCtxKey{"messageID"}
|
||||
ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"}
|
||||
ctxKeyAgentID = &toolCtxKey{"agentID"}
|
||||
ctxKeySessionKey = &toolCtxKey{"sessionKey"}
|
||||
ctxKeySessionScope = &toolCtxKey{"sessionScope"}
|
||||
)
|
||||
|
||||
// WithToolContext returns a child context carrying channel and chatID.
|
||||
@@ -51,6 +58,18 @@ func WithToolInboundContext(
|
||||
return ctx
|
||||
}
|
||||
|
||||
// WithToolSessionContext returns a child context carrying turn-scoped session metadata.
|
||||
func WithToolSessionContext(
|
||||
ctx context.Context,
|
||||
agentID, sessionKey string,
|
||||
scope *session.SessionScope,
|
||||
) context.Context {
|
||||
ctx = context.WithValue(ctx, ctxKeyAgentID, agentID)
|
||||
ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey)
|
||||
ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope))
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ToolChannel extracts the channel from ctx, or "" if unset.
|
||||
func ToolChannel(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyChannel).(string)
|
||||
@@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string {
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset.
|
||||
func ToolAgentID(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyAgentID).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset.
|
||||
func ToolSessionKey(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeySessionKey).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolSessionScope extracts the active turn's structured session scope from ctx.
|
||||
func ToolSessionScope(ctx context.Context) *session.SessionScope {
|
||||
scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope)
|
||||
return session.CloneScope(scope)
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
// When an async tool finishes its work, it calls this callback with the result.
|
||||
//
|
||||
|
||||
+2
-4
@@ -311,8 +311,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channel, chatID, ""),
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
@@ -335,8 +334,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channel, chatID, ""),
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SendCallback func(channel, chatID, content, replyToMessageID string) error
|
||||
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
|
||||
|
||||
// sentTarget records the channel+chatID that the message tool sent to.
|
||||
type sentTarget struct {
|
||||
@@ -15,7 +15,7 @@ type sentTarget struct {
|
||||
}
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
sendCallback SendCallbackWithContext
|
||||
mu sync.Mutex
|
||||
sentTargets []sentTarget // Tracks all targets sent to in the current round
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func (t *MessageTool) HasSentTo(channel, chatID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil {
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
|
||||
@@ -4,16 +4,22 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
|
||||
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
|
||||
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
@@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
@@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentReplyTo string
|
||||
tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentReplyTo = replyToMessageID
|
||||
return nil
|
||||
})
|
||||
@@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var gotAgentID, gotSessionKey string
|
||||
var gotScope *session.SessionScope
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
gotAgentID = ToolAgentID(ctx)
|
||||
gotSessionKey = ToolSessionKey(ctx)
|
||||
gotScope = ToolSessionScope(ctx)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "direct:test-chat-id",
|
||||
},
|
||||
})
|
||||
|
||||
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotAgentID != "main" {
|
||||
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
|
||||
}
|
||||
if gotSessionKey != "sk_v1_tool" {
|
||||
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
|
||||
}
|
||||
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
|
||||
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
|
||||
}
|
||||
}
|
||||
|
||||
+276
-117
@@ -13,7 +13,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -49,26 +51,12 @@ type sessionChatMessage struct {
|
||||
Media []string `json:"media,omitempty"`
|
||||
}
|
||||
|
||||
type sessionMetaFile struct {
|
||||
Key string `json:"key"`
|
||||
Summary string `json:"summary"`
|
||||
Skip int `json:"skip"`
|
||||
Count int `json:"count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// picoSessionPrefix is the key prefix used by the gateway's routing for Pico
|
||||
// channel sessions. The full key format is:
|
||||
//
|
||||
// agent:main:pico:direct:pico:<session-uuid>
|
||||
//
|
||||
// The sanitized filename replaces ':' with '_', so on disk it becomes:
|
||||
//
|
||||
// agent_main_pico_direct_pico_<session-uuid>.json
|
||||
// legacyPicoSessionPrefix is the legacy key prefix used by older Pico JSON/JSONL
|
||||
// sessions before structured scope metadata existed.
|
||||
const (
|
||||
picoSessionPrefix = "agent:main:pico:direct:pico:"
|
||||
sanitizedPicoSessionPrefix = "agent_main_pico_direct_pico_"
|
||||
legacyPicoSessionPrefix = "agent:main:pico:direct:pico:"
|
||||
picoSessionPrefix = legacyPicoSessionPrefix
|
||||
|
||||
// Keep the session API aligned with the shared JSONL store reader limit in
|
||||
// pkg/memory/jsonl.go so oversized lines fail consistently everywhere.
|
||||
maxSessionJSONLLineSize = 10 * 1024 * 1024
|
||||
@@ -82,28 +70,23 @@ func defaultToolFeedbackMaxArgsLength() int {
|
||||
return defaults.GetToolFeedbackMaxArgsLength()
|
||||
}
|
||||
|
||||
// extractPicoSessionID extracts the session UUID from a full session key.
|
||||
// extractLegacyPicoSessionID extracts the session UUID from an old Pico key.
|
||||
// Returns the UUID and true if the key matches the Pico session pattern.
|
||||
func extractPicoSessionID(key string) (string, bool) {
|
||||
if strings.HasPrefix(key, picoSessionPrefix) {
|
||||
return strings.TrimPrefix(key, picoSessionPrefix), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func extractPicoSessionIDFromSanitizedKey(key string) (string, bool) {
|
||||
if strings.HasPrefix(key, sanitizedPicoSessionPrefix) {
|
||||
return strings.TrimPrefix(key, sanitizedPicoSessionPrefix), true
|
||||
func extractLegacyPicoSessionID(key string) (string, bool) {
|
||||
if strings.HasPrefix(key, legacyPicoSessionPrefix) {
|
||||
return strings.TrimPrefix(key, legacyPicoSessionPrefix), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func sanitizeSessionKey(key string) string {
|
||||
return strings.ReplaceAll(key, ":", "_")
|
||||
key = strings.ReplaceAll(key, ":", "_")
|
||||
key = strings.ReplaceAll(key, "/", "_")
|
||||
key = strings.ReplaceAll(key, "\\", "_")
|
||||
return key
|
||||
}
|
||||
|
||||
func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error) {
|
||||
path := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID)+".json")
|
||||
func (h *Handler) readLegacySession(path string) (sessionFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return sessionFile{}, err
|
||||
@@ -116,18 +99,18 @@ func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (h *Handler) readSessionMeta(path, sessionKey string) (sessionMetaFile, error) {
|
||||
func (h *Handler) readSessionMeta(path, sessionKey string) (memory.SessionMeta, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if os.IsNotExist(err) {
|
||||
return sessionMetaFile{Key: sessionKey}, nil
|
||||
return memory.SessionMeta{Key: sessionKey}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return sessionMetaFile{}, err
|
||||
return memory.SessionMeta{}, err
|
||||
}
|
||||
|
||||
var meta sessionMetaFile
|
||||
var meta memory.SessionMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
return sessionMetaFile{}, err
|
||||
return memory.SessionMeta{}, err
|
||||
}
|
||||
if meta.Key == "" {
|
||||
meta.Key = sessionKey
|
||||
@@ -170,8 +153,7 @@ func (h *Handler) readSessionMessages(path string, skip int) ([]providers.Messag
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) {
|
||||
sessionKey := picoSessionPrefix + sessionID
|
||||
func (h *Handler) readJSONLSession(dir, sessionKey string) (sessionFile, error) {
|
||||
base := filepath.Join(dir, sanitizeSessionKey(sessionKey))
|
||||
jsonlPath := base + ".jsonl"
|
||||
metaPath := base + ".meta.json"
|
||||
@@ -208,6 +190,213 @@ func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
type picoJSONLSessionRef struct {
|
||||
ID string
|
||||
Key string
|
||||
}
|
||||
|
||||
type picoLegacySessionRef struct {
|
||||
ID string
|
||||
Path string
|
||||
}
|
||||
|
||||
func extractPicoSessionIDFromScope(scope session.SessionScope) (string, bool) {
|
||||
if !strings.EqualFold(strings.TrimSpace(scope.Channel), "pico") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
strings.TrimSpace(scope.Values["sender"]),
|
||||
strings.TrimSpace(scope.Values["chat"]),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(candidate, "pico:"); idx >= 0 {
|
||||
sessionID := strings.TrimSpace(candidate[idx+len("pico:"):])
|
||||
if sessionID != "" {
|
||||
return sessionID, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func sessionRefFromMeta(meta memory.SessionMeta) (picoJSONLSessionRef, bool) {
|
||||
if len(meta.Scope) == 0 {
|
||||
if sessionID, ok := extractLegacyPicoSessionID(meta.Key); ok {
|
||||
return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true
|
||||
}
|
||||
for _, alias := range meta.Aliases {
|
||||
if sessionID, ok := extractLegacyPicoSessionID(alias); ok {
|
||||
return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true
|
||||
}
|
||||
}
|
||||
return picoJSONLSessionRef{}, false
|
||||
}
|
||||
var scope session.SessionScope
|
||||
if err := json.Unmarshal(meta.Scope, &scope); err != nil {
|
||||
return picoJSONLSessionRef{}, false
|
||||
}
|
||||
sessionID, ok := extractPicoSessionIDFromScope(scope)
|
||||
if !ok {
|
||||
if legacySessionID, ok := extractLegacyPicoSessionID(meta.Key); ok {
|
||||
return picoJSONLSessionRef{ID: legacySessionID, Key: meta.Key}, true
|
||||
}
|
||||
for _, alias := range meta.Aliases {
|
||||
if legacySessionID, ok := extractLegacyPicoSessionID(alias); ok {
|
||||
return picoJSONLSessionRef{ID: legacySessionID, Key: meta.Key}, true
|
||||
}
|
||||
}
|
||||
return picoJSONLSessionRef{}, false
|
||||
}
|
||||
return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true
|
||||
}
|
||||
|
||||
func (h *Handler) findPicoJSONLSessions(dir string) ([]picoJSONLSessionRef, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refs := make([]picoJSONLSessionRef, 0)
|
||||
seen := make(map[string]struct{})
|
||||
metaBackedBases := make(map[string]struct{})
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
metaPath := filepath.Join(dir, name)
|
||||
meta, err := h.readSessionMeta(metaPath, "")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ref, ok := sessionRefFromMeta(meta)
|
||||
if !ok || ref.Key == "" || ref.ID == "" {
|
||||
continue
|
||||
}
|
||||
metaBackedBases[strings.TrimSuffix(name, ".meta.json")] = struct{}{}
|
||||
if _, exists := seen[ref.ID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[ref.ID] = struct{}{}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
base := strings.TrimSuffix(name, ".jsonl")
|
||||
if _, ok := metaBackedBases[base]; ok {
|
||||
continue
|
||||
}
|
||||
ref, ok := jsonlSessionRefFromFilename(name)
|
||||
if !ok || ref.Key == "" || ref.ID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[ref.ID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[ref.ID] = struct{}{}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
return refs, nil
|
||||
}
|
||||
|
||||
func (h *Handler) findPicoJSONLSession(dir, sessionID string) (picoJSONLSessionRef, error) {
|
||||
refs, err := h.findPicoJSONLSessions(dir)
|
||||
if err != nil {
|
||||
return picoJSONLSessionRef{}, err
|
||||
}
|
||||
for _, ref := range refs {
|
||||
if ref.ID == sessionID {
|
||||
return ref, nil
|
||||
}
|
||||
}
|
||||
return picoJSONLSessionRef{}, os.ErrNotExist
|
||||
}
|
||||
|
||||
func (h *Handler) findLegacyPicoSessions(dir string) ([]picoLegacySessionRef, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refs := make([]picoLegacySessionRef, 0)
|
||||
seen := make(map[string]struct{})
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() || filepath.Ext(name) != ".json" || strings.HasSuffix(name, ".meta.json") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
sess, err := h.readLegacySession(path)
|
||||
if err != nil || isEmptySession(sess) {
|
||||
continue
|
||||
}
|
||||
|
||||
sessionID, ok := extractLegacyPicoSessionID(sess.Key)
|
||||
if !ok || sessionID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[sessionID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[sessionID] = struct{}{}
|
||||
refs = append(refs, picoLegacySessionRef{ID: sessionID, Path: path})
|
||||
}
|
||||
return refs, nil
|
||||
}
|
||||
|
||||
func jsonlSessionRefFromFilename(name string) (picoJSONLSessionRef, bool) {
|
||||
if !strings.HasSuffix(name, ".jsonl") {
|
||||
return picoJSONLSessionRef{}, false
|
||||
}
|
||||
base := strings.TrimSuffix(name, ".jsonl")
|
||||
if base == "" {
|
||||
return picoJSONLSessionRef{}, false
|
||||
}
|
||||
|
||||
legacyPrefix := sanitizeSessionKey(legacyPicoSessionPrefix)
|
||||
if strings.HasPrefix(base, legacyPrefix) {
|
||||
sessionID := strings.TrimPrefix(base, legacyPrefix)
|
||||
if sessionID == "" {
|
||||
return picoJSONLSessionRef{}, false
|
||||
}
|
||||
return picoJSONLSessionRef{
|
||||
ID: sessionID,
|
||||
Key: legacyPicoSessionPrefix + sessionID,
|
||||
}, true
|
||||
}
|
||||
|
||||
if session.IsOpaqueSessionKey(base) {
|
||||
return picoJSONLSessionRef{
|
||||
ID: base,
|
||||
Key: base,
|
||||
}, true
|
||||
}
|
||||
|
||||
return picoJSONLSessionRef{}, false
|
||||
}
|
||||
|
||||
func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessionRef, error) {
|
||||
refs, err := h.findLegacyPicoSessions(dir)
|
||||
if err != nil {
|
||||
return picoLegacySessionRef{}, err
|
||||
}
|
||||
for _, ref := range refs {
|
||||
if ref.ID == sessionID {
|
||||
return ref, nil
|
||||
}
|
||||
}
|
||||
return picoLegacySessionRef{}, os.ErrNotExist
|
||||
}
|
||||
|
||||
func buildSessionListItem(sessionID string, sess sessionFile, toolFeedbackMaxArgsLength int) sessionListItem {
|
||||
preview := ""
|
||||
for _, msg := range sess.Messages {
|
||||
@@ -458,8 +647,7 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if _, err := os.ReadDir(dir); err != nil {
|
||||
// Directory doesn't exist yet = no sessions
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode([]sessionListItem{})
|
||||
@@ -469,74 +657,29 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
items := []sessionListItem{}
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
if refs, findErr := h.findPicoJSONLSessions(dir); findErr == nil {
|
||||
for _, ref := range refs {
|
||||
sess, loadErr := h.readJSONLSession(dir, ref.Key)
|
||||
if loadErr != nil || isEmptySession(sess) {
|
||||
continue
|
||||
}
|
||||
seen[ref.ID] = struct{}{}
|
||||
items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength))
|
||||
}
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
var (
|
||||
sessionID string
|
||||
sess sessionFile
|
||||
loadErr error
|
||||
ok bool
|
||||
)
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".jsonl"):
|
||||
sessionID, ok = extractPicoSessionIDFromSanitizedKey(strings.TrimSuffix(name, ".jsonl"))
|
||||
if !ok {
|
||||
if legacyRefs, findErr := h.findLegacyPicoSessions(dir); findErr == nil {
|
||||
for _, ref := range legacyRefs {
|
||||
if _, exists := seen[ref.ID]; exists {
|
||||
continue
|
||||
}
|
||||
sess, loadErr = h.readJSONLSession(dir, sessionID)
|
||||
if loadErr == nil && isEmptySession(sess) {
|
||||
sess, loadErr := h.readLegacySession(ref.Path)
|
||||
if loadErr != nil || isEmptySession(sess) {
|
||||
continue
|
||||
}
|
||||
case strings.HasSuffix(name, ".meta.json"):
|
||||
continue
|
||||
case filepath.Ext(name) == ".json":
|
||||
base := strings.TrimSuffix(name, ".json")
|
||||
if _, statErr := os.Stat(filepath.Join(dir, base+".jsonl")); statErr == nil {
|
||||
if jsonlSessionID, found := extractPicoSessionIDFromSanitizedKey(base); found {
|
||||
if jsonlSess, jsonlErr := h.readJSONLSession(
|
||||
dir,
|
||||
jsonlSessionID,
|
||||
); jsonlErr == nil &&
|
||||
!isEmptySession(jsonlSess) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, name))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
continue
|
||||
}
|
||||
if isEmptySession(sess) {
|
||||
continue
|
||||
}
|
||||
sessionID, ok = extractPicoSessionID(sess.Key)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[sessionID]; exists {
|
||||
continue
|
||||
}
|
||||
default:
|
||||
continue
|
||||
seen[ref.ID] = struct{}{}
|
||||
items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength))
|
||||
}
|
||||
|
||||
if loadErr != nil {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[sessionID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[sessionID] = struct{}{}
|
||||
items = append(items, buildSessionListItem(sessionID, sess, toolFeedbackMaxArgsLength))
|
||||
}
|
||||
|
||||
// Sort by updated descending (most recent first)
|
||||
@@ -590,13 +733,20 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sess, err := h.readJSONLSession(dir, sessionID)
|
||||
ref, refErr := h.findPicoJSONLSession(dir, sessionID)
|
||||
var sess sessionFile
|
||||
err = refErr
|
||||
if refErr == nil {
|
||||
sess, err = h.readJSONLSession(dir, ref.Key)
|
||||
}
|
||||
if err == nil && isEmptySession(sess) {
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
sess, err = h.readLegacySession(dir, sessionID)
|
||||
if legacyRef, legacyErr := h.findLegacyPicoSession(dir, sessionID); legacyErr == nil {
|
||||
sess, err = h.readLegacySession(legacyRef.Path)
|
||||
}
|
||||
if err == nil && isEmptySession(sess) {
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
@@ -639,21 +789,30 @@ func (h *Handler) handleDeleteSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID))
|
||||
jsonlPath := base + ".jsonl"
|
||||
metaPath := base + ".meta.json"
|
||||
legacyPath := base + ".json"
|
||||
|
||||
removed := false
|
||||
for _, path := range []string{jsonlPath, metaPath, legacyPath} {
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
if ref, err := h.findPicoJSONLSession(dir, sessionID); err == nil {
|
||||
base := filepath.Join(dir, sanitizeSessionKey(ref.Key))
|
||||
for _, path := range []string{base + ".jsonl", base + ".meta.json"} {
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
http.Error(w, "failed to delete session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Error(w, "failed to delete session", http.StatusInternalServerError)
|
||||
return
|
||||
removed = true
|
||||
}
|
||||
}
|
||||
|
||||
if legacyRef, err := h.findLegacyPicoSession(dir, sessionID); err == nil {
|
||||
if err := os.Remove(legacyRef.Path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
http.Error(w, "failed to delete session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
removed = true
|
||||
}
|
||||
removed = true
|
||||
}
|
||||
|
||||
if !removed {
|
||||
|
||||
+166
-12
@@ -36,12 +36,12 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
store, storeErr := memory.NewJSONLStore(dir)
|
||||
if storeErr != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", storeErr)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "history-jsonl"
|
||||
sessionKey := legacyPicoSessionPrefix + "history-jsonl"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "Explain why the history API is empty after migration.",
|
||||
@@ -106,12 +106,12 @@ func TestHandleListSessions_TitleUsesFirstUserMessage(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
store, storeErr := memory.NewJSONLStore(dir)
|
||||
if storeErr != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", storeErr)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "summary-title"
|
||||
sessionKey := legacyPicoSessionPrefix + "summary-title"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "fallback preview",
|
||||
@@ -164,7 +164,7 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-jsonl"
|
||||
sessionKey := legacyPicoSessionPrefix + "detail-jsonl"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "first"},
|
||||
{Role: "assistant", Content: "second"},
|
||||
@@ -218,6 +218,81 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, storeErr := memory.NewJSONLStore(dir)
|
||||
if storeErr != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", storeErr)
|
||||
}
|
||||
|
||||
sessionKey := "sk_v1_scope_discovery"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "scope discovered session",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
if err := store.SetSummary(nil, sessionKey, "scope summary"); err != nil {
|
||||
t.Fatalf("SetSummary() error = %v", err)
|
||||
}
|
||||
|
||||
scopeData, err := json.Marshal(session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "pico",
|
||||
Account: "default",
|
||||
Dimensions: []string{"sender"},
|
||||
Values: map[string]string{
|
||||
"sender": "pico:scope-jsonl",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal(scope) error = %v", err)
|
||||
}
|
||||
if err := store.UpsertSessionMeta(nil, sessionKey, scopeData, nil); err != nil {
|
||||
t.Fatalf("UpsertSessionMeta() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(listRec, listReq)
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal(list) error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].ID != "scope-jsonl" {
|
||||
t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "scope-jsonl")
|
||||
}
|
||||
|
||||
detailRec := httptest.NewRecorder()
|
||||
detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/scope-jsonl", nil)
|
||||
mux.ServeHTTP(detailRec, detailReq)
|
||||
if detailRec.Code != http.StatusOK {
|
||||
t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String())
|
||||
}
|
||||
|
||||
deleteRec := httptest.NewRecorder()
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/api/sessions/scope-jsonl", nil)
|
||||
mux.ServeHTTP(deleteRec, deleteReq)
|
||||
if deleteRec.Code != http.StatusNoContent {
|
||||
t.Fatalf("delete status = %d, want %d, body=%s", deleteRec.Code, http.StatusNoContent, deleteRec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
@@ -784,7 +859,7 @@ func TestHandleDeleteSession_JSONLStorage(t *testing.T) {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "delete-jsonl"
|
||||
sessionKey := legacyPicoSessionPrefix + "delete-jsonl"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "delete me",
|
||||
@@ -821,7 +896,7 @@ func TestHandleGetSession_LegacyJSONFallback(t *testing.T) {
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
manager := session.NewSessionManager(dir)
|
||||
sessionKey := picoSessionPrefix + "legacy-json"
|
||||
sessionKey := legacyPicoSessionPrefix + "legacy-json"
|
||||
manager.AddMessage(sessionKey, "user", "legacy user")
|
||||
manager.AddMessage(sessionKey, "assistant", "legacy assistant")
|
||||
if err := manager.Save(sessionKey); err != nil {
|
||||
@@ -846,7 +921,7 @@ func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+"empty-jsonl"))
|
||||
base := filepath.Join(dir, sanitizeSessionKey(legacyPicoSessionPrefix+"empty-jsonl"))
|
||||
if err := os.WriteFile(base+".jsonl", []byte{}, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(jsonl) error = %v", err)
|
||||
}
|
||||
@@ -879,3 +954,82 @@ func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) {
|
||||
t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusNotFound, detailRec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSessions_ListsLegacyJSONLWithoutMeta(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
sessionKey := legacyPicoSessionPrefix + "missing-meta"
|
||||
base := filepath.Join(dir, sanitizeSessionKey(sessionKey))
|
||||
line, err := json.Marshal(providers.Message{Role: "user", Content: "recover me"})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal(message) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(base+".jsonl", append(line, '\n'), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(jsonl) error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(listRec, listReq)
|
||||
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal(list) error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].ID != "missing-meta" {
|
||||
t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "missing-meta")
|
||||
}
|
||||
|
||||
detailRec := httptest.NewRecorder()
|
||||
detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/missing-meta", nil)
|
||||
mux.ServeHTTP(detailRec, detailReq)
|
||||
|
||||
if detailRec.Code != http.StatusOK {
|
||||
t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSessions_IgnoresMetaJSONInLegacyFallback(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
metaOnly := filepath.Join(dir, "agent_main_pico_direct_pico_meta-only.meta.json")
|
||||
metaOnlyContent := []byte(`{"key":"agent:main:pico:direct:pico:meta-only","summary":"meta only"}`)
|
||||
if err := os.WriteFile(metaOnly, metaOnlyContent, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(meta) error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(listRec, listReq)
|
||||
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal(list) error = %v", err)
|
||||
}
|
||||
if len(items) != 0 {
|
||||
t.Fatalf("len(items) = %d, want 0", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user