mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
chore: resolve conflicts with upstream/main
This commit is contained in:
+41
-1
@@ -92,7 +92,47 @@ func NewAgentInstance(
|
||||
Primary: model,
|
||||
Fallbacks: fallbacks,
|
||||
}
|
||||
candidates := providers.ResolveCandidates(modelCfg, defaults.Provider)
|
||||
resolveFromModelList := func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if cfg != nil {
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
|
||||
@@ -93,3 +93,77 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
|
||||
t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.7)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "step-3.5-flash",
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "step-3.5-flash",
|
||||
Model: "openrouter/stepfun/step-3.5-flash:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openrouter" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
|
||||
}
|
||||
if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "glm-5",
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "glm-5",
|
||||
Model: "glm-5",
|
||||
APIBase: "https://api.z.ai/api/coding/paas/v4",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
|
||||
|
||||
if len(agent.Candidates) != 1 {
|
||||
t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
|
||||
}
|
||||
if agent.Candidates[0].Provider != "openai" {
|
||||
t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
|
||||
}
|
||||
if agent.Candidates[0].Model != "glm-5" {
|
||||
t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
|
||||
}
|
||||
}
|
||||
|
||||
+181
-35
@@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
@@ -38,6 +40,7 @@ type AgentLoop struct {
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -52,6 +55,8 @@ type processOptions struct {
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
@@ -119,12 +124,13 @@ func registerSharedTools(
|
||||
// Message tool
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
agent.Tools.Register(messageTool)
|
||||
|
||||
@@ -165,33 +171,61 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
}
|
||||
// Process message
|
||||
func() {
|
||||
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
|
||||
// Currently disabled because files are deleted before the LLM can access their content.
|
||||
// defer func() {
|
||||
// if al.mediaStore != nil && msg.MediaScope != "" {
|
||||
// if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil {
|
||||
// logger.WarnCF("agent", "Failed to release media", map[string]any{
|
||||
// "scope": msg.MediaScope,
|
||||
// "error": releaseErr.Error(),
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
|
||||
if response != "" {
|
||||
// Check if the message tool already sent a response during this round.
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
// Use default agent's tools to check (message tool is shared).
|
||||
alreadySent := false
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent != nil {
|
||||
if tool, ok := defaultAgent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
// Check if the message tool already sent a response during this round.
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
// Use default agent's tools to check (message tool is shared).
|
||||
alreadySent := false
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent != nil {
|
||||
if tool, ok := defaultAgent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !alreadySent {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
if !alreadySent {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
logger.InfoCF("agent", "Published outbound response",
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
"content_len": len(response),
|
||||
})
|
||||
} else {
|
||||
logger.DebugCF(
|
||||
"agent",
|
||||
"Skipped outbound (message tool already sent)",
|
||||
map[string]any{"channel": msg.Channel},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,6 +248,41 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
al.channelManager = cm
|
||||
}
|
||||
|
||||
// SetMediaStore injects a MediaStore for media lifecycle management.
|
||||
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
|
||||
al.mediaStore = s
|
||||
}
|
||||
|
||||
// inferMediaType determines the media type ("image", "audio", "video", "file")
|
||||
// from a filename and MIME content type.
|
||||
func inferMediaType(filename, contentType string) string {
|
||||
ct := strings.ToLower(contentType)
|
||||
fn := strings.ToLower(filename)
|
||||
|
||||
if strings.HasPrefix(ct, "image/") {
|
||||
return "image"
|
||||
}
|
||||
if strings.HasPrefix(ct, "audio/") || ct == "application/ogg" {
|
||||
return "audio"
|
||||
}
|
||||
if strings.HasPrefix(ct, "video/") {
|
||||
return "video"
|
||||
}
|
||||
|
||||
// Fallback: infer from extension
|
||||
ext := filepath.Ext(fn)
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg":
|
||||
return "image"
|
||||
case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus":
|
||||
return "audio"
|
||||
case ".mp4", ".avi", ".mov", ".webm", ".mkv":
|
||||
return "video"
|
||||
}
|
||||
|
||||
return "file"
|
||||
}
|
||||
|
||||
// RecordLastChannel records the last active channel for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChannel(channel string) error {
|
||||
@@ -255,12 +324,15 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
}
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: "heartbeat",
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
UserMessage: content,
|
||||
DefaultResponse: "I've completed processing but have no response to give.",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
NoHistory: true, // Don't load session history for heartbeat
|
||||
@@ -307,6 +379,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
if !ok {
|
||||
agent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
|
||||
}
|
||||
|
||||
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(msg.Channel, msg.ChatID)
|
||||
}
|
||||
}
|
||||
|
||||
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
|
||||
sessionKey := route.SessionKey
|
||||
@@ -326,7 +408,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
DefaultResponse: "I've completed processing but have no response to give.",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
})
|
||||
@@ -373,6 +455,9 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
|
||||
// Use default agent for system messages
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for system message")
|
||||
}
|
||||
|
||||
// Use the origin session for context
|
||||
sessionKey := routing.BuildAgentMainSessionKey(agent.ID)
|
||||
@@ -448,7 +533,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
if opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: finalContent,
|
||||
@@ -468,6 +553,34 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
return finalContent, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) {
|
||||
if al.channelManager == nil {
|
||||
return ""
|
||||
}
|
||||
if ch, ok := al.channelManager.GetChannel(channelName); ok {
|
||||
return ch.ReasoningChannelID()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
|
||||
if reasoningContent == "" || channelName == "" || channelID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Check context cancellation before attempting to publish,
|
||||
// since PublishOutbound's select may race between send and ctx.Done().
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: channelName,
|
||||
ChatID: channelID,
|
||||
Content: reasoningContent,
|
||||
})
|
||||
}
|
||||
|
||||
// runLLMIteration executes the LLM call loop with tool handling.
|
||||
func (al *AgentLoop) runLLMIteration(
|
||||
ctx context.Context,
|
||||
@@ -565,7 +678,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
})
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: "Context window exceeded. Compressing history and retrying...",
|
||||
@@ -594,6 +707,18 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
|
||||
|
||||
logger.DebugCF("agent", "LLM response",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_chars": len(response.Content),
|
||||
"tool_calls": len(response.ToolCalls),
|
||||
"reasoning": response.Reasoning,
|
||||
"target_channel": al.targetReasoningChannelID(opts.Channel),
|
||||
"channel": opts.Channel,
|
||||
})
|
||||
// Check if no tool calls - we're done
|
||||
if len(response.ToolCalls) == 0 {
|
||||
finalContent = response.Content
|
||||
@@ -695,7 +820,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: toolResult.ForUser,
|
||||
@@ -707,6 +832,28 @@ func (al *AgentLoop) runLLMIteration(
|
||||
})
|
||||
}
|
||||
|
||||
// If tool returned media refs, publish them as outbound media
|
||||
if len(toolResult.Media) > 0 && opts.SendResponse {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
// Populate metadata from MediaStore when available
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
@@ -1119,21 +1266,20 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
|
||||
return "", false
|
||||
}
|
||||
|
||||
// extractPeer extracts the routing peer from inbound message metadata.
|
||||
// extractPeer extracts the routing peer from the inbound message's structured Peer field.
|
||||
func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
peerKind := msg.Metadata["peer_kind"]
|
||||
if peerKind == "" {
|
||||
if msg.Peer.Kind == "" {
|
||||
return nil
|
||||
}
|
||||
peerID := msg.Metadata["peer_id"]
|
||||
peerID := msg.Peer.ID
|
||||
if peerID == "" {
|
||||
if peerKind == "direct" {
|
||||
if msg.Peer.Kind == "direct" {
|
||||
peerID = msg.SenderID
|
||||
} else {
|
||||
peerID = msg.ChatID
|
||||
}
|
||||
}
|
||||
return &routing.RoutePeer{Kind: peerKind, ID: peerID}
|
||||
return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID}
|
||||
}
|
||||
|
||||
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
|
||||
|
||||
@@ -10,11 +10,23 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
type fakeChannel struct{ id string }
|
||||
|
||||
func (f *fakeChannel) Name() string { return "fake" }
|
||||
func (f *fakeChannel) Start(ctx context.Context) error { return nil }
|
||||
func (f *fakeChannel) Stop(ctx context.Context) error { return nil }
|
||||
func (f *fakeChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return nil }
|
||||
func (f *fakeChannel) IsRunning() bool { return true }
|
||||
func (f *fakeChannel) IsAllowed(string) bool { return true }
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
@@ -620,3 +632,158 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
||||
chManager, err := channels.NewManager(&config.Config{}, bus.NewMessageBus(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel manager: %v", err)
|
||||
}
|
||||
for name, id := range map[string]string{
|
||||
"whatsapp": "rid-whatsapp",
|
||||
"telegram": "rid-telegram",
|
||||
"feishu": "rid-feishu",
|
||||
"discord": "rid-discord",
|
||||
"maixcam": "rid-maixcam",
|
||||
"qq": "rid-qq",
|
||||
"dingtalk": "rid-dingtalk",
|
||||
"slack": "rid-slack",
|
||||
"line": "rid-line",
|
||||
"onebot": "rid-onebot",
|
||||
"wecom": "rid-wecom",
|
||||
"wecom_app": "rid-wecom-app",
|
||||
} {
|
||||
chManager.RegisterChannel(name, &fakeChannel{id: id})
|
||||
}
|
||||
al.SetChannelManager(chManager)
|
||||
tests := []struct {
|
||||
channel string
|
||||
wantID string
|
||||
}{
|
||||
{channel: "whatsapp", wantID: "rid-whatsapp"},
|
||||
{channel: "telegram", wantID: "rid-telegram"},
|
||||
{channel: "feishu", wantID: "rid-feishu"},
|
||||
{channel: "discord", wantID: "rid-discord"},
|
||||
{channel: "maixcam", wantID: "rid-maixcam"},
|
||||
{channel: "qq", wantID: "rid-qq"},
|
||||
{channel: "dingtalk", wantID: "rid-dingtalk"},
|
||||
{channel: "slack", wantID: "rid-slack"},
|
||||
{channel: "line", wantID: "rid-line"},
|
||||
{channel: "onebot", wantID: "rid-onebot"},
|
||||
{channel: "wecom", wantID: "rid-wecom"},
|
||||
{channel: "wecom_app", wantID: "rid-wecom-app"},
|
||||
{channel: "unknown", wantID: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.channel, func(t *testing.T) {
|
||||
got := al.targetReasoningChannelID(tt.channel)
|
||||
if got != tt.wantID {
|
||||
t.Fatalf("targetReasoningChannelID(%q) = %q, want %q", tt.channel, got, tt.wantID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReasoning(t *testing.T) {
|
||||
newLoop := func(t *testing.T) (*AgentLoop, *bus.MessageBus) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
return NewAgentLoop(cfg, msgBus, &mockProvider{}), msgBus
|
||||
}
|
||||
|
||||
t.Run("skips when any required field is empty", func(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("publishes one message for non telegram", func(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected an outbound message")
|
||||
}
|
||||
if msg.Channel != "slack" || msg.ChatID != "channel-1" || msg.Content != "hello reasoning" {
|
||||
t.Fatalf("unexpected outbound message: %+v", msg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("publishes one message for telegram", func(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
reasoning := "hello telegram reasoning"
|
||||
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
|
||||
if msg.Channel != "telegram" {
|
||||
t.Fatalf("expected telegram channel message, got %+v", msg)
|
||||
}
|
||||
if msg.ChatID != "tg-chat" {
|
||||
t.Fatalf("expected chatID tg-chat, got %+v", msg)
|
||||
}
|
||||
if msg.Content != reasoning {
|
||||
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
|
||||
}
|
||||
})
|
||||
t.Run("expired ctx", func(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
reasoning := "hello telegram reasoning"
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user