mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into version
This commit is contained in:
+393
-4
@@ -2,7 +2,10 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@@ -417,6 +420,29 @@ func (m *countingMockProvider) GetDefaultModel() string {
|
||||
return "counting-mock-model"
|
||||
}
|
||||
|
||||
type toolLimitOnlyProvider struct{}
|
||||
|
||||
func (m *toolLimitOnlyProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_tool_limit_test",
|
||||
Type: "function",
|
||||
Name: "tool_limit_test_tool",
|
||||
Arguments: map[string]any{"value": "x"},
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *toolLimitOnlyProvider) GetDefaultModel() string {
|
||||
return "tool-limit-only-model"
|
||||
}
|
||||
|
||||
// mockCustomTool is a simple mock tool for registration testing
|
||||
type mockCustomTool struct{}
|
||||
|
||||
@@ -439,11 +465,74 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
type toolLimitTestTool struct{}
|
||||
|
||||
func (m *toolLimitTestTool) Name() string {
|
||||
return "tool_limit_test_tool"
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Description() string {
|
||||
return "Tool used to exhaust the iteration budget in tests"
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"value": map[string]any{"type": "string"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *toolLimitTestTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.SilentResult("tool limit test result")
|
||||
}
|
||||
|
||||
// testHelper executes a message and returns the response
|
||||
type testHelper struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func newChatCompletionTestServer(
|
||||
t *testing.T,
|
||||
label string,
|
||||
response string,
|
||||
calls *int,
|
||||
model *string,
|
||||
) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
||||
}
|
||||
*calls = *calls + 1
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
decodeErr := json.NewDecoder(r.Body).Decode(&req)
|
||||
if decodeErr != nil {
|
||||
t.Fatalf("decode %s request: %v", label, decodeErr)
|
||||
}
|
||||
*model = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
encodeErr := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": response},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
if encodeErr != nil {
|
||||
t.Fatalf("encode %s response: %v", label, encodeErr)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
@@ -605,12 +694,34 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
ModelName: "before-switch",
|
||||
ModelName: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg.WithSecurity(&config.SecurityConfig{
|
||||
ModelList: map[string]config.ModelSecurityEntry{
|
||||
"local": {
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
"deepseek": {
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
@@ -621,13 +732,13 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to after-switch",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
@@ -641,7 +752,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
|
||||
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
}
|
||||
|
||||
@@ -650,6 +761,201 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
ModelName: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg.WithSecurity(&config.SecurityConfig{
|
||||
ModelList: map[string]config.ModelSecurityEntry{
|
||||
"local": {
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to missing",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if switchResp != `model "missing" not found in model_list or providers` {
|
||||
t.Fatalf("unexpected /switch error reply: %q", switchResp)
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
|
||||
}
|
||||
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
localCalls := 0
|
||||
localModel := ""
|
||||
localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel)
|
||||
defer localServer.Close()
|
||||
|
||||
remoteCalls := 0
|
||||
remoteModel := ""
|
||||
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
|
||||
defer remoteServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
ModelName: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/Qwen3.5-35B-A3B",
|
||||
APIBase: localServer.URL,
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIBase: remoteServer.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg.WithSecurity(&config.SecurityConfig{
|
||||
ModelList: map[string]config.ModelSecurityEntry{
|
||||
"local": {
|
||||
APIKeys: []string{"local-key"},
|
||||
},
|
||||
"deepseek": {
|
||||
APIKeys: []string{"remote-key"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello before switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if firstResp != "local reply" {
|
||||
t.Fatalf("unexpected response before switch: %q", firstResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls before switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 0 {
|
||||
t.Fatalf("remote calls before switch = %d, want 0", remoteCalls)
|
||||
}
|
||||
if localModel != "Qwen3.5-35B-A3B" {
|
||||
t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B")
|
||||
}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello after switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if secondResp != "remote reply" {
|
||||
t.Fatalf("unexpected response after switch: %q", secondResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls after switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 1 {
|
||||
t.Fatalf("remote calls after switch = %d, want 1", remoteCalls)
|
||||
}
|
||||
if remoteModel != "deepseek-v3.2" {
|
||||
t.Fatalf(
|
||||
"remote model after switch = %q, want %q",
|
||||
remoteModel,
|
||||
"deepseek-v3.2",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
@@ -845,6 +1151,89 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmptyModelResponseUsesAccurateFallback(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: ""}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "empty-response", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if response != defaultResponse {
|
||||
t.Fatalf("response = %q, want %q", response, defaultResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolLimitOnlyProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(&toolLimitTestTool{})
|
||||
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "hello", "tool-limit", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if response != toolLimitResponse {
|
||||
t.Fatalf("response = %q, want %q", response, toolLimitResponse)
|
||||
}
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: "test",
|
||||
Peer: &routing.RoutePeer{
|
||||
Kind: "direct",
|
||||
ID: "cron",
|
||||
},
|
||||
})
|
||||
history := defaultAgent.Sessions.GetHistory(route.SessionKey)
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("history len = %d, want 4", len(history))
|
||||
}
|
||||
assertRoles(t, history, "user", "assistant", "tool", "assistant")
|
||||
if history[3].Content != toolLimitResponse {
|
||||
t.Fatalf("final assistant content = %q, want %q", history[3].Content, toolLimitResponse)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessDirectWithChannel_TriggersMCPInitialization verifies that
|
||||
// ProcessDirectWithChannel triggers MCP initialization when MCP is enabled.
|
||||
// Note: Manager is only initialized when at least one MCP server is configured
|
||||
|
||||
Reference in New Issue
Block a user