mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'upstream-main' into feat/subturn-poc
This commit is contained in:
+24
-5
@@ -52,7 +52,7 @@ func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuil
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
@@ -65,7 +65,7 @@ func getGlobalConfigDir() string {
|
||||
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
// builtin skills: skills directory in current project
|
||||
// Use the skills/ directory under the current working directory
|
||||
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
|
||||
builtinSkillsDir := strings.TrimSpace(os.Getenv(config.EnvBuiltinSkills))
|
||||
if builtinSkillsDir == "" {
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir = filepath.Join(wd, "skills")
|
||||
@@ -458,7 +458,23 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
//
|
||||
// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
// See: https://platform.openai.com/docs/guides/prompt-caching
|
||||
func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
|
||||
func formatCurrentSenderLine(senderID, senderDisplayName string) string {
|
||||
senderID = strings.TrimSpace(senderID)
|
||||
senderDisplayName = strings.TrimSpace(senderDisplayName)
|
||||
|
||||
switch {
|
||||
case senderDisplayName != "" && senderID != "":
|
||||
return fmt.Sprintf("Current sender: %s (ID: %s)", senderDisplayName, senderID)
|
||||
case senderDisplayName != "":
|
||||
return fmt.Sprintf("Current sender: %s", senderDisplayName)
|
||||
case senderID != "":
|
||||
return fmt.Sprintf("Current sender: %s", senderID)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) buildDynamicContext(channel, chatID, senderID, senderDisplayName string) string {
|
||||
now := time.Now().Format("2006-01-02 15:04 (Monday)")
|
||||
rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
|
||||
|
||||
@@ -468,6 +484,9 @@ func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
|
||||
if channel != "" && chatID != "" {
|
||||
fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
|
||||
}
|
||||
if senderLine := formatCurrentSenderLine(senderID, senderDisplayName); senderLine != "" {
|
||||
fmt.Fprintf(&sb, "\n\n## Current Sender\n%s", senderLine)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -477,7 +496,7 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
summary string,
|
||||
currentMessage string,
|
||||
media []string,
|
||||
channel, chatID string,
|
||||
channel, chatID, senderID, senderDisplayName string,
|
||||
) []providers.Message {
|
||||
messages := []providers.Message{}
|
||||
|
||||
@@ -493,7 +512,7 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
staticPrompt := cb.BuildSystemPromptWithCache()
|
||||
|
||||
// Build short dynamic context (time, runtime, session) — changes per request
|
||||
dynamicCtx := cb.buildDynamicContext(channel, chatID)
|
||||
dynamicCtx := cb.buildDynamicContext(channel, chatID, senderID, senderDisplayName)
|
||||
|
||||
// Compose a single system message: static (cached) + dynamic + optional summary.
|
||||
// Keeping all system content in one message ensures every provider adapter can
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1")
|
||||
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
|
||||
|
||||
systemCount := 0
|
||||
for _, m := range msgs {
|
||||
@@ -126,6 +126,68 @@ func TestSingleSystemMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nTest agent.",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
senderID string
|
||||
senderDisplayName string
|
||||
wantLine string
|
||||
wantSection bool
|
||||
}{
|
||||
{
|
||||
name: "both id and display name",
|
||||
senderID: "feishu:ou_xxx",
|
||||
senderDisplayName: "Zhang San",
|
||||
wantLine: "Current sender: Zhang San (ID: feishu:ou_xxx)",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "display name only",
|
||||
senderDisplayName: "Alice",
|
||||
wantLine: "Current sender: Alice",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "id only",
|
||||
senderID: "discord:123",
|
||||
wantLine: "Current sender: discord:123",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "no sender info",
|
||||
wantSection: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
|
||||
sys := msgs[0].Content
|
||||
|
||||
if tt.wantSection {
|
||||
if !strings.Contains(sys, "## Current Sender") {
|
||||
t.Fatalf("system prompt missing Current Sender section:\n%s", sys)
|
||||
}
|
||||
if !strings.Contains(sys, tt.wantLine) {
|
||||
t.Fatalf("system prompt missing sender line %q:\n%s", tt.wantLine, sys)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(sys, "## Current Sender") {
|
||||
t.Fatalf("system prompt should omit Current Sender section:\n%s", sys)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMtimeAutoInvalidation verifies that the cache detects source file changes
|
||||
// via mtime without requiring explicit InvalidateCache().
|
||||
// Fix: original implementation had no auto-invalidation — edits to bootstrap files,
|
||||
@@ -576,7 +638,7 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// Also exercise BuildMessages concurrently
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat")
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat", "", "")
|
||||
if len(msgs) < 2 {
|
||||
errs <- "BuildMessages returned fewer than 2 messages"
|
||||
return
|
||||
@@ -664,6 +726,6 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test")
|
||||
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test", "", "")
|
||||
}
|
||||
}
|
||||
|
||||
+72
-15
@@ -61,6 +61,8 @@ type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
SenderID string // Current sender ID for dynamic context
|
||||
SenderDisplayName string // Current sender display name for dynamic context
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
@@ -166,7 +168,12 @@ func registerSharedTools(
|
||||
}
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("web_fetch") {
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(
|
||||
50000,
|
||||
cfg.Tools.Web.Proxy,
|
||||
cfg.Tools.Web.Format,
|
||||
cfg.Tools.Web.FetchLimitBytes,
|
||||
cfg.Tools.Web.PrivateHostWhitelist)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
@@ -338,10 +345,9 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
case msg, ok := <-al.bus.InboundChan():
|
||||
if !ok {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start a goroutine that drains the bus while processMessage is
|
||||
@@ -408,6 +414,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
}()
|
||||
default:
|
||||
time.Sleep(time.Microsecond * 200)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,9 +427,15 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
|
||||
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
|
||||
for {
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
var msg bus.InboundMessage
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case m, ok := <-al.bus.InboundChan():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
msg = m
|
||||
}
|
||||
|
||||
// Transcribe audio if needed before steering, so the agent sees text.
|
||||
@@ -861,14 +875,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
})
|
||||
|
||||
opts := processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
SenderID: msg.SenderID,
|
||||
SenderDisplayName: msg.Sender.DisplayName,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
}
|
||||
|
||||
// context-dependent commands check their own Runtime fields and report
|
||||
@@ -1039,6 +1055,8 @@ func (al *AgentLoop) runAgentLoop(
|
||||
opts.Media,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
opts.SenderID,
|
||||
opts.SenderDisplayName,
|
||||
)
|
||||
|
||||
// Resolve media:// refs: images→base64 data URLs, non-images→local paths in content
|
||||
@@ -1256,6 +1274,19 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Build tool definitions
|
||||
providerToolDefs := agent.Tools.ToProviderDefs()
|
||||
|
||||
// Determine whether the provider's native web search should replace
|
||||
// the client-side web_search tool for this request. Only enable when web
|
||||
// search is actually enabled and registered (so users who disabled web
|
||||
// access do not get provider-side search or billing).
|
||||
_, hasWebSearch := agent.Tools.Get("web_search")
|
||||
useNativeSearch := al.cfg.Tools.Web.PreferNative &&
|
||||
isNativeSearchProvider(agent.Provider) &&
|
||||
hasWebSearch
|
||||
|
||||
if useNativeSearch {
|
||||
providerToolDefs = filterClientWebSearch(providerToolDefs)
|
||||
}
|
||||
|
||||
// Log LLM request details
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
@@ -1264,6 +1295,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"model": activeModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"native_search": useNativeSearch,
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"system_prompt_len": len(messages[0].Content),
|
||||
@@ -1286,6 +1318,9 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
}
|
||||
if useNativeSearch {
|
||||
llmOpts["native_search"] = true
|
||||
}
|
||||
// parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
|
||||
// so checking != ThinkingOff is sufficient.
|
||||
if agent.ThinkingLevel != ThinkingOff {
|
||||
@@ -1387,7 +1422,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
|
||||
messages = agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
nil, opts.Channel, opts.ChatID,
|
||||
nil, opts.Channel, opts.ChatID, opts.SenderID, opts.SenderDisplayName,
|
||||
)
|
||||
continue
|
||||
}
|
||||
@@ -2246,6 +2281,28 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
|
||||
}
|
||||
|
||||
// isNativeSearchProvider reports whether the given LLM provider implements
|
||||
// NativeSearchCapable and returns true for SupportsNativeSearch.
|
||||
func isNativeSearchProvider(p providers.LLMProvider) bool {
|
||||
if ns, ok := p.(providers.NativeSearchCapable); ok {
|
||||
return ns.SupportsNativeSearch()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterClientWebSearch returns a copy of tools with the client-side
|
||||
// web_search tool removed. Used when native provider search is preferred.
|
||||
func filterClientWebSearch(tools []providers.ToolDefinition) []providers.ToolDefinition {
|
||||
result := make([]providers.ToolDefinition, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
result = append(result, t)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper to extract provider from registry for cleanup
|
||||
func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) {
|
||||
if registry == nil {
|
||||
|
||||
+228
-39
@@ -30,6 +30,28 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
type recordingProvider struct {
|
||||
lastMessages []providers.Message
|
||||
}
|
||||
|
||||
func (r *recordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastMessages = append([]providers.Message(nil), messages...)
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *recordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
@@ -54,6 +76,59 @@ func newTestAgentLoop(
|
||||
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "discord",
|
||||
SenderID: "discord:123",
|
||||
Sender: bus.SenderInfo{
|
||||
DisplayName: "Alice",
|
||||
},
|
||||
ChatID: "group-1",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(provider.lastMessages) == 0 {
|
||||
t.Fatal("provider did not receive any messages")
|
||||
}
|
||||
|
||||
systemPrompt := provider.lastMessages[0].Content
|
||||
wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)"
|
||||
if !strings.Contains(systemPrompt, wantSender) {
|
||||
t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt)
|
||||
}
|
||||
|
||||
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
||||
if lastMessage.Role != "user" || lastMessage.Content != "hello" {
|
||||
t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
@@ -922,10 +997,25 @@ func TestHandleReasoning(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
}
|
||||
if msg.Content == "reasoning" {
|
||||
t.Fatalf("expected no message for empty chatID, got %+v", msg)
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
t.Log("expected an outbound message, got none within timeout")
|
||||
return
|
||||
default:
|
||||
// Continue to check for message
|
||||
time.Sleep(5 * time.Millisecond) // Avoid busy loop
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -933,9 +1023,7 @@ func TestHandleReasoning(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
msg, ok := <-msgBus.OutboundChan()
|
||||
if !ok {
|
||||
t.Fatal("expected an outbound message")
|
||||
}
|
||||
@@ -949,35 +1037,52 @@ func TestHandleReasoning(t *testing.T) {
|
||||
reasoning := "hello telegram reasoning"
|
||||
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("expected an outbound message, got none within timeout")
|
||||
return
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
|
||||
if msg.Channel != "telegram" {
|
||||
t.Fatalf("expected telegram channel message, got %+v", msg)
|
||||
}
|
||||
if msg.ChatID != "tg-chat" {
|
||||
t.Fatalf("expected chatID tg-chat, got %+v", msg)
|
||||
}
|
||||
if msg.Content != reasoning {
|
||||
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
|
||||
if msg.Channel != "telegram" {
|
||||
t.Fatalf("expected telegram channel message, got %+v", msg)
|
||||
}
|
||||
if msg.ChatID != "tg-chat" {
|
||||
t.Fatalf("expected chatID tg-chat, got %+v", msg)
|
||||
}
|
||||
if msg.Content != reasoning {
|
||||
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("expired ctx", func(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
reasoning := "hello telegram reasoning"
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
|
||||
|
||||
consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer consumeCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
t.Fatalf("expected no outbound message, but received: %+v", msg)
|
||||
}
|
||||
t.Logf("Received unexpected outbound message: %+v", msg)
|
||||
return
|
||||
case <-consumeCtx.Done():
|
||||
t.Fatalf("failed: no message received within timeout")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1017,20 +1122,23 @@ func TestHandleReasoning(t *testing.T) {
|
||||
|
||||
// Drain the bus and verify the reasoning message was NOT published
|
||||
// (it should have been dropped due to timeout).
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer drainCancel()
|
||||
foundReasoning := false
|
||||
timeer := time.After(1 * time.Second)
|
||||
for {
|
||||
msg, ok := msgBus.SubscribeOutbound(drainCtx)
|
||||
if !ok {
|
||||
break
|
||||
select {
|
||||
case <-timeer:
|
||||
t.Logf(
|
||||
"no reasoning message received after draining bus for 1s, as expected,length=%d",
|
||||
len(msgBus.OutboundChan()),
|
||||
)
|
||||
return
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if msg.Content == "should timeout" {
|
||||
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
|
||||
}
|
||||
}
|
||||
if msg.Content == "should timeout" {
|
||||
foundReasoning = true
|
||||
}
|
||||
}
|
||||
if foundReasoning {
|
||||
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1318,3 +1426,84 @@ func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) {
|
||||
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Native search helper tests ---
|
||||
|
||||
type nativeSearchProvider struct {
|
||||
supported bool
|
||||
}
|
||||
|
||||
func (p *nativeSearchProvider) Chat(
|
||||
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
|
||||
model string, opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{Content: "ok"}, nil
|
||||
}
|
||||
|
||||
func (p *nativeSearchProvider) GetDefaultModel() string { return "test-model" }
|
||||
|
||||
func (p *nativeSearchProvider) SupportsNativeSearch() bool { return p.supported }
|
||||
|
||||
type plainProvider struct{}
|
||||
|
||||
func (p *plainProvider) Chat(
|
||||
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
|
||||
model string, opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{Content: "ok"}, nil
|
||||
}
|
||||
|
||||
func (p *plainProvider) GetDefaultModel() string { return "test-model" }
|
||||
|
||||
func TestIsNativeSearchProvider_Supported(t *testing.T) {
|
||||
if !isNativeSearchProvider(&nativeSearchProvider{supported: true}) {
|
||||
t.Fatal("expected true for provider that supports native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNativeSearchProvider_NotSupported(t *testing.T) {
|
||||
if isNativeSearchProvider(&nativeSearchProvider{supported: false}) {
|
||||
t.Fatal("expected false for provider that does not support native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNativeSearchProvider_NoInterface(t *testing.T) {
|
||||
if isNativeSearchProvider(&plainProvider{}) {
|
||||
t.Fatal("expected false for provider that does not implement NativeSearchCapable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterClientWebSearch_RemovesWebSearch(t *testing.T) {
|
||||
defs := []providers.ToolDefinition{
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "web_search"}},
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
|
||||
}
|
||||
result := filterClientWebSearch(defs)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
for _, td := range result {
|
||||
if td.Function.Name == "web_search" {
|
||||
t.Fatal("web_search should be filtered out")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterClientWebSearch_NoWebSearch(t *testing.T) {
|
||||
defs := []providers.ToolDefinition{
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
|
||||
}
|
||||
result := filterClientWebSearch(defs)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
|
||||
result := filterClientWebSearch(nil)
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("len(result) = %d, want 0", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user