refactor(providers): restructure provider creation with protocol-based configuration

- Move provider creation logic to factory_provider.go with protocol-based approach
- Add OpenAIProviderConfig with WebSearch support and embedded ProviderConfig
- Add maxTokensField to OpenAI-compatible provider for configurable token field
- Introduce new providers: Ollama, DeepSeek, GitHubCopilot, Antigravity, Qwen
- Remove redundant CreateProvider function from factory.go
- Add ThoughtSignature field to FunctionCall for tool response handling
- Remove duplicate Name field assignment in tool loop
- Update tests to reflect new provider configuration structure
This commit is contained in:
yinwm
2026-02-20 00:12:01 +08:00
parent f8f1d539d4
commit 68cdafc5f2
10 changed files with 115 additions and 156 deletions
+4 -3
View File
@@ -15,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/devices"
"github.com/sipeed/picoclaw/pkg/health"
@@ -76,7 +77,7 @@ func gatewayCmd() {
// Setup cron tool and service
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout)
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout, cfg)
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
@@ -202,14 +203,14 @@ func gatewayCmd() {
fmt.Println("✓ Gateway stopped")
}
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *cron.CronService {
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, cfg *config.Config) *cron.CronService {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
// Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout)
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
+17 -12
View File
@@ -86,18 +86,23 @@ func DefaultConfig() *Config {
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
Cerebras: ProviderConfig{},
VolcEngine: ProviderConfig{},
Anthropic: ProviderConfig{},
OpenAI: OpenAIProviderConfig{WebSearch: true},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Ollama: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
DeepSeek: ProviderConfig{},
Cerebras: ProviderConfig{},
VolcEngine: ProviderConfig{},
GitHubCopilot: ProviderConfig{},
Antigravity: ProviderConfig{},
Qwen: ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
+20 -14
View File
@@ -13,9 +13,11 @@ import (
func TestConvertProvidersToModelList_OpenAI(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{
APIKey: "sk-test-key",
APIBase: "https://custom.api.com/v1",
OpenAI: OpenAIProviderConfig{
ProviderConfig: ProviderConfig{
APIKey: "sk-test-key",
APIBase: "https://custom.api.com/v1",
},
},
},
}
@@ -64,7 +66,7 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{APIKey: "openai-key"},
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "openai-key"}},
Groq: ProviderConfig{APIKey: "groq-key"},
Zhipu: ProviderConfig{APIKey: "zhipu-key"},
},
@@ -112,7 +114,7 @@ func TestConvertProvidersToModelList_Nil(t *testing.T) {
func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{APIKey: "key1"},
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}},
Anthropic: ProviderConfig{APIKey: "key2"},
OpenRouter: ProviderConfig{APIKey: "key3"},
Groq: ProviderConfig{APIKey: "key4"},
@@ -143,9 +145,11 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
func TestConvertProvidersToModelList_Proxy(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{
APIKey: "key",
Proxy: "http://proxy:8080",
OpenAI: OpenAIProviderConfig{
ProviderConfig: ProviderConfig{
APIKey: "key",
Proxy: "http://proxy:8080",
},
},
},
}
@@ -164,8 +168,10 @@ func TestConvertProvidersToModelList_Proxy(t *testing.T) {
func TestConvertProvidersToModelList_AuthMethod(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: ProviderConfig{
AuthMethod: "oauth",
OpenAI: OpenAIProviderConfig{
ProviderConfig: ProviderConfig{
AuthMethod: "oauth",
},
},
},
}
@@ -213,7 +219,7 @@ func TestConvertProvidersToModelList_PreservesUserModel_OpenAI(t *testing.T) {
},
},
Providers: ProvidersConfig{
OpenAI: ProviderConfig{APIKey: "sk-openai"},
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}},
},
}
@@ -310,7 +316,7 @@ func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *tes
},
},
Providers: ProvidersConfig{
OpenAI: ProviderConfig{APIKey: "sk-openai"},
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}},
DeepSeek: ProviderConfig{APIKey: "sk-deepseek"},
},
}
@@ -364,7 +370,7 @@ func TestConvertProvidersToModelList_ProviderNameAliases(t *testing.T) {
// Set the appropriate provider config
switch tt.providerAlias {
case "gpt":
cfg.Providers.OpenAI = tt.provider
cfg.Providers.OpenAI = OpenAIProviderConfig{ProviderConfig: tt.provider}
case "claude":
cfg.Providers.Anthropic = tt.provider
case "doubao":
@@ -441,7 +447,7 @@ func TestConvertProvidersToModelList_NoProviderField_MultipleProviders(t *testin
},
},
Providers: ProvidersConfig{
OpenAI: ProviderConfig{APIKey: "openai-key"},
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "openai-key"}},
Zhipu: ProviderConfig{APIKey: "zhipu-key"},
},
}
-53
View File
@@ -35,33 +35,6 @@ type providerSelection struct {
enableWebSearch bool
}
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
if apiBase == "" {
apiBase = defaultAnthropicAPIBase
}
cred, err := getCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
}
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
cred, err := getCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
p.enableWebSearch = enableWebSearch
return p, nil
}
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
@@ -332,29 +305,3 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
return sel, nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
sel, err := resolveProviderSelection(cfg)
if err != nil {
return nil, err
}
switch sel.providerType {
case providerTypeClaudeAuth:
return createClaudeAuthProvider(sel.apiBase)
case providerTypeCodexAuth:
return createCodexAuthProvider(sel.enableWebSearch)
case providerTypeCodexCLIToken:
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
c.enableWebSearch = sel.enableWebSearch
return c, nil
case providerTypeClaudeCLI:
return NewClaudeCliProvider(sel.workspace), nil
case providerTypeCodexCLI:
return NewCodexCliProvider(sel.workspace), nil
case providerTypeGitHubCopilot:
return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model)
default:
return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil
}
}
+2 -3
View File
@@ -9,13 +9,12 @@ import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
cred, err := getCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
@@ -27,7 +26,7 @@ func createClaudeAuthProvider() (LLMProvider, error) {
// createCodexAuthProvider creates a Codex provider using OAuth credentials from auth store.
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
cred, err := getCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
+10 -14
View File
@@ -99,16 +99,15 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
tests := []struct {
name string
protocol string
wantBase string
}{
{"openai", "openai", "https://api.openai.com/v1"},
{"groq", "groq", "https://api.groq.com/openai/v1"},
{"openrouter", "openrouter", "https://openrouter.ai/api/v1"},
{"cerebras", "cerebras", "https://api.cerebras.ai/v1"},
{"qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"},
{"vllm", "vllm", "http://localhost:8000/v1"},
{"deepseek", "deepseek", "https://api.deepseek.com/v1"},
{"ollama", "ollama", "http://localhost:11434/v1"},
{"openai", "openai"},
{"groq", "groq"},
{"openrouter", "openrouter"},
{"cerebras", "cerebras"},
{"qwen", "qwen"},
{"vllm", "vllm"},
{"deepseek", "deepseek"},
{"ollama", "ollama"},
}
for _, tt := range tests {
@@ -124,13 +123,10 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
httpProvider, ok := provider.(*HTTPProvider)
if !ok {
// Verify we got an HTTPProvider for all these protocols
if _, ok := provider.(*HTTPProvider); !ok {
t.Fatalf("expected *HTTPProvider, got %T", provider)
}
if httpProvider.apiBase != tt.wantBase {
t.Errorf("apiBase = %q, want %q", httpProvider.apiBase, tt.wantBase)
}
})
}
}
+36 -43
View File
@@ -199,7 +199,7 @@ func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
@@ -211,9 +211,16 @@ func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "codex-code"
cfg.Agents.Defaults.Model = "test-codex"
cfg.ModelList = []config.ModelConfig{
{
ModelName: "test-codex",
Model: "codex-cli/codex-model",
Workspace: "/tmp/workspace",
},
}
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
@@ -223,18 +230,24 @@ func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
}
}
func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
func TestCreateProviderReturnsClaudeCliProviderForClaudeCli(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
cfg.Agents.Defaults.Model = "test-claude-cli"
cfg.ModelList = []config.ModelConfig{
{
ModelName: "test-claude-cli",
Model: "claude-cli/claude-sonnet",
Workspace: "/tmp/workspace",
},
}
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
if _, ok := provider.(*ClaudeCliProvider); !ok {
t.Fatalf("provider type = %T, want *ClaudeCliProvider", provider)
}
}
@@ -252,48 +265,28 @@ func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "anthropic"
cfg.Providers.Anthropic.AuthMethod = "oauth"
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
cfg.Agents.Defaults.Model = "test-claude-oauth"
cfg.ModelList = []config.ModelConfig{
{
ModelName: "test-claude-oauth",
Model: "anthropic/claude-3-sonnet",
AuthMethod: "oauth",
},
}
provider, err := CreateProvider(cfg)
provider, _, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
claudeProvider, ok := provider.(*ClaudeProvider)
if !ok {
if _, ok := provider.(*ClaudeProvider); !ok {
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
}
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
}
// TODO: Test custom APIBase when createClaudeAuthProvider supports it
}
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "openai" {
t.Fatalf("provider = %q, want openai", provider)
}
return &auth.AuthCredential{
AccessToken: "openai-token",
AccountID: "acct_123",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "oauth"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
// TODO: This test requires openai protocol to support auth_method: "oauth"
// which is not yet implemented in the new factory_provider.go
t.Skip("OpenAI OAuth via model_list not yet implemented")
}
+23 -11
View File
@@ -24,12 +24,17 @@ type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
apiKey string
apiBase string
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
}
func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
}
@@ -46,9 +51,10 @@ func NewProvider(apiKey, apiBase, proxy string) *Provider {
}
return &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
maxTokensField: maxTokensField,
httpClient: client,
}
}
@@ -70,12 +76,18 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
// Fallback: detect from model name for backward compatibility
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") {
fieldName = "max_completion_tokens"
} else {
fieldName = "max_tokens"
}
}
requestBody[fieldName] = maxTokens
}
if temperature, ok := asFloat(options["temperature"]); ok {
+3 -2
View File
@@ -9,8 +9,9 @@ type ToolCall struct {
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
Name string `json:"name"`
Arguments string `json:"arguments"`
ThoughtSignature string `json:"thought_signature,omitempty"`
}
type LLMResponse struct {
-1
View File
@@ -116,7 +116,6 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
Name: tc.Name,
Arguments: string(argumentsJSON),
},
Name: tc.Name,
})
}
messages = append(messages, assistantMsg)