Merge branch 'upstream-main' into feat/subturn-poc

This commit is contained in:
Administrator
2026-03-19 22:12:51 +08:00
22 changed files with 1125 additions and 72 deletions
+14 -55
View File
@@ -3,13 +3,13 @@ package agent
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -85,9 +85,11 @@ func NewAgentInstance(
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
logger.ErrorCF("agent", "Failed to initialize exec tool; continuing without exec",
map[string]any{"error": err.Error()})
} else {
toolsRegistry.Register(execTool)
}
toolsRegistry.Register(execTool)
}
if cfg.Tools.IsToolEnabled("edit_file") {
@@ -150,59 +152,14 @@ func NewAgentInstance(
}
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
Fallbacks: fallbacks,
}
resolveFromModelList := func(raw string) (string, bool) {
ensureProtocol := func(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
}
raw = strings.TrimSpace(raw)
if raw == "" {
return "", false
}
if cfg != nil {
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return ensureProtocol(mc.Model), true
}
for i := range cfg.ModelList {
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return ensureProtocol(fullModel), true
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return ensureProtocol(fullModel), true
}
}
}
return "", false
}
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
var lightCandidates []providers.FallbackCandidate
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
if len(resolved) > 0 {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
@@ -210,8 +167,8 @@ func NewAgentInstance(
})
lightCandidates = resolved
} else {
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
rc.LightModel, agentID)
logger.WarnCF("agent", "Routing light model not found; routing disabled",
map[string]any{"light_model": rc.LightModel, "agent_id": agentID})
}
}
@@ -320,7 +277,8 @@ func (a *AgentInstance) Close() error {
func initSessionStore(dir string) session.SessionStore {
store, err := memory.NewJSONLStore(dir)
if err != nil {
log.Printf("memory: init store: %v; using json sessions", err)
logger.WarnCF("agent", "Memory JSONL store init failed; falling back to json sessions",
map[string]any{"error": err.Error()})
return session.NewSessionManager(dir)
}
@@ -328,11 +286,12 @@ func initSessionStore(dir string) session.SessionStore {
// Migration failure means the store could not write data.
// Fall back to SessionManager to avoid a split state where
// some sessions are in JSONL and others remain in JSON.
log.Printf("memory: migration failed: %v; falling back to json sessions", merr)
logger.WarnCF("agent", "Memory migration failed; falling back to json sessions",
map[string]any{"error": merr.Error()})
store.Close()
return session.NewSessionManager(dir)
} else if n > 0 {
log.Printf("memory: migrated %d session(s) to jsonl", n)
logger.InfoCF("agent", "Memory migrated to JSONL", map[string]any{"sessions_migrated": n})
}
return session.NewJSONLBackend(store)
+34
View File
@@ -246,3 +246,37 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
}
}
func TestNewAgentInstance_InvalidExecConfigDoesNotExit(t *testing.T) {
workspace := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "test-model",
},
},
Tools: config.ToolsConfig{
ReadFile: config.ReadFileToolConfig{Enabled: true},
Exec: config.ExecConfig{
ToolConfig: config.ToolConfig{Enabled: true},
EnableDenyPatterns: true,
CustomDenyPatterns: []string{"[invalid-regex"},
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
if agent == nil {
t.Fatal("expected agent instance, got nil")
}
if _, ok := agent.Tools.Get("exec"); ok {
t.Fatal("exec tool should not be registered when exec config is invalid")
}
if _, ok := agent.Tools.Get("read_file"); !ok {
t.Fatal("read_file tool should still be registered")
}
}
+30 -4
View File
@@ -1760,7 +1760,7 @@ func (al *AgentLoop) selectCandidates(
history []providers.Message,
) (candidates []providers.FallbackCandidate, model string) {
if agent.Router == nil || len(agent.LightCandidates) == 0 {
return agent.Candidates, agent.Model
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
}
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
@@ -1771,7 +1771,7 @@ func (al *AgentLoop) selectCandidates(
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.Candidates, agent.Model
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
}
logger.InfoCF("agent", "Model routing: light model selected",
@@ -1781,7 +1781,7 @@ func (al *AgentLoop) selectCandidates(
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, agent.Router.LightModel()
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
@@ -2271,11 +2271,37 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
}
if agent != nil {
rt.GetModelInfo = func() (string, string) {
return agent.Model, cfg.Agents.Defaults.Provider
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
}
rt.SwitchModel = func(value string) (string, error) {
value = strings.TrimSpace(value)
modelCfg, err := resolvedModelConfig(cfg, value, agent.Workspace)
if err != nil {
return "", err
}
nextProvider, _, err := providers.CreateProviderFromConfig(modelCfg)
if err != nil {
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
}
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
if len(nextCandidates) == 0 {
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
}
oldModel := agent.Model
oldProvider := agent.Provider
agent.Model = value
agent.Provider = nextProvider
agent.Candidates = nextCandidates
agent.ThinkingLevel = parseThinkingLevel(modelCfg.ThinkingLevel)
if oldProvider != nil && oldProvider != nextProvider {
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
stateful.Close()
}
}
return oldModel, nil
}
+242 -4
View File
@@ -2,7 +2,10 @@ package agent
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
@@ -444,6 +447,46 @@ type testHelper struct {
al *AgentLoop
}
func newChatCompletionTestServer(
t *testing.T,
label string,
response string,
calls *int,
model *string,
) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
}
*calls = *calls + 1
defer r.Body.Close()
var req struct {
Model string `json:"model"`
}
decodeErr := json.NewDecoder(r.Body).Decode(&req)
if decodeErr != nil {
t.Fatalf("decode %s request: %v", label, decodeErr)
}
*model = req.Model
w.Header().Set("Content-Type", "application/json")
encodeErr := json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": response},
"finish_reason": "stop",
},
},
})
if encodeErr != nil {
t.Fatalf("encode %s response: %v", label, encodeErr)
}
}))
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
@@ -605,11 +648,25 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "before-switch",
Model: "local",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []config.ModelConfig{
{
ModelName: "local",
Model: "openai/local-model",
APIKey: "test-key",
APIBase: "https://local.example.invalid/v1",
},
{
ModelName: "deepseek",
Model: "openrouter/deepseek/deepseek-v3.2",
APIKey: "test-key",
APIBase: "https://openrouter.ai/api/v1",
},
},
}
msgBus := bus.NewMessageBus()
@@ -621,13 +678,13 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to after-switch",
Content: "/switch model to deepseek",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
}
@@ -641,7 +698,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
}
@@ -650,6 +707,187 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
}
}
func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "local",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []config.ModelConfig{
{
ModelName: "local",
Model: "openai/local-model",
APIKey: "test-key",
APIBase: "https://local.example.invalid/v1",
},
},
}
msgBus := bus.NewMessageBus()
provider := &countingMockProvider{response: "LLM reply"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to missing",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if switchResp != `model "missing" not found in model_list or providers` {
t.Fatalf("unexpected /switch error reply: %q", switchResp)
}
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/show model",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
}
if provider.calls != 0 {
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
}
}
func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
localCalls := 0
localModel := ""
localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel)
defer localServer.Close()
remoteCalls := 0
remoteModel := ""
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
defer remoteServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "local",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []config.ModelConfig{
{
ModelName: "local",
Model: "openai/Qwen3.5-35B-A3B",
APIKey: "local-key",
APIBase: localServer.URL,
},
{
ModelName: "deepseek",
Model: "openrouter/deepseek/deepseek-v3.2",
APIKey: "remote-key",
APIBase: remoteServer.URL,
},
},
}
msgBus := bus.NewMessageBus()
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello before switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if firstResp != "local reply" {
t.Fatalf("unexpected response before switch: %q", firstResp)
}
if localCalls != 1 {
t.Fatalf("local calls before switch = %d, want 1", localCalls)
}
if remoteCalls != 0 {
t.Fatalf("remote calls before switch = %d, want 0", remoteCalls)
}
if localModel != "Qwen3.5-35B-A3B" {
t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B")
}
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to deepseek",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
}
secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello after switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if secondResp != "remote reply" {
t.Fatalf("unexpected response after switch: %q", secondResp)
}
if localCalls != 1 {
t.Fatalf("local calls after switch = %d, want 1", localCalls)
}
if remoteCalls != 1 {
t.Fatalf("remote calls after switch = %d, want 1", remoteCalls)
}
if remoteModel != "deepseek-v3.2" {
t.Fatalf(
"remote model after switch = %q, want %q",
remoteModel,
"deepseek-v3.2",
)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
+97
View File
@@ -0,0 +1,97 @@
package agent
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
ensureProtocol := func(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
}
return func(raw string) (string, bool) {
raw = strings.TrimSpace(raw)
if raw == "" || cfg == nil {
return "", false
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return ensureProtocol(mc.Model), true
}
for i := range cfg.ModelList {
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return ensureProtocol(fullModel), true
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return ensureProtocol(fullModel), true
}
}
return "", false
}
}
func resolveModelCandidates(
cfg *config.Config,
defaultProvider string,
primary string,
fallbacks []string,
) []providers.FallbackCandidate {
return providers.ResolveCandidatesWithLookup(
providers.ModelConfig{
Primary: primary,
Fallbacks: fallbacks,
},
defaultProvider,
buildModelListResolver(cfg),
)
}
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Model) != "" {
return candidates[0].Model
}
return fallback
}
func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallback string) string {
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Provider) != "" {
return candidates[0].Provider
}
return fallback
}
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
modelCfg, err := cfg.GetModelConfig(strings.TrimSpace(modelName))
if err != nil {
return nil, err
}
clone := *modelCfg
if clone.Workspace == "" {
clone.Workspace = workspace
}
return &clone, nil
}
+9
View File
@@ -924,6 +924,15 @@ func LoadConfig(path string) (*Config, error) {
cfg.ModelList = ConvertProvidersToModelList(cfg)
}
// Inherit credentials from providers to model_list entries (#1635).
// When both providers and model_list are present, model_list entries
// whose api_key/api_base are empty will inherit from the matching
// provider (matched by protocol prefix). Explicit model_list values
// always take precedence.
if cfg.HasProvidersConfig() {
InheritProviderCredentials(cfg.ModelList, cfg.Providers)
}
// Validate model_list for uniqueness and required fields
if err := cfg.ValidateModelList(); err != nil {
return nil, err
+81
View File
@@ -468,3 +468,84 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return result
}
// protocolProviderMapping maps a model protocol prefix (the part before "/" in
// the Model field) to a function that extracts the corresponding ProviderConfig
// from the legacy ProvidersConfig. Used by InheritProviderCredentials.
var protocolProviderMapping = map[string]func(p ProvidersConfig) ProviderConfig{
"openai": func(p ProvidersConfig) ProviderConfig { return p.OpenAI.ProviderConfig },
"anthropic": func(p ProvidersConfig) ProviderConfig { return p.Anthropic },
"litellm": func(p ProvidersConfig) ProviderConfig { return p.LiteLLM },
"openrouter": func(p ProvidersConfig) ProviderConfig { return p.OpenRouter },
"groq": func(p ProvidersConfig) ProviderConfig { return p.Groq },
"zhipu": func(p ProvidersConfig) ProviderConfig { return p.Zhipu },
"vllm": func(p ProvidersConfig) ProviderConfig { return p.VLLM },
"gemini": func(p ProvidersConfig) ProviderConfig { return p.Gemini },
"nvidia": func(p ProvidersConfig) ProviderConfig { return p.Nvidia },
"ollama": func(p ProvidersConfig) ProviderConfig { return p.Ollama },
"moonshot": func(p ProvidersConfig) ProviderConfig { return p.Moonshot },
"shengsuanyun": func(p ProvidersConfig) ProviderConfig { return p.ShengSuanYun },
"deepseek": func(p ProvidersConfig) ProviderConfig { return p.DeepSeek },
"cerebras": func(p ProvidersConfig) ProviderConfig { return p.Cerebras },
"vivgrid": func(p ProvidersConfig) ProviderConfig { return p.Vivgrid },
"volcengine": func(p ProvidersConfig) ProviderConfig { return p.VolcEngine },
"github-copilot": func(p ProvidersConfig) ProviderConfig { return p.GitHubCopilot },
"antigravity": func(p ProvidersConfig) ProviderConfig { return p.Antigravity },
"qwen": func(p ProvidersConfig) ProviderConfig { return p.Qwen },
"mistral": func(p ProvidersConfig) ProviderConfig { return p.Mistral },
"avian": func(p ProvidersConfig) ProviderConfig { return p.Avian },
"minimax": func(p ProvidersConfig) ProviderConfig { return p.Minimax },
"longcat": func(p ProvidersConfig) ProviderConfig { return p.LongCat },
"modelscope": func(p ProvidersConfig) ProviderConfig { return p.ModelScope },
"novita": func(p ProvidersConfig) ProviderConfig { return p.Novita },
}
// InheritProviderCredentials fills in missing api_key, api_base, proxy, and
// request_timeout on model_list entries from the matching legacy providers
// configuration. The match is determined by the protocol prefix in the Model
// field (e.g. "deepseek/deepseek-chat" matches providers.deepseek).
//
// Only empty fields are filled — any value explicitly set on a model_list entry
// takes precedence. This function modifies the slice in place.
//
// This bridges the gap described in issue #1635: users who configure
// credentials once in the providers section expect model_list entries using
// the same protocol to "just work" without duplicating credentials.
func InheritProviderCredentials(models []ModelConfig, providers ProvidersConfig) {
if providers.IsEmpty() {
return
}
for i := range models {
m := &models[i]
// Extract protocol prefix from Model field
protocol := ""
if idx := strings.Index(m.Model, "/"); idx > 0 {
protocol = strings.ToLower(m.Model[:idx])
}
if protocol == "" {
continue
}
getProvider, ok := protocolProviderMapping[protocol]
if !ok {
continue
}
pc := getProvider(providers)
// Only fill empty fields — explicit model_list values win
if m.APIKey == "" && pc.APIKey != "" {
m.APIKey = pc.APIKey
}
if m.APIBase == "" && pc.APIBase != "" {
m.APIBase = pc.APIBase
}
if m.Proxy == "" && pc.Proxy != "" {
m.Proxy = pc.Proxy
}
if m.RequestTimeout == 0 && pc.RequestTimeout != 0 {
m.RequestTimeout = pc.RequestTimeout
}
}
}
+140
View File
@@ -613,3 +613,143 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
}
}
// ---------- InheritProviderCredentials tests ----------
func TestInheritProviderCredentials_FillsMissingAPIKey(t *testing.T) {
models := []ModelConfig{
{ModelName: "my-deepseek", Model: "deepseek/deepseek-chat"},
}
providers := ProvidersConfig{
DeepSeek: ProviderConfig{
APIKey: "sk-deepseek-from-providers",
APIBase: "https://api.deepseek.com/v1",
},
}
InheritProviderCredentials(models, providers)
if models[0].APIKey != "sk-deepseek-from-providers" {
t.Errorf("APIKey = %q, want %q", models[0].APIKey, "sk-deepseek-from-providers")
}
if models[0].APIBase != "https://api.deepseek.com/v1" {
t.Errorf("APIBase = %q, want %q", models[0].APIBase, "https://api.deepseek.com/v1")
}
}
func TestInheritProviderCredentials_ExplicitValuesTakePrecedence(t *testing.T) {
models := []ModelConfig{
{
ModelName: "my-openai",
Model: "openai/gpt-5.4",
APIKey: "sk-explicit-model-key",
APIBase: "https://my-custom-endpoint.com/v1",
},
}
providers := ProvidersConfig{
OpenAI: OpenAIProviderConfig{
ProviderConfig: ProviderConfig{
APIKey: "sk-provider-key",
APIBase: "https://api.openai.com/v1",
},
},
}
InheritProviderCredentials(models, providers)
if models[0].APIKey != "sk-explicit-model-key" {
t.Errorf("APIKey = %q, want %q (explicit should win)", models[0].APIKey, "sk-explicit-model-key")
}
if models[0].APIBase != "https://my-custom-endpoint.com/v1" {
t.Errorf("APIBase = %q, want %q (explicit should win)", models[0].APIBase, "https://my-custom-endpoint.com/v1")
}
}
func TestInheritProviderCredentials_MultipleModels(t *testing.T) {
models := []ModelConfig{
{ModelName: "groq-llama", Model: "groq/llama-3.1-70b"},
{ModelName: "zhipu-glm", Model: "zhipu/glm-4"},
{ModelName: "custom-openai", Model: "openai/gpt-5.4", APIKey: "sk-already-set"},
}
providers := ProvidersConfig{
Groq: ProviderConfig{APIKey: "gsk-groq-key", Proxy: "http://proxy:8080"},
Zhipu: ProviderConfig{APIKey: "zhipu-key-123", APIBase: "https://zhipu.example.com"},
OpenAI: OpenAIProviderConfig{
ProviderConfig: ProviderConfig{APIKey: "sk-should-not-override"},
},
}
InheritProviderCredentials(models, providers)
// groq model should inherit
if models[0].APIKey != "gsk-groq-key" {
t.Errorf("groq APIKey = %q, want %q", models[0].APIKey, "gsk-groq-key")
}
if models[0].Proxy != "http://proxy:8080" {
t.Errorf("groq Proxy = %q, want %q", models[0].Proxy, "http://proxy:8080")
}
// zhipu model should inherit
if models[1].APIKey != "zhipu-key-123" {
t.Errorf("zhipu APIKey = %q, want %q", models[1].APIKey, "zhipu-key-123")
}
if models[1].APIBase != "https://zhipu.example.com" {
t.Errorf("zhipu APIBase = %q, want %q", models[1].APIBase, "https://zhipu.example.com")
}
// openai model already has key — should NOT be overridden
if models[2].APIKey != "sk-already-set" {
t.Errorf("openai APIKey = %q, want %q (should not be overridden)", models[2].APIKey, "sk-already-set")
}
}
func TestInheritProviderCredentials_NoMatchingProvider(t *testing.T) {
models := []ModelConfig{
{ModelName: "my-model", Model: "novelai/some-model"},
}
providers := ProvidersConfig{
DeepSeek: ProviderConfig{APIKey: "sk-deepseek"},
}
InheritProviderCredentials(models, providers)
// No matching provider for "novelai" protocol — should stay empty
if models[0].APIKey != "" {
t.Errorf("APIKey = %q, want empty (no matching provider)", models[0].APIKey)
}
}
func TestInheritProviderCredentials_EmptyProviders(t *testing.T) {
models := []ModelConfig{
{ModelName: "my-model", Model: "openai/gpt-5.4"},
}
providers := ProvidersConfig{} // all empty
InheritProviderCredentials(models, providers)
// Empty providers — nothing to inherit
if models[0].APIKey != "" {
t.Errorf("APIKey = %q, want empty", models[0].APIKey)
}
}
func TestInheritProviderCredentials_InheritsRequestTimeout(t *testing.T) {
models := []ModelConfig{
{ModelName: "my-ollama", Model: "ollama/llama3.2:3b"},
}
providers := ProvidersConfig{
Ollama: ProviderConfig{
APIBase: "http://localhost:11434",
RequestTimeout: 120,
},
}
InheritProviderCredentials(models, providers)
if models[0].APIBase != "http://localhost:11434" {
t.Errorf("APIBase = %q, want %q", models[0].APIBase, "http://localhost:11434")
}
if models[0].RequestTimeout != 120 {
t.Errorf("RequestTimeout = %d, want 120", models[0].RequestTimeout)
}
}
@@ -221,6 +221,10 @@ func buildRequestBody(
// Add tool_use blocks
for _, tc := range msg.ToolCalls {
if strings.TrimSpace(tc.Name) == "" {
continue
}
// Handle nil Arguments (GLM-4 may return null input)
input := tc.Arguments
if input == nil {
@@ -492,6 +492,20 @@ func TestBuildRequestBodyEdgeCases(t *testing.T) {
},
wantErr: false,
},
{
name: "skip tool calls with empty names",
messages: []Message{
{Role: "assistant", Content: "Calling tool", ToolCalls: []ToolCall{
{ID: "tool-empty", Name: "", Arguments: map[string]any{"ignored": true}},
{ID: "tool-valid", Name: "test_tool", Arguments: map[string]any{"arg": "value"}},
}},
},
model: "test-model",
options: map[string]any{
"max_tokens": 8192,
},
wantErr: false,
},
}
for _, tt := range tests {
@@ -513,6 +527,37 @@ func TestBuildRequestBodyEdgeCases(t *testing.T) {
if got["model"] != tt.model {
t.Errorf("model = %v, want %v", got["model"], tt.model)
}
if tt.name == "skip tool calls with empty names" {
messages, ok := got["messages"].([]any)
if !ok || len(messages) != 1 {
t.Fatalf("messages = %#v, want single assistant message", got["messages"])
}
assistantMsg, ok := messages[0].(map[string]any)
if !ok {
t.Fatalf("assistant message = %#v, want map", messages[0])
}
content, ok := assistantMsg["content"].([]any)
if !ok {
t.Fatalf("assistant content = %#v, want []any", assistantMsg["content"])
}
if len(content) != 2 {
t.Fatalf("assistant content length = %d, want 2", len(content))
}
toolUse, ok := content[1].(map[string]any)
if !ok {
t.Fatalf("tool_use block = %#v, want map", content[1])
}
if gotName := toolUse["name"]; gotName != "test_tool" {
t.Fatalf("tool_use name = %v, want %q", gotName, "test_tool")
}
if gotID := toolUse["id"]; gotID != "tool-valid" {
t.Fatalf("tool_use id = %v, want %q", gotID, "tool-valid")
}
}
})
}
}
+26 -2
View File
@@ -115,8 +115,9 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
"minimax", "longcat", "modelscope", "novita":
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
"coding-plan", "alibaba-coding", "qwen-coding":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -173,6 +174,21 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
case "coding-plan-anthropic", "alibaba-coding-anthropic":
// Alibaba Coding Plan with Anthropic-compatible API
apiBase := cfg.APIBase
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
if cfg.APIKey == "" {
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
}
return anthropicmessages.NewProviderWithTimeout(
cfg.APIKey,
apiBase,
cfg.RequestTimeout,
), modelID, nil
case "antigravity":
return NewAntigravityProvider(), modelID, nil
@@ -245,6 +261,14 @@ func getDefaultAPIBase(protocol string) string {
return "https://ark.cn-beijing.volces.com/api/v3"
case "qwen":
return "https://dashscope.aliyuncs.com/compatible-mode/v1"
case "qwen-intl", "qwen-international", "dashscope-intl":
return "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
case "qwen-us", "dashscope-us":
return "https://dashscope-us.aliyuncs.com/compatible-mode/v1"
case "coding-plan", "alibaba-coding", "qwen-coding":
return "https://coding-intl.dashscope.aliyuncs.com/v1"
case "coding-plan-anthropic", "alibaba-coding-anthropic":
return "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic"
case "vllm":
return "http://localhost:8000/v1"
case "mistral":
+131
View File
@@ -472,3 +472,134 @@ func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
}
}
func TestCreateProviderFromConfig_QwenInternationalAlias(t *testing.T) {
tests := []struct {
name string
protocol string
}{
{"qwen-international", "qwen-international"},
{"dashscope-intl", "dashscope-intl"},
{"qwen-intl", "qwen-intl"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/qwen-max",
APIKey: "test-key",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "qwen-max" {
t.Errorf("modelID = %q, want %q", modelID, "qwen-max")
}
if _, ok := provider.(*HTTPProvider); !ok {
t.Fatalf("expected *HTTPProvider, got %T", provider)
}
})
}
}
func TestCreateProviderFromConfig_QwenUSAlias(t *testing.T) {
tests := []struct {
name string
protocol string
}{
{"qwen-us", "qwen-us"},
{"dashscope-us", "dashscope-us"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/qwen-max",
APIKey: "test-key",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "qwen-max" {
t.Errorf("modelID = %q, want %q", modelID, "qwen-max")
}
if _, ok := provider.(*HTTPProvider); !ok {
t.Fatalf("expected *HTTPProvider, got %T", provider)
}
})
}
}
func TestCreateProviderFromConfig_CodingPlanAnthropic(t *testing.T) {
tests := []struct {
name string
protocol string
}{
{"coding-plan-anthropic", "coding-plan-anthropic"},
{"alibaba-coding-anthropic", "alibaba-coding-anthropic"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-" + tt.protocol,
Model: tt.protocol + "/claude-sonnet-4-20250514",
APIKey: "test-key",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "claude-sonnet-4-20250514" {
t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4-20250514")
}
// coding-plan-anthropic uses Anthropic Messages provider
// Verify it's the anthropic messages provider by checking interface
var _ LLMProvider = provider
})
}
}
func TestGetDefaultAPIBase_CodingPlanAnthropic(t *testing.T) {
expectedURL := "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic"
if got := getDefaultAPIBase("coding-plan-anthropic"); got != expectedURL {
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "coding-plan-anthropic", got, expectedURL)
}
if got := getDefaultAPIBase("alibaba-coding-anthropic"); got != expectedURL {
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "alibaba-coding-anthropic", got, expectedURL)
}
}
func TestGetDefaultAPIBase_QwenIntlAliases(t *testing.T) {
expectedURL := "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
for _, protocol := range []string{"qwen-intl", "qwen-international", "dashscope-intl"} {
if got := getDefaultAPIBase(protocol); got != expectedURL {
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", protocol, got, expectedURL)
}
}
}
func TestGetDefaultAPIBase_QwenUSAliases(t *testing.T) {
expectedURL := "https://dashscope-us.aliyuncs.com/compatible-mode/v1"
for _, protocol := range []string{"qwen-us", "dashscope-us"} {
if got := getDefaultAPIBase(protocol); got != expectedURL {
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", protocol, got, expectedURL)
}
}
}
+8
View File
@@ -53,6 +53,14 @@ func NormalizeProvider(provider string) string {
return "zhipu"
case "google":
return "gemini"
case "alibaba-coding", "qwen-coding":
return "coding-plan"
case "alibaba-coding-anthropic":
return "coding-plan-anthropic"
case "qwen-international", "dashscope-intl":
return "qwen-intl"
case "dashscope-us":
return "qwen-us"
}
return p
+8
View File
@@ -73,6 +73,14 @@ func TestNormalizeProvider(t *testing.T) {
{"glm", "zhipu"},
{"google", "gemini"},
{"groq", "groq"},
// Alibaba Coding Plan aliases
{"alibaba-coding", "coding-plan"},
{"qwen-coding", "coding-plan"},
{"alibaba-coding-anthropic", "coding-plan-anthropic"},
// Qwen international aliases
{"qwen-international", "qwen-intl"},
{"dashscope-intl", "qwen-intl"},
{"dashscope-us", "qwen-us"},
{"", ""},
}