mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into version
This commit is contained in:
+14
-55
@@ -3,13 +3,13 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -85,9 +85,11 @@ func NewAgentInstance(
|
||||
if cfg.Tools.IsToolEnabled("exec") {
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
logger.ErrorCF("agent", "Failed to initialize exec tool; continuing without exec",
|
||||
map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
toolsRegistry.Register(execTool)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
}
|
||||
|
||||
if cfg.Tools.IsToolEnabled("edit_file") {
|
||||
@@ -150,59 +152,14 @@ func NewAgentInstance(
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
Fallbacks: fallbacks,
|
||||
}
|
||||
resolveFromModelList := func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if cfg != nil {
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
|
||||
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
|
||||
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
@@ -210,8 +167,8 @@ func NewAgentInstance(
|
||||
})
|
||||
lightCandidates = resolved
|
||||
} else {
|
||||
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
|
||||
rc.LightModel, agentID)
|
||||
logger.WarnCF("agent", "Routing light model not found; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,7 +277,8 @@ func (a *AgentInstance) Close() error {
|
||||
func initSessionStore(dir string) session.SessionStore {
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
log.Printf("memory: init store: %v; using json sessions", err)
|
||||
logger.WarnCF("agent", "Memory JSONL store init failed; falling back to json sessions",
|
||||
map[string]any{"error": err.Error()})
|
||||
return session.NewSessionManager(dir)
|
||||
}
|
||||
|
||||
@@ -328,11 +286,12 @@ func initSessionStore(dir string) session.SessionStore {
|
||||
// Migration failure means the store could not write data.
|
||||
// Fall back to SessionManager to avoid a split state where
|
||||
// some sessions are in JSONL and others remain in JSON.
|
||||
log.Printf("memory: migration failed: %v; falling back to json sessions", merr)
|
||||
logger.WarnCF("agent", "Memory migration failed; falling back to json sessions",
|
||||
map[string]any{"error": merr.Error()})
|
||||
store.Close()
|
||||
return session.NewSessionManager(dir)
|
||||
} else if n > 0 {
|
||||
log.Printf("memory: migrated %d session(s) to jsonl", n)
|
||||
logger.InfoCF("agent", "Memory migrated to JSONL", map[string]any{"sessions_migrated": n})
|
||||
}
|
||||
|
||||
return session.NewJSONLBackend(store)
|
||||
|
||||
@@ -246,3 +246,37 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_InvalidExecConfigDoesNotExit(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "test-model",
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{Enabled: true},
|
||||
Exec: config.ExecConfig{
|
||||
ToolConfig: config.ToolConfig{Enabled: true},
|
||||
EnableDenyPatterns: true,
|
||||
CustomDenyPatterns: []string{"[invalid-regex"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
if agent == nil {
|
||||
t.Fatal("expected agent instance, got nil")
|
||||
}
|
||||
|
||||
if _, ok := agent.Tools.Get("exec"); ok {
|
||||
t.Fatal("exec tool should not be registered when exec config is invalid")
|
||||
}
|
||||
|
||||
if _, ok := agent.Tools.Get("read_file"); !ok {
|
||||
t.Fatal("read_file tool should still be registered")
|
||||
}
|
||||
}
|
||||
|
||||
+139
-52
@@ -70,7 +70,8 @@ type processOptions struct {
|
||||
}
|
||||
|
||||
const (
|
||||
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit."
|
||||
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
@@ -292,58 +293,64 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
// Process message
|
||||
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
|
||||
// Currently disabled because files are deleted before the LLM can access their content.
|
||||
// defer func() {
|
||||
// if al.mediaStore != nil && msg.MediaScope != "" {
|
||||
// if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil {
|
||||
// logger.WarnCF("agent", "Failed to release media", map[string]any{
|
||||
// "scope": msg.MediaScope,
|
||||
// "error": releaseErr.Error(),
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
func() {
|
||||
defer func() {
|
||||
if al.channelManager != nil {
|
||||
al.channelManager.InvokeTypingStop(msg.Channel, msg.ChatID)
|
||||
}
|
||||
}()
|
||||
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
|
||||
// Currently disabled because files are deleted before the LLM can access their content.
|
||||
// defer func() {
|
||||
// if al.mediaStore != nil && msg.MediaScope != "" {
|
||||
// if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil {
|
||||
// logger.WarnCF("agent", "Failed to release media", map[string]any{
|
||||
// "scope": msg.MediaScope,
|
||||
// "error": releaseErr.Error(),
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
}
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
// Check if the message tool already sent a response during this round.
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
// Use default agent's tools to check (message tool is shared).
|
||||
alreadySent := false
|
||||
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
||||
if defaultAgent != nil {
|
||||
if tool, ok := defaultAgent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
if response != "" {
|
||||
// Check if the message tool already sent a response during this round.
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
// Use default agent's tools to check (message tool is shared).
|
||||
alreadySent := false
|
||||
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
||||
if defaultAgent != nil {
|
||||
if tool, ok := defaultAgent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !alreadySent {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
logger.InfoCF("agent", "Published outbound response",
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
"content_len": len(response),
|
||||
if !alreadySent {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
} else {
|
||||
logger.DebugCF(
|
||||
"agent",
|
||||
"Skipped outbound (message tool already sent)",
|
||||
map[string]any{"channel": msg.Channel},
|
||||
)
|
||||
logger.InfoCF("agent", "Published outbound response",
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
"content_len": len(response),
|
||||
})
|
||||
} else {
|
||||
logger.DebugCF(
|
||||
"agent",
|
||||
"Skipped outbound (message tool already sent)",
|
||||
map[string]any{"channel": msg.Channel},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
default:
|
||||
time.Sleep(time.Microsecond * 200)
|
||||
}
|
||||
@@ -943,7 +950,11 @@ func (al *AgentLoop) runAgentLoop(
|
||||
|
||||
// 4. Handle empty response
|
||||
if finalContent == "" {
|
||||
finalContent = opts.DefaultResponse
|
||||
if iteration >= agent.MaxIterations && agent.MaxIterations > 0 {
|
||||
finalContent = toolLimitResponse
|
||||
} else {
|
||||
finalContent = opts.DefaultResponse
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Save final assistant message to session
|
||||
@@ -1034,6 +1045,7 @@ func (al *AgentLoop) handleReasoning(
|
||||
}
|
||||
|
||||
// runLLMIteration executes the LLM call loop with tool handling.
|
||||
// Returns (finalContent, iteration, error).
|
||||
func (al *AgentLoop) runLLMIteration(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
@@ -1043,6 +1055,13 @@ func (al *AgentLoop) runLLMIteration(
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
// Check if both the provider and channel support streaming
|
||||
streamProvider, providerCanStream := agent.Provider.(providers.StreamingProvider)
|
||||
var streamer bus.Streamer
|
||||
if providerCanStream && !opts.NoHistory && !constants.IsInternalChannel(opts.Channel) {
|
||||
streamer, _ = al.bus.GetStreamer(ctx, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
// all tool-follow-up iterations within the same turn so that a multi-step
|
||||
@@ -1124,6 +1143,16 @@ func (al *AgentLoop) runLLMIteration(
|
||||
al.activeRequests.Add(1)
|
||||
defer al.activeRequests.Done()
|
||||
|
||||
// Use streaming when available (streamer obtained, provider supports it)
|
||||
if streamer != nil && streamProvider != nil {
|
||||
return streamProvider.ChatStream(
|
||||
ctx, messages, providerToolDefs, activeModel, llmOpts,
|
||||
func(accumulated string) {
|
||||
streamer.Update(ctx, accumulated)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
@@ -1251,15 +1280,31 @@ func (al *AgentLoop) runLLMIteration(
|
||||
if finalContent == "" && response.ReasoningContent != "" {
|
||||
finalContent = response.ReasoningContent
|
||||
}
|
||||
|
||||
// If we were streaming, finalize the message (sends the permanent message)
|
||||
if streamer != nil {
|
||||
if err := streamer.Finalize(ctx, finalContent); err != nil {
|
||||
logger.WarnCF("agent", "Stream finalize failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_chars": len(finalContent),
|
||||
"streamed": streamer != nil,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
// Tool calls detected — cancel any active stream (draft auto-expires)
|
||||
if streamer != nil {
|
||||
streamer.Cancel(ctx)
|
||||
}
|
||||
|
||||
normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls))
|
||||
for _, tc := range response.ToolCalls {
|
||||
normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc))
|
||||
@@ -1336,6 +1381,22 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Send tool feedback to chat channel if enabled
|
||||
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() && opts.Channel != "" {
|
||||
feedbackPreview := utils.Truncate(
|
||||
string(argsJSON),
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
feedbackMsg := fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", tc.Name, feedbackPreview)
|
||||
fbCtx, fbCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: feedbackMsg,
|
||||
})
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
// Create async callback for tools that implement AsyncExecutor.
|
||||
// When the background work completes, this publishes the result
|
||||
// as an inbound system message so processSystemMessage routes it
|
||||
@@ -1475,7 +1536,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, agent.Model
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
@@ -1486,7 +1547,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, agent.Model
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
@@ -1496,7 +1557,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, agent.Router.LightModel()
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
@@ -1959,11 +2020,37 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
}
|
||||
if agent != nil {
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, cfg.Agents.Defaults.Provider
|
||||
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
|
||||
}
|
||||
rt.SwitchModel = func(value string) (string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
modelCfg, err := resolvedModelConfig(cfg, value, agent.Workspace)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nextProvider, _, err := providers.CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
|
||||
}
|
||||
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
|
||||
if len(nextCandidates) == 0 {
|
||||
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
|
||||
}
|
||||
|
||||
oldModel := agent.Model
|
||||
oldProvider := agent.Provider
|
||||
agent.Model = value
|
||||
agent.Provider = nextProvider
|
||||
agent.Candidates = nextCandidates
|
||||
agent.ThinkingLevel = parseThinkingLevel(modelCfg.ThinkingLevel)
|
||||
|
||||
if oldProvider != nil && oldProvider != nextProvider {
|
||||
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
|
||||
stateful.Close()
|
||||
}
|
||||
}
|
||||
return oldModel, nil
|
||||
}
|
||||
|
||||
|
||||
+393
-4
@@ -2,7 +2,10 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@@ -417,6 +420,29 @@ func (m *countingMockProvider) GetDefaultModel() string {
|
||||
return "counting-mock-model"
|
||||
}
|
||||
|
||||
type toolLimitOnlyProvider struct{}
|
||||
|
||||
func (m *toolLimitOnlyProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_tool_limit_test",
|
||||
Type: "function",
|
||||
Name: "tool_limit_test_tool",
|
||||
Arguments: map[string]any{"value": "x"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *toolLimitOnlyProvider) GetDefaultModel() string {
|
||||
return "tool-limit-only-model"
|
||||
}
|
||||
|
||||
// mockCustomTool is a simple mock tool for registration testing
|
||||
type mockCustomTool struct{}
|
||||
|
||||
@@ -439,11 +465,74 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
type toolLimitTestTool struct{}
|
||||
|
||||
func (m *toolLimitTestTool) Name() string {
|
||||
return "tool_limit_test_tool"
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Description() string {
|
||||
return "Tool used to exhaust the iteration budget in tests"
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"value": map[string]any{"type": "string"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.SilentResult("tool limit test result")
|
||||
}
|
||||
|
||||
// testHelper executes a message and returns the response
|
||||
type testHelper struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func newChatCompletionTestServer(
|
||||
t *testing.T,
|
||||
label string,
|
||||
response string,
|
||||
calls *int,
|
||||
model *string,
|
||||
) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
||||
}
|
||||
*calls = *calls + 1
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
decodeErr := json.NewDecoder(r.Body).Decode(&req)
|
||||
if decodeErr != nil {
|
||||
t.Fatalf("decode %s request: %v", label, decodeErr)
|
||||
}
|
||||
*model = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
encodeErr := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": response},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
if encodeErr != nil {
|
||||
t.Fatalf("encode %s response: %v", label, encodeErr)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
@@ -605,12 +694,34 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
ModelName: "before-switch",
|
||||
ModelName: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg.WithSecurity(&config.SecurityConfig{
|
||||
ModelList: map[string]config.ModelSecurityEntry{
|
||||
"local": {
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
"deepseek": {
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
@@ -621,13 +732,13 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to after-switch",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
@@ -641,7 +752,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
|
||||
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
}
|
||||
|
||||
@@ -650,6 +761,201 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
ModelName: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg.WithSecurity(&config.SecurityConfig{
|
||||
ModelList: map[string]config.ModelSecurityEntry{
|
||||
"local": {
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to missing",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if switchResp != `model "missing" not found in model_list or providers` {
|
||||
t.Fatalf("unexpected /switch error reply: %q", switchResp)
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
|
||||
}
|
||||
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
localCalls := 0
|
||||
localModel := ""
|
||||
localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel)
|
||||
defer localServer.Close()
|
||||
|
||||
remoteCalls := 0
|
||||
remoteModel := ""
|
||||
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
|
||||
defer remoteServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
ModelName: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/Qwen3.5-35B-A3B",
|
||||
APIBase: localServer.URL,
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIBase: remoteServer.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg.WithSecurity(&config.SecurityConfig{
|
||||
ModelList: map[string]config.ModelSecurityEntry{
|
||||
"local": {
|
||||
APIKeys: []string{"local-key"},
|
||||
},
|
||||
"deepseek": {
|
||||
APIKeys: []string{"remote-key"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello before switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if firstResp != "local reply" {
|
||||
t.Fatalf("unexpected response before switch: %q", firstResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls before switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 0 {
|
||||
t.Fatalf("remote calls before switch = %d, want 0", remoteCalls)
|
||||
}
|
||||
if localModel != "Qwen3.5-35B-A3B" {
|
||||
t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B")
|
||||
}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello after switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if secondResp != "remote reply" {
|
||||
t.Fatalf("unexpected response after switch: %q", secondResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls after switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 1 {
|
||||
t.Fatalf("remote calls after switch = %d, want 1", remoteCalls)
|
||||
}
|
||||
if remoteModel != "deepseek-v3.2" {
|
||||
t.Fatalf(
|
||||
"remote model after switch = %q, want %q",
|
||||
remoteModel,
|
||||
"deepseek-v3.2",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
@@ -845,6 +1151,89 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: ""}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "empty-response", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if response != defaultResponse {
|
||||
t.Fatalf("response = %q, want %q", response, defaultResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolLimitOnlyProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(&toolLimitTestTool{})
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "tool-limit", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if response != toolLimitResponse {
|
||||
t.Fatalf("response = %q, want %q", response, toolLimitResponse)
|
||||
}
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: "test",
|
||||
Peer: &routing.RoutePeer{
|
||||
Kind: "direct",
|
||||
ID: "cron",
|
||||
},
|
||||
})
|
||||
history := defaultAgent.Sessions.GetHistory(route.SessionKey)
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("history len = %d, want 4", len(history))
|
||||
}
|
||||
assertRoles(t, history, "user", "assistant", "tool", "assistant")
|
||||
if history[3].Content != toolLimitResponse {
|
||||
t.Fatalf("final assistant content = %q, want %q", history[3].Content, toolLimitResponse)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessDirectWithChannel_TriggersMCPInitialization verifies that
|
||||
// ProcessDirectWithChannel triggers MCP initialization when MCP is enabled.
|
||||
// Note: Manager is only initialized when at least one MCP server is configured
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
return func(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || cfg == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func resolveModelCandidates(
|
||||
cfg *config.Config,
|
||||
defaultProvider string,
|
||||
primary string,
|
||||
fallbacks []string,
|
||||
) []providers.FallbackCandidate {
|
||||
return providers.ResolveCandidatesWithLookup(
|
||||
providers.ModelConfig{
|
||||
Primary: primary,
|
||||
Fallbacks: fallbacks,
|
||||
},
|
||||
defaultProvider,
|
||||
buildModelListResolver(cfg),
|
||||
)
|
||||
}
|
||||
|
||||
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Model) != "" {
|
||||
return candidates[0].Model
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Provider) != "" {
|
||||
return candidates[0].Provider
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
modelCfg, err := cfg.GetModelConfig(strings.TrimSpace(modelName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clone := *modelCfg
|
||||
if clone.Workspace == "" {
|
||||
clone.Workspace = workspace
|
||||
}
|
||||
|
||||
return &clone, nil
|
||||
}
|
||||
Reference in New Issue
Block a user