From c639e2c21677aaff50796d2b68af3183d6d61fb5 Mon Sep 17 00:00:00 2001 From: Alix-007 Date: Tue, 17 Mar 2026 23:31:56 +0800 Subject: [PATCH] feat(agent): include current sender in dynamic context (#1696) * feat(agent): include current sender in dynamic context * test(agent): keep current-sender regression ASCII-only --------- Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com> --- pkg/agent/context.go | 25 +++++++++-- pkg/agent/context_cache_test.go | 68 ++++++++++++++++++++++++++++-- pkg/agent/loop.go | 42 ++++++++++-------- pkg/agent/loop_test.go | 75 +++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 24 deletions(-) diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 5a84c45e2..830edf875 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -458,7 +458,23 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string { // // See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching // See: https://platform.openai.com/docs/guides/prompt-caching -func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string { +func formatCurrentSenderLine(senderID, senderDisplayName string) string { + senderID = strings.TrimSpace(senderID) + senderDisplayName = strings.TrimSpace(senderDisplayName) + + switch { + case senderDisplayName != "" && senderID != "": + return fmt.Sprintf("Current sender: %s (ID: %s)", senderDisplayName, senderID) + case senderDisplayName != "": + return fmt.Sprintf("Current sender: %s", senderDisplayName) + case senderID != "": + return fmt.Sprintf("Current sender: %s", senderID) + default: + return "" + } +} + +func (cb *ContextBuilder) buildDynamicContext(channel, chatID, senderID, senderDisplayName string) string { now := time.Now().Format("2006-01-02 15:04 (Monday)") rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version()) @@ -468,6 +484,9 @@ func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string { if channel != "" && chatID != "" { fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID) } + if senderLine := formatCurrentSenderLine(senderID, senderDisplayName); senderLine != "" { + fmt.Fprintf(&sb, "\n\n## Current Sender\n%s", senderLine) + } return sb.String() } @@ -477,7 +496,7 @@ func (cb *ContextBuilder) BuildMessages( summary string, currentMessage string, media []string, - channel, chatID string, + channel, chatID, senderID, senderDisplayName string, ) []providers.Message { messages := []providers.Message{} @@ -493,7 +512,7 @@ func (cb *ContextBuilder) BuildMessages( staticPrompt := cb.BuildSystemPromptWithCache() // Build short dynamic context (time, runtime, session) — changes per request - dynamicCtx := cb.buildDynamicContext(channel, chatID) + dynamicCtx := cb.buildDynamicContext(channel, chatID, senderID, senderDisplayName) // Compose a single system message: static (cached) + dynamic + optional summary. // Keeping all system content in one message ensures every provider adapter can diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go index 707510820..c26976c3c 100644 --- a/pkg/agent/context_cache_test.go +++ b/pkg/agent/context_cache_test.go @@ -82,7 +82,7 @@ func TestSingleSystemMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1") + msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "") systemCount := 0 for _, m := range msgs { @@ -126,6 +126,68 @@ func TestSingleSystemMessage(t *testing.T) { } } +func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "IDENTITY.md": "# Identity\nTest agent.", + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + + tests := []struct { + name string + senderID string + senderDisplayName string + wantLine string + wantSection bool + }{ + { + name: "both id and display name", + senderID: "feishu:ou_xxx", + senderDisplayName: "Zhang San", + wantLine: "Current sender: Zhang San (ID: feishu:ou_xxx)", + wantSection: true, + }, + { + name: "display name only", + senderDisplayName: "Alice", + wantLine: "Current sender: Alice", + wantSection: true, + }, + { + name: "id only", + senderID: "discord:123", + wantLine: "Current sender: discord:123", + wantSection: true, + }, + { + name: "no sender info", + wantSection: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName) + sys := msgs[0].Content + + if tt.wantSection { + if !strings.Contains(sys, "## Current Sender") { + t.Fatalf("system prompt missing Current Sender section:\n%s", sys) + } + if !strings.Contains(sys, tt.wantLine) { + t.Fatalf("system prompt missing sender line %q:\n%s", tt.wantLine, sys) + } + return + } + + if strings.Contains(sys, "## Current Sender") { + t.Fatalf("system prompt should omit Current Sender section:\n%s", sys) + } + }) + } +} + // TestMtimeAutoInvalidation verifies that the cache detects source file changes // via mtime without requiring explicit InvalidateCache(). // Fix: original implementation had no auto-invalidation — edits to bootstrap files, @@ -576,7 +638,7 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) { } // Also exercise BuildMessages concurrently - msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat") + msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat", "", "") if len(msgs) < 2 { errs <- "BuildMessages returned fewer than 2 messages" return @@ -664,6 +726,6 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test") + _ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test", "", "") } } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index c25650201..00c9d913a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -55,15 +55,17 @@ type AgentLoop struct { // processOptions configures how a message is processed type processOptions struct { - SessionKey string // Session identifier for history/context - Channel string // Target channel for tool execution - ChatID string // Target chat ID for tool execution - UserMessage string // User message content (may include prefix) - Media []string // media:// refs from inbound message - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - NoHistory bool // If true, don't load session history (for heartbeat) + SessionKey string // Session identifier for history/context + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + SenderID string // Current sender ID for dynamic context + SenderDisplayName string // Current sender display name for dynamic context + UserMessage string // User message content (may include prefix) + Media []string // media:// refs from inbound message + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) } const ( @@ -746,14 +748,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) }) opts := processOptions{ - SessionKey: sessionKey, - Channel: msg.Channel, - ChatID: msg.ChatID, - UserMessage: msg.Content, - Media: msg.Media, - DefaultResponse: defaultResponse, - EnableSummary: true, - SendResponse: false, + SessionKey: sessionKey, + Channel: msg.Channel, + ChatID: msg.ChatID, + SenderID: msg.SenderID, + SenderDisplayName: msg.Sender.DisplayName, + UserMessage: msg.Content, + Media: msg.Media, + DefaultResponse: defaultResponse, + EnableSummary: true, + SendResponse: false, } // context-dependent commands check their own Runtime fields and report @@ -893,6 +897,8 @@ func (al *AgentLoop) runAgentLoop( opts.Media, opts.Channel, opts.ChatID, + opts.SenderID, + opts.SenderDisplayName, ) // Resolve media:// refs: images→base64 data URLs, non-images→local paths in content @@ -1164,7 +1170,7 @@ func (al *AgentLoop) runLLMIteration( newSummary := agent.Sessions.GetSummary(opts.SessionKey) messages = agent.ContextBuilder.BuildMessages( newHistory, newSummary, "", - nil, opts.Channel, opts.ChatID, + nil, opts.Channel, opts.ChatID, opts.SenderID, opts.SenderDisplayName, ) continue } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index a6604e87f..47c378771 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -30,6 +30,28 @@ func (f *fakeChannel) IsAllowed(string) bool { func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } func (f *fakeChannel) ReasoningChannelID() string { return f.id } +type recordingProvider struct { + lastMessages []providers.Message +} + +func (r *recordingProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + r.lastMessages = append([]providers.Message(nil), messages...) + return &providers.LLMResponse{ + Content: "Mock response", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (r *recordingProvider) GetDefaultModel() string { + return "mock-model" +} + func newTestAgentLoop( t *testing.T, ) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) { @@ -54,6 +76,59 @@ func newTestAgentLoop( return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) } } +func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &recordingProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + response, err := al.processMessage(context.Background(), bus.InboundMessage{ + Channel: "discord", + SenderID: "discord:123", + Sender: bus.SenderInfo{ + DisplayName: "Alice", + }, + ChatID: "group-1", + Content: "hello", + }) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response != "Mock response" { + t.Fatalf("processMessage() response = %q, want %q", response, "Mock response") + } + if len(provider.lastMessages) == 0 { + t.Fatal("provider did not receive any messages") + } + + systemPrompt := provider.lastMessages[0].Content + wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)" + if !strings.Contains(systemPrompt, wantSender) { + t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt) + } + + lastMessage := provider.lastMessages[len(provider.lastMessages)-1] + if lastMessage.Role != "user" || lastMessage.Content != "hello" { + t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage) + } +} + func TestRecordLastChannel(t *testing.T) { al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) defer cleanup()