mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
providers: finalize PR213 review fixes
Phase 1: centralize protocol message/tool/response types in protocoltypes and keep compatibility aliases in providers and protocol packages. Phase 1: preserve HTTPProvider constructor compatibility and route Anthropic api_base through factory auth/provider constructors with base URL normalization. Phase 2: expand provider routing/auth tests (deepseek/nvidia/shengsuanyun, codex/claude oauth/codex-cli) and add openai_compat + anthropic coverage for proxy transport, model normalization, numeric option coercion, token-source refresh, and base URL behavior. Phase 3: apply gofmt and validate with Dockerized tests (go test ./pkg/providers/... ./pkg/migrate and go test ./...).
This commit is contained in:
@@ -4,74 +4,59 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
const defaultBaseURL = "https://api.anthropic.com"
|
||||
|
||||
type Provider struct {
|
||||
client *anthropic.Client
|
||||
tokenSource func() (string, error)
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewProvider(token string) *Provider {
|
||||
return NewProviderWithBaseURL(token, "")
|
||||
}
|
||||
|
||||
func NewProviderWithBaseURL(token, apiBase string) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL("https://api.anthropic.com"),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
return &Provider{client: &client}
|
||||
return &Provider{
|
||||
client: &client,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithClient(client *anthropic.Client) *Provider {
|
||||
return &Provider{client: client}
|
||||
return &Provider{
|
||||
client: client,
|
||||
baseURL: defaultBaseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider {
|
||||
p := NewProvider(token)
|
||||
return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "")
|
||||
}
|
||||
|
||||
func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider {
|
||||
p := NewProviderWithBaseURL(token, apiBase)
|
||||
p.tokenSource = tokenSource
|
||||
return p
|
||||
}
|
||||
@@ -103,6 +88,10 @@ func (p *Provider) GetDefaultModel() string {
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
|
||||
func (p *Provider) BaseURL() string {
|
||||
return p.baseURL
|
||||
}
|
||||
|
||||
func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||
var system []anthropic.TextBlockParam
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
@@ -208,6 +197,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
|
||||
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
@@ -239,3 +229,20 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeBaseURL(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if strings.HasSuffix(base, "/v1") {
|
||||
base = strings.TrimSuffix(base, "/v1")
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
@@ -199,6 +200,62 @@ func TestProvider_GetDefaultModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) {
|
||||
p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/")
|
||||
if got := p.BaseURL(); got != "https://api.anthropic.com" {
|
||||
t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ChatUsesTokenSource(t *testing.T) {
|
||||
var requests int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
atomic.AddInt32(&requests, 1)
|
||||
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": reqBody["model"],
|
||||
"stop_reason": "end_turn",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "ok"},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
|
||||
return "refreshed-token", nil
|
||||
}, server.URL)
|
||||
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&requests); got != 1 {
|
||||
t.Fatalf("requests = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||
c := anthropic.NewClient(
|
||||
anthropicoption.WithAuthToken(token),
|
||||
|
||||
@@ -3,8 +3,6 @@ package providers
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
|
||||
)
|
||||
|
||||
@@ -18,28 +16,34 @@ func NewClaudeProvider(token string) *ClaudeProvider {
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
|
||||
}
|
||||
}
|
||||
|
||||
func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider {
|
||||
return &ClaudeProvider{delegate: delegate}
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
resp, err := p.delegate.Chat(
|
||||
ctx,
|
||||
toAnthropicProviderMessages(messages),
|
||||
toAnthropicProviderTools(tools),
|
||||
model,
|
||||
options,
|
||||
)
|
||||
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fromAnthropicProviderResponse(resp), nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) GetDefaultModel() string {
|
||||
@@ -48,7 +52,7 @@ func (p *ClaudeProvider) GetDefaultModel() string {
|
||||
|
||||
func createClaudeTokenSource() func() (string, error) {
|
||||
return func() (string, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
cred, err := getCredential("anthropic")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
@@ -58,95 +62,3 @@ func createClaudeTokenSource() func() (string, error) {
|
||||
return cred.AccessToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
func toAnthropicProviderMessages(messages []Message) []anthropicprovider.Message {
|
||||
out := make([]anthropicprovider.Message, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
out = append(out, anthropicprovider.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ToolCalls: toAnthropicProviderToolCalls(msg.ToolCalls),
|
||||
ToolCallID: msg.ToolCallID,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toAnthropicProviderTools(tools []ToolDefinition) []anthropicprovider.ToolDefinition {
|
||||
out := make([]anthropicprovider.ToolDefinition, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, anthropicprovider.ToolDefinition{
|
||||
Type: t.Type,
|
||||
Function: anthropicprovider.ToolFunctionDefinition{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: t.Function.Parameters,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toAnthropicProviderToolCalls(toolCalls []ToolCall) []anthropicprovider.ToolCall {
|
||||
out := make([]anthropicprovider.ToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
var fn *anthropicprovider.FunctionCall
|
||||
if tc.Function != nil {
|
||||
fn = &anthropicprovider.FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
out = append(out, anthropicprovider.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: fn,
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fromAnthropicProviderResponse(resp *anthropicprovider.LLMResponse) *LLMResponse {
|
||||
if resp == nil {
|
||||
return &LLMResponse{}
|
||||
}
|
||||
|
||||
var usage *UsageInfo
|
||||
if resp.Usage != nil {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: resp.Content,
|
||||
ToolCalls: fromAnthropicProviderToolCalls(resp.ToolCalls),
|
||||
FinishReason: resp.FinishReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func fromAnthropicProviderToolCalls(toolCalls []anthropicprovider.ToolCall) []ToolCall {
|
||||
out := make([]ToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
var fn *FunctionCall
|
||||
if tc.Function != nil {
|
||||
fn = &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
out = append(out, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: fn,
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
|
||||
|
||||
var getCredential = auth.GetCredential
|
||||
|
||||
type providerType int
|
||||
|
||||
const (
|
||||
@@ -30,19 +34,22 @@ type providerSelection struct {
|
||||
connectMode string
|
||||
}
|
||||
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
cred, err := getCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
cred, err := getCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
@@ -69,6 +76,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Groq.APIKey
|
||||
sel.apiBase = cfg.Providers.Groq.APIBase
|
||||
sel.proxy = cfg.Providers.Groq.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
@@ -85,6 +93,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
}
|
||||
sel.apiKey = cfg.Providers.OpenAI.APIKey
|
||||
sel.apiBase = cfg.Providers.OpenAI.APIBase
|
||||
sel.proxy = cfg.Providers.OpenAI.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
@@ -92,18 +101,24 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.Anthropic.APIKey
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.anthropic.com/v1"
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
@@ -114,6 +129,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
sel.apiBase = cfg.Providers.Zhipu.APIBase
|
||||
sel.proxy = cfg.Providers.Zhipu.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
@@ -122,6 +138,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Gemini.APIKey
|
||||
sel.apiBase = cfg.Providers.Gemini.APIBase
|
||||
sel.proxy = cfg.Providers.Gemini.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
@@ -130,15 +147,26 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
sel.proxy = cfg.Providers.VLLM.Proxy
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "nvidia":
|
||||
if cfg.Providers.Nvidia.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Nvidia.APIKey
|
||||
sel.apiBase = cfg.Providers.Nvidia.APIBase
|
||||
sel.proxy = cfg.Providers.Nvidia.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claude-code", "claudecode":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
@@ -159,6 +187,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
sel.apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
sel.proxy = cfg.Providers.DeepSeek.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
@@ -204,6 +233,10 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
|
||||
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
@@ -211,7 +244,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.anthropic.com/v1"
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
|
||||
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
@@ -303,7 +336,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
|
||||
switch sel.providerType {
|
||||
case providerTypeClaudeAuth:
|
||||
return createClaudeAuthProvider()
|
||||
return createClaudeAuthProvider(sel.apiBase)
|
||||
case providerTypeCodexAuth:
|
||||
return createCodexAuthProvider()
|
||||
case providerTypeCodexCLIToken:
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
@@ -32,6 +33,40 @@ func TestResolveProviderSelection(t *testing.T) {
|
||||
wantType: providerTypeGitHubCopilot,
|
||||
wantAPIBase: "localhost:4321",
|
||||
},
|
||||
{
|
||||
name: "explicit deepseek provider uses deepseek defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "deepseek"
|
||||
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
|
||||
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
|
||||
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.deepseek.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit shengsuanyun provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "shengsuanyun"
|
||||
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
|
||||
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit nvidia provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "nvidia"
|
||||
cfg.Providers.Nvidia.APIKey = "nvapi-test"
|
||||
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://integrate.api.nvidia.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "openrouter model uses openrouter defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
@@ -202,3 +237,63 @@ func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
|
||||
t.Fatalf("provider type = %T, want *CodexProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
|
||||
originalGetCredential := getCredential
|
||||
t.Cleanup(func() { getCredential = originalGetCredential })
|
||||
|
||||
getCredential = func(provider string) (*auth.AuthCredential, error) {
|
||||
if provider != "anthropic" {
|
||||
t.Fatalf("provider = %q, want anthropic", provider)
|
||||
}
|
||||
return &auth.AuthCredential{
|
||||
AccessToken: "anthropic-token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "anthropic"
|
||||
cfg.Providers.Anthropic.AuthMethod = "oauth"
|
||||
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
claudeProvider, ok := provider.(*ClaudeProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
|
||||
}
|
||||
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
|
||||
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
|
||||
originalGetCredential := getCredential
|
||||
t.Cleanup(func() { getCredential = originalGetCredential })
|
||||
|
||||
getCredential = func(provider string) (*auth.AuthCredential, error) {
|
||||
if provider != "openai" {
|
||||
t.Fatalf("provider = %q, want openai", provider)
|
||||
}
|
||||
return &auth.AuthCredential{
|
||||
AccessToken: "openai-token",
|
||||
AccountID: "acct_123",
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "openai"
|
||||
cfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*CodexProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *CodexProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,116 +15,16 @@ type HTTPProvider struct {
|
||||
delegate *openai_compat.Provider
|
||||
}
|
||||
|
||||
func NewHTTPProvider(apiKey, apiBase string, proxy ...string) *HTTPProvider {
|
||||
proxyURL := ""
|
||||
if len(proxy) > 0 {
|
||||
proxyURL = proxy[0]
|
||||
}
|
||||
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProvider(apiKey, apiBase, proxyURL),
|
||||
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
compatResp, err := p.delegate.Chat(ctx, toOpenAICompatMessages(messages), toOpenAICompatTools(tools), model, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fromOpenAICompatResponse(compatResp), nil
|
||||
return p.delegate.Chat(ctx, messages, tools, model, options)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func toOpenAICompatMessages(messages []Message) []openai_compat.Message {
|
||||
out := make([]openai_compat.Message, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
out = append(out, openai_compat.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ToolCalls: toOpenAICompatToolCalls(msg.ToolCalls),
|
||||
ToolCallID: msg.ToolCallID,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toOpenAICompatTools(tools []ToolDefinition) []openai_compat.ToolDefinition {
|
||||
out := make([]openai_compat.ToolDefinition, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
out = append(out, openai_compat.ToolDefinition{
|
||||
Type: t.Type,
|
||||
Function: openai_compat.ToolFunctionDefinition{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: t.Function.Parameters,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toOpenAICompatToolCalls(toolCalls []ToolCall) []openai_compat.ToolCall {
|
||||
out := make([]openai_compat.ToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
var fn *openai_compat.FunctionCall
|
||||
if tc.Function != nil {
|
||||
fn = &openai_compat.FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
out = append(out, openai_compat.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: fn,
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fromOpenAICompatResponse(resp *openai_compat.LLMResponse) *LLMResponse {
|
||||
if resp == nil {
|
||||
return &LLMResponse{}
|
||||
}
|
||||
|
||||
var usage *UsageInfo
|
||||
if resp.Usage != nil {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: resp.Content,
|
||||
ToolCalls: fromOpenAICompatToolCalls(resp.ToolCalls),
|
||||
FinishReason: resp.FinishReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func fromOpenAICompatToolCalls(toolCalls []openai_compat.ToolCall) []ToolCall {
|
||||
out := make([]ToolCall, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
var fn *FunctionCall
|
||||
if tc.Function != nil {
|
||||
fn = &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
out = append(out, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: fn,
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -6,55 +6,22 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
@@ -62,21 +29,19 @@ type Provider struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase string, proxy ...string) *Provider {
|
||||
proxyURL := ""
|
||||
if len(proxy) > 0 {
|
||||
proxyURL = proxy[0]
|
||||
}
|
||||
func NewProvider(apiKey, apiBase, proxy string) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxyURL != "" {
|
||||
parsed, err := url.Parse(proxyURL)
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
} else {
|
||||
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,13 +57,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
// Strip provider prefix for OpenAI-compatible backends.
|
||||
if idx := strings.Index(model, "/"); idx != -1 {
|
||||
prefix := model[:idx]
|
||||
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
|
||||
model = model[idx+1:]
|
||||
}
|
||||
}
|
||||
model = normalizeModel(model, p.apiBase)
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
@@ -110,7 +69,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
if maxTokens, ok := asInt(options["max_tokens"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
@@ -119,7 +78,7 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
|
||||
}
|
||||
}
|
||||
|
||||
if temperature, ok := options["temperature"].(float64); ok {
|
||||
if temperature, ok := asFloat(options["temperature"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1.
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
@@ -198,17 +157,11 @@ func parseResponse(body []byte) (*LLMResponse, error) {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
|
||||
if tc.Type == "function" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else if tc.Function != nil {
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
@@ -228,3 +181,52 @@ func parseResponse(body []byte) (*LLMResponse, error) {
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
idx := strings.Index(model, "/")
|
||||
if idx == -1 {
|
||||
return model
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
|
||||
return model
|
||||
}
|
||||
|
||||
prefix := strings.ToLower(model[:idx])
|
||||
switch prefix {
|
||||
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
|
||||
return model[idx+1:]
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func asInt(v interface{}) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -32,7 +33,7 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL)
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -78,7 +79,7 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL)
|
||||
p := NewProvider("key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -100,7 +101,7 @@ func TestProviderChat_HTTPError(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL)
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
@@ -128,7 +129,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL)
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
@@ -164,6 +165,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
input: "ollama/qwen2.5:14b",
|
||||
wantModel: "qwen2.5:14b",
|
||||
},
|
||||
{
|
||||
name: "strips deepseek prefix",
|
||||
input: "deepseek/deepseek-chat",
|
||||
wantModel: "deepseek-chat",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -188,7 +194,7 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL)
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -200,3 +206,72 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ProxyConfigured(t *testing.T) {
|
||||
proxyURL := "http://127.0.0.1:8080"
|
||||
p := NewProvider("key", "https://example.com", proxyURL)
|
||||
|
||||
transport, ok := p.httpClient.Transport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport)
|
||||
}
|
||||
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
|
||||
gotProxy, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy function returned error: %v", err)
|
||||
}
|
||||
if gotProxy == nil || gotProxy.String() != proxyURL {
|
||||
t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["max_tokens"] != float64(512) {
|
||||
t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"])
|
||||
}
|
||||
if requestBody["temperature"] != float64(1) {
|
||||
t.Fatalf("temperature = %v, want 1", requestBody["temperature"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" {
|
||||
t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat")
|
||||
}
|
||||
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
|
||||
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package protocoltypes
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
+11
-43
@@ -1,52 +1,20 @@
|
||||
package providers
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
type LLMProvider interface {
|
||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
|
||||
GetDefaultModel() string
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user