mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'upstream-main' into feat/subturn-poc
This commit is contained in:
+14
-55
@@ -3,13 +3,13 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -85,9 +85,11 @@ func NewAgentInstance(
|
||||
if cfg.Tools.IsToolEnabled("exec") {
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
logger.ErrorCF("agent", "Failed to initialize exec tool; continuing without exec",
|
||||
map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
toolsRegistry.Register(execTool)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
}
|
||||
|
||||
if cfg.Tools.IsToolEnabled("edit_file") {
|
||||
@@ -150,59 +152,14 @@ func NewAgentInstance(
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
Fallbacks: fallbacks,
|
||||
}
|
||||
resolveFromModelList := func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if cfg != nil {
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
|
||||
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
|
||||
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
@@ -210,8 +167,8 @@ func NewAgentInstance(
|
||||
})
|
||||
lightCandidates = resolved
|
||||
} else {
|
||||
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
|
||||
rc.LightModel, agentID)
|
||||
logger.WarnCF("agent", "Routing light model not found; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,7 +277,8 @@ func (a *AgentInstance) Close() error {
|
||||
func initSessionStore(dir string) session.SessionStore {
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
log.Printf("memory: init store: %v; using json sessions", err)
|
||||
logger.WarnCF("agent", "Memory JSONL store init failed; falling back to json sessions",
|
||||
map[string]any{"error": err.Error()})
|
||||
return session.NewSessionManager(dir)
|
||||
}
|
||||
|
||||
@@ -328,11 +286,12 @@ func initSessionStore(dir string) session.SessionStore {
|
||||
// Migration failure means the store could not write data.
|
||||
// Fall back to SessionManager to avoid a split state where
|
||||
// some sessions are in JSONL and others remain in JSON.
|
||||
log.Printf("memory: migration failed: %v; falling back to json sessions", merr)
|
||||
logger.WarnCF("agent", "Memory migration failed; falling back to json sessions",
|
||||
map[string]any{"error": merr.Error()})
|
||||
store.Close()
|
||||
return session.NewSessionManager(dir)
|
||||
} else if n > 0 {
|
||||
log.Printf("memory: migrated %d session(s) to jsonl", n)
|
||||
logger.InfoCF("agent", "Memory migrated to JSONL", map[string]any{"sessions_migrated": n})
|
||||
}
|
||||
|
||||
return session.NewJSONLBackend(store)
|
||||
|
||||
@@ -246,3 +246,37 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_InvalidExecConfigDoesNotExit(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "test-model",
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{Enabled: true},
|
||||
Exec: config.ExecConfig{
|
||||
ToolConfig: config.ToolConfig{Enabled: true},
|
||||
EnableDenyPatterns: true,
|
||||
CustomDenyPatterns: []string{"[invalid-regex"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
if agent == nil {
|
||||
t.Fatal("expected agent instance, got nil")
|
||||
}
|
||||
|
||||
if _, ok := agent.Tools.Get("exec"); ok {
|
||||
t.Fatal("exec tool should not be registered when exec config is invalid")
|
||||
}
|
||||
|
||||
if _, ok := agent.Tools.Get("read_file"); !ok {
|
||||
t.Fatal("read_file tool should still be registered")
|
||||
}
|
||||
}
|
||||
|
||||
+30
-4
@@ -1760,7 +1760,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, agent.Model
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
@@ -1771,7 +1771,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, agent.Model
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
@@ -1781,7 +1781,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, agent.Router.LightModel()
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
@@ -2271,11 +2271,37 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
}
|
||||
if agent != nil {
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, cfg.Agents.Defaults.Provider
|
||||
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
|
||||
}
|
||||
rt.SwitchModel = func(value string) (string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
modelCfg, err := resolvedModelConfig(cfg, value, agent.Workspace)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nextProvider, _, err := providers.CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
|
||||
}
|
||||
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
|
||||
if len(nextCandidates) == 0 {
|
||||
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
|
||||
}
|
||||
|
||||
oldModel := agent.Model
|
||||
oldProvider := agent.Provider
|
||||
agent.Model = value
|
||||
agent.Provider = nextProvider
|
||||
agent.Candidates = nextCandidates
|
||||
agent.ThinkingLevel = parseThinkingLevel(modelCfg.ThinkingLevel)
|
||||
|
||||
if oldProvider != nil && oldProvider != nextProvider {
|
||||
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
|
||||
stateful.Close()
|
||||
}
|
||||
}
|
||||
return oldModel, nil
|
||||
}
|
||||
|
||||
|
||||
+242
-4
@@ -2,7 +2,10 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@@ -444,6 +447,46 @@ type testHelper struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func newChatCompletionTestServer(
|
||||
t *testing.T,
|
||||
label string,
|
||||
response string,
|
||||
calls *int,
|
||||
model *string,
|
||||
) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
||||
}
|
||||
*calls = *calls + 1
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
decodeErr := json.NewDecoder(r.Body).Decode(&req)
|
||||
if decodeErr != nil {
|
||||
t.Fatalf("decode %s request: %v", label, decodeErr)
|
||||
}
|
||||
*model = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
encodeErr := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": response},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
if encodeErr != nil {
|
||||
t.Fatalf("encode %s response: %v", label, encodeErr)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
@@ -605,11 +648,25 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "before-switch",
|
||||
Model: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
@@ -621,13 +678,13 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to after-switch",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
@@ -641,7 +698,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
|
||||
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
}
|
||||
|
||||
@@ -650,6 +707,187 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to missing",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if switchResp != `model "missing" not found in model_list or providers` {
|
||||
t.Fatalf("unexpected /switch error reply: %q", switchResp)
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
|
||||
}
|
||||
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
localCalls := 0
|
||||
localModel := ""
|
||||
localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel)
|
||||
defer localServer.Close()
|
||||
|
||||
remoteCalls := 0
|
||||
remoteModel := ""
|
||||
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
|
||||
defer remoteServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/Qwen3.5-35B-A3B",
|
||||
APIKey: "local-key",
|
||||
APIBase: localServer.URL,
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIKey: "remote-key",
|
||||
APIBase: remoteServer.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello before switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if firstResp != "local reply" {
|
||||
t.Fatalf("unexpected response before switch: %q", firstResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls before switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 0 {
|
||||
t.Fatalf("remote calls before switch = %d, want 0", remoteCalls)
|
||||
}
|
||||
if localModel != "Qwen3.5-35B-A3B" {
|
||||
t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B")
|
||||
}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello after switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if secondResp != "remote reply" {
|
||||
t.Fatalf("unexpected response after switch: %q", secondResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls after switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 1 {
|
||||
t.Fatalf("remote calls after switch = %d, want 1", remoteCalls)
|
||||
}
|
||||
if remoteModel != "deepseek-v3.2" {
|
||||
t.Fatalf(
|
||||
"remote model after switch = %q, want %q",
|
||||
remoteModel,
|
||||
"deepseek-v3.2",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
return func(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || cfg == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func resolveModelCandidates(
|
||||
cfg *config.Config,
|
||||
defaultProvider string,
|
||||
primary string,
|
||||
fallbacks []string,
|
||||
) []providers.FallbackCandidate {
|
||||
return providers.ResolveCandidatesWithLookup(
|
||||
providers.ModelConfig{
|
||||
Primary: primary,
|
||||
Fallbacks: fallbacks,
|
||||
},
|
||||
defaultProvider,
|
||||
buildModelListResolver(cfg),
|
||||
)
|
||||
}
|
||||
|
||||
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Model) != "" {
|
||||
return candidates[0].Model
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Provider) != "" {
|
||||
return candidates[0].Provider
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
modelCfg, err := cfg.GetModelConfig(strings.TrimSpace(modelName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clone := *modelCfg
|
||||
if clone.Workspace == "" {
|
||||
clone.Workspace = workspace
|
||||
}
|
||||
|
||||
return &clone, nil
|
||||
}
|
||||
@@ -924,6 +924,15 @@ func LoadConfig(path string) (*Config, error) {
|
||||
cfg.ModelList = ConvertProvidersToModelList(cfg)
|
||||
}
|
||||
|
||||
// Inherit credentials from providers to model_list entries (#1635).
|
||||
// When both providers and model_list are present, model_list entries
|
||||
// whose api_key/api_base are empty will inherit from the matching
|
||||
// provider (matched by protocol prefix). Explicit model_list values
|
||||
// always take precedence.
|
||||
if cfg.HasProvidersConfig() {
|
||||
InheritProviderCredentials(cfg.ModelList, cfg.Providers)
|
||||
}
|
||||
|
||||
// Validate model_list for uniqueness and required fields
|
||||
if err := cfg.ValidateModelList(); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -468,3 +468,84 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// protocolProviderMapping maps a model protocol prefix (the part before "/" in
|
||||
// the Model field) to a function that extracts the corresponding ProviderConfig
|
||||
// from the legacy ProvidersConfig. Used by InheritProviderCredentials.
|
||||
var protocolProviderMapping = map[string]func(p ProvidersConfig) ProviderConfig{
|
||||
"openai": func(p ProvidersConfig) ProviderConfig { return p.OpenAI.ProviderConfig },
|
||||
"anthropic": func(p ProvidersConfig) ProviderConfig { return p.Anthropic },
|
||||
"litellm": func(p ProvidersConfig) ProviderConfig { return p.LiteLLM },
|
||||
"openrouter": func(p ProvidersConfig) ProviderConfig { return p.OpenRouter },
|
||||
"groq": func(p ProvidersConfig) ProviderConfig { return p.Groq },
|
||||
"zhipu": func(p ProvidersConfig) ProviderConfig { return p.Zhipu },
|
||||
"vllm": func(p ProvidersConfig) ProviderConfig { return p.VLLM },
|
||||
"gemini": func(p ProvidersConfig) ProviderConfig { return p.Gemini },
|
||||
"nvidia": func(p ProvidersConfig) ProviderConfig { return p.Nvidia },
|
||||
"ollama": func(p ProvidersConfig) ProviderConfig { return p.Ollama },
|
||||
"moonshot": func(p ProvidersConfig) ProviderConfig { return p.Moonshot },
|
||||
"shengsuanyun": func(p ProvidersConfig) ProviderConfig { return p.ShengSuanYun },
|
||||
"deepseek": func(p ProvidersConfig) ProviderConfig { return p.DeepSeek },
|
||||
"cerebras": func(p ProvidersConfig) ProviderConfig { return p.Cerebras },
|
||||
"vivgrid": func(p ProvidersConfig) ProviderConfig { return p.Vivgrid },
|
||||
"volcengine": func(p ProvidersConfig) ProviderConfig { return p.VolcEngine },
|
||||
"github-copilot": func(p ProvidersConfig) ProviderConfig { return p.GitHubCopilot },
|
||||
"antigravity": func(p ProvidersConfig) ProviderConfig { return p.Antigravity },
|
||||
"qwen": func(p ProvidersConfig) ProviderConfig { return p.Qwen },
|
||||
"mistral": func(p ProvidersConfig) ProviderConfig { return p.Mistral },
|
||||
"avian": func(p ProvidersConfig) ProviderConfig { return p.Avian },
|
||||
"minimax": func(p ProvidersConfig) ProviderConfig { return p.Minimax },
|
||||
"longcat": func(p ProvidersConfig) ProviderConfig { return p.LongCat },
|
||||
"modelscope": func(p ProvidersConfig) ProviderConfig { return p.ModelScope },
|
||||
"novita": func(p ProvidersConfig) ProviderConfig { return p.Novita },
|
||||
}
|
||||
|
||||
// InheritProviderCredentials fills in missing api_key, api_base, proxy, and
|
||||
// request_timeout on model_list entries from the matching legacy providers
|
||||
// configuration. The match is determined by the protocol prefix in the Model
|
||||
// field (e.g. "deepseek/deepseek-chat" matches providers.deepseek).
|
||||
//
|
||||
// Only empty fields are filled — any value explicitly set on a model_list entry
|
||||
// takes precedence. This function modifies the slice in place.
|
||||
//
|
||||
// This bridges the gap described in issue #1635: users who configure
|
||||
// credentials once in the providers section expect model_list entries using
|
||||
// the same protocol to "just work" without duplicating credentials.
|
||||
func InheritProviderCredentials(models []ModelConfig, providers ProvidersConfig) {
|
||||
if providers.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range models {
|
||||
m := &models[i]
|
||||
|
||||
// Extract protocol prefix from Model field
|
||||
protocol := ""
|
||||
if idx := strings.Index(m.Model, "/"); idx > 0 {
|
||||
protocol = strings.ToLower(m.Model[:idx])
|
||||
}
|
||||
if protocol == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
getProvider, ok := protocolProviderMapping[protocol]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pc := getProvider(providers)
|
||||
|
||||
// Only fill empty fields — explicit model_list values win
|
||||
if m.APIKey == "" && pc.APIKey != "" {
|
||||
m.APIKey = pc.APIKey
|
||||
}
|
||||
if m.APIBase == "" && pc.APIBase != "" {
|
||||
m.APIBase = pc.APIBase
|
||||
}
|
||||
if m.Proxy == "" && pc.Proxy != "" {
|
||||
m.Proxy = pc.Proxy
|
||||
}
|
||||
if m.RequestTimeout == 0 && pc.RequestTimeout != 0 {
|
||||
m.RequestTimeout = pc.RequestTimeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -613,3 +613,143 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T)
|
||||
t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- InheritProviderCredentials tests ----------
|
||||
|
||||
func TestInheritProviderCredentials_FillsMissingAPIKey(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{ModelName: "my-deepseek", Model: "deepseek/deepseek-chat"},
|
||||
}
|
||||
providers := ProvidersConfig{
|
||||
DeepSeek: ProviderConfig{
|
||||
APIKey: "sk-deepseek-from-providers",
|
||||
APIBase: "https://api.deepseek.com/v1",
|
||||
},
|
||||
}
|
||||
|
||||
InheritProviderCredentials(models, providers)
|
||||
|
||||
if models[0].APIKey != "sk-deepseek-from-providers" {
|
||||
t.Errorf("APIKey = %q, want %q", models[0].APIKey, "sk-deepseek-from-providers")
|
||||
}
|
||||
if models[0].APIBase != "https://api.deepseek.com/v1" {
|
||||
t.Errorf("APIBase = %q, want %q", models[0].APIBase, "https://api.deepseek.com/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInheritProviderCredentials_ExplicitValuesTakePrecedence(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
ModelName: "my-openai",
|
||||
Model: "openai/gpt-5.4",
|
||||
APIKey: "sk-explicit-model-key",
|
||||
APIBase: "https://my-custom-endpoint.com/v1",
|
||||
},
|
||||
}
|
||||
providers := ProvidersConfig{
|
||||
OpenAI: OpenAIProviderConfig{
|
||||
ProviderConfig: ProviderConfig{
|
||||
APIKey: "sk-provider-key",
|
||||
APIBase: "https://api.openai.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
InheritProviderCredentials(models, providers)
|
||||
|
||||
if models[0].APIKey != "sk-explicit-model-key" {
|
||||
t.Errorf("APIKey = %q, want %q (explicit should win)", models[0].APIKey, "sk-explicit-model-key")
|
||||
}
|
||||
if models[0].APIBase != "https://my-custom-endpoint.com/v1" {
|
||||
t.Errorf("APIBase = %q, want %q (explicit should win)", models[0].APIBase, "https://my-custom-endpoint.com/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInheritProviderCredentials_MultipleModels(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{ModelName: "groq-llama", Model: "groq/llama-3.1-70b"},
|
||||
{ModelName: "zhipu-glm", Model: "zhipu/glm-4"},
|
||||
{ModelName: "custom-openai", Model: "openai/gpt-5.4", APIKey: "sk-already-set"},
|
||||
}
|
||||
providers := ProvidersConfig{
|
||||
Groq: ProviderConfig{APIKey: "gsk-groq-key", Proxy: "http://proxy:8080"},
|
||||
Zhipu: ProviderConfig{APIKey: "zhipu-key-123", APIBase: "https://zhipu.example.com"},
|
||||
OpenAI: OpenAIProviderConfig{
|
||||
ProviderConfig: ProviderConfig{APIKey: "sk-should-not-override"},
|
||||
},
|
||||
}
|
||||
|
||||
InheritProviderCredentials(models, providers)
|
||||
|
||||
// groq model should inherit
|
||||
if models[0].APIKey != "gsk-groq-key" {
|
||||
t.Errorf("groq APIKey = %q, want %q", models[0].APIKey, "gsk-groq-key")
|
||||
}
|
||||
if models[0].Proxy != "http://proxy:8080" {
|
||||
t.Errorf("groq Proxy = %q, want %q", models[0].Proxy, "http://proxy:8080")
|
||||
}
|
||||
|
||||
// zhipu model should inherit
|
||||
if models[1].APIKey != "zhipu-key-123" {
|
||||
t.Errorf("zhipu APIKey = %q, want %q", models[1].APIKey, "zhipu-key-123")
|
||||
}
|
||||
if models[1].APIBase != "https://zhipu.example.com" {
|
||||
t.Errorf("zhipu APIBase = %q, want %q", models[1].APIBase, "https://zhipu.example.com")
|
||||
}
|
||||
|
||||
// openai model already has key — should NOT be overridden
|
||||
if models[2].APIKey != "sk-already-set" {
|
||||
t.Errorf("openai APIKey = %q, want %q (should not be overridden)", models[2].APIKey, "sk-already-set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInheritProviderCredentials_NoMatchingProvider(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{ModelName: "my-model", Model: "novelai/some-model"},
|
||||
}
|
||||
providers := ProvidersConfig{
|
||||
DeepSeek: ProviderConfig{APIKey: "sk-deepseek"},
|
||||
}
|
||||
|
||||
InheritProviderCredentials(models, providers)
|
||||
|
||||
// No matching provider for "novelai" protocol — should stay empty
|
||||
if models[0].APIKey != "" {
|
||||
t.Errorf("APIKey = %q, want empty (no matching provider)", models[0].APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInheritProviderCredentials_EmptyProviders(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{ModelName: "my-model", Model: "openai/gpt-5.4"},
|
||||
}
|
||||
providers := ProvidersConfig{} // all empty
|
||||
|
||||
InheritProviderCredentials(models, providers)
|
||||
|
||||
// Empty providers — nothing to inherit
|
||||
if models[0].APIKey != "" {
|
||||
t.Errorf("APIKey = %q, want empty", models[0].APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInheritProviderCredentials_InheritsRequestTimeout(t *testing.T) {
|
||||
models := []ModelConfig{
|
||||
{ModelName: "my-ollama", Model: "ollama/llama3.2:3b"},
|
||||
}
|
||||
providers := ProvidersConfig{
|
||||
Ollama: ProviderConfig{
|
||||
APIBase: "http://localhost:11434",
|
||||
RequestTimeout: 120,
|
||||
},
|
||||
}
|
||||
|
||||
InheritProviderCredentials(models, providers)
|
||||
|
||||
if models[0].APIBase != "http://localhost:11434" {
|
||||
t.Errorf("APIBase = %q, want %q", models[0].APIBase, "http://localhost:11434")
|
||||
}
|
||||
if models[0].RequestTimeout != 120 {
|
||||
t.Errorf("RequestTimeout = %d, want 120", models[0].RequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +221,10 @@ func buildRequestBody(
|
||||
|
||||
// Add tool_use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nil Arguments (GLM-4 may return null input)
|
||||
input := tc.Arguments
|
||||
if input == nil {
|
||||
|
||||
@@ -492,6 +492,20 @@ func TestBuildRequestBodyEdgeCases(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "skip tool calls with empty names",
|
||||
messages: []Message{
|
||||
{Role: "assistant", Content: "Calling tool", ToolCalls: []ToolCall{
|
||||
{ID: "tool-empty", Name: "", Arguments: map[string]any{"ignored": true}},
|
||||
{ID: "tool-valid", Name: "test_tool", Arguments: map[string]any{"arg": "value"}},
|
||||
}},
|
||||
},
|
||||
model: "test-model",
|
||||
options: map[string]any{
|
||||
"max_tokens": 8192,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -513,6 +527,37 @@ func TestBuildRequestBodyEdgeCases(t *testing.T) {
|
||||
if got["model"] != tt.model {
|
||||
t.Errorf("model = %v, want %v", got["model"], tt.model)
|
||||
}
|
||||
|
||||
if tt.name == "skip tool calls with empty names" {
|
||||
messages, ok := got["messages"].([]any)
|
||||
if !ok || len(messages) != 1 {
|
||||
t.Fatalf("messages = %#v, want single assistant message", got["messages"])
|
||||
}
|
||||
|
||||
assistantMsg, ok := messages[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("assistant message = %#v, want map", messages[0])
|
||||
}
|
||||
|
||||
content, ok := assistantMsg["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("assistant content = %#v, want []any", assistantMsg["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("assistant content length = %d, want 2", len(content))
|
||||
}
|
||||
|
||||
toolUse, ok := content[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("tool_use block = %#v, want map", content[1])
|
||||
}
|
||||
if gotName := toolUse["name"]; gotName != "test_tool" {
|
||||
t.Fatalf("tool_use name = %v, want %q", gotName, "test_tool")
|
||||
}
|
||||
if gotID := toolUse["id"]; gotID != "tool-valid" {
|
||||
t.Fatalf("tool_use id = %v, want %q", gotID, "tool-valid")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,8 +115,9 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
"minimax", "longcat", "modelscope", "novita":
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
|
||||
"coding-plan", "alibaba-coding", "qwen-coding":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -173,6 +174,21 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "coding-plan-anthropic", "alibaba-coding-anthropic":
|
||||
// Alibaba Coding Plan with Anthropic-compatible API
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
if cfg.APIKey == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
|
||||
}
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey,
|
||||
apiBase,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
|
||||
@@ -245,6 +261,14 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://ark.cn-beijing.volces.com/api/v3"
|
||||
case "qwen":
|
||||
return "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
case "qwen-intl", "qwen-international", "dashscope-intl":
|
||||
return "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||||
case "qwen-us", "dashscope-us":
|
||||
return "https://dashscope-us.aliyuncs.com/compatible-mode/v1"
|
||||
case "coding-plan", "alibaba-coding", "qwen-coding":
|
||||
return "https://coding-intl.dashscope.aliyuncs.com/v1"
|
||||
case "coding-plan-anthropic", "alibaba-coding-anthropic":
|
||||
return "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic"
|
||||
case "vllm":
|
||||
return "http://localhost:8000/v1"
|
||||
case "mistral":
|
||||
|
||||
@@ -472,3 +472,134 @@ func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_QwenInternationalAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
}{
|
||||
{"qwen-international", "qwen-international"},
|
||||
{"dashscope-intl", "dashscope-intl"},
|
||||
{"qwen-intl", "qwen-intl"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/qwen-max",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "qwen-max" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "qwen-max")
|
||||
}
|
||||
if _, ok := provider.(*HTTPProvider); !ok {
|
||||
t.Fatalf("expected *HTTPProvider, got %T", provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_QwenUSAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
}{
|
||||
{"qwen-us", "qwen-us"},
|
||||
{"dashscope-us", "dashscope-us"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/qwen-max",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "qwen-max" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "qwen-max")
|
||||
}
|
||||
if _, ok := provider.(*HTTPProvider); !ok {
|
||||
t.Fatalf("expected *HTTPProvider, got %T", provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_CodingPlanAnthropic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
}{
|
||||
{"coding-plan-anthropic", "coding-plan-anthropic"},
|
||||
{"alibaba-coding-anthropic", "alibaba-coding-anthropic"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-" + tt.protocol,
|
||||
Model: tt.protocol + "/claude-sonnet-4-20250514",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "claude-sonnet-4-20250514" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4-20250514")
|
||||
}
|
||||
// coding-plan-anthropic uses Anthropic Messages provider
|
||||
// Verify it's the anthropic messages provider by checking interface
|
||||
var _ LLMProvider = provider
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultAPIBase_CodingPlanAnthropic(t *testing.T) {
|
||||
expectedURL := "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic"
|
||||
if got := getDefaultAPIBase("coding-plan-anthropic"); got != expectedURL {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "coding-plan-anthropic", got, expectedURL)
|
||||
}
|
||||
if got := getDefaultAPIBase("alibaba-coding-anthropic"); got != expectedURL {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "alibaba-coding-anthropic", got, expectedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultAPIBase_QwenIntlAliases(t *testing.T) {
|
||||
expectedURL := "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||||
for _, protocol := range []string{"qwen-intl", "qwen-international", "dashscope-intl"} {
|
||||
if got := getDefaultAPIBase(protocol); got != expectedURL {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", protocol, got, expectedURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultAPIBase_QwenUSAliases(t *testing.T) {
|
||||
expectedURL := "https://dashscope-us.aliyuncs.com/compatible-mode/v1"
|
||||
for _, protocol := range []string{"qwen-us", "dashscope-us"} {
|
||||
if got := getDefaultAPIBase(protocol); got != expectedURL {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", protocol, got, expectedURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,14 @@ func NormalizeProvider(provider string) string {
|
||||
return "zhipu"
|
||||
case "google":
|
||||
return "gemini"
|
||||
case "alibaba-coding", "qwen-coding":
|
||||
return "coding-plan"
|
||||
case "alibaba-coding-anthropic":
|
||||
return "coding-plan-anthropic"
|
||||
case "qwen-international", "dashscope-intl":
|
||||
return "qwen-intl"
|
||||
case "dashscope-us":
|
||||
return "qwen-us"
|
||||
}
|
||||
|
||||
return p
|
||||
|
||||
@@ -73,6 +73,14 @@ func TestNormalizeProvider(t *testing.T) {
|
||||
{"glm", "zhipu"},
|
||||
{"google", "gemini"},
|
||||
{"groq", "groq"},
|
||||
// Alibaba Coding Plan aliases
|
||||
{"alibaba-coding", "coding-plan"},
|
||||
{"qwen-coding", "coding-plan"},
|
||||
{"alibaba-coding-anthropic", "coding-plan-anthropic"},
|
||||
// Qwen international aliases
|
||||
{"qwen-international", "qwen-intl"},
|
||||
{"dashscope-intl", "qwen-intl"},
|
||||
{"dashscope-us", "qwen-us"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user