mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge remote-tracking branch 'origin/main' into feat/refactor-provider-by-protocol
This commit is contained in:
@@ -0,0 +1,248 @@
|
||||
package anthropicprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
const defaultBaseURL = "https://api.anthropic.com"
|
||||
|
||||
type Provider struct {
|
||||
client *anthropic.Client
|
||||
tokenSource func() (string, error)
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewProvider(token string) *Provider {
|
||||
return NewProviderWithBaseURL(token, "")
|
||||
}
|
||||
|
||||
func NewProviderWithBaseURL(token, apiBase string) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
return &Provider{
|
||||
client: &client,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithClient(client *anthropic.Client) *Provider {
|
||||
return &Provider{
|
||||
client: client,
|
||||
baseURL: defaultBaseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider {
|
||||
return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "")
|
||||
}
|
||||
|
||||
func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider {
|
||||
p := NewProviderWithBaseURL(token, apiBase)
|
||||
p.tokenSource = tokenSource
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
if p.tokenSource != nil {
|
||||
tok, err := p.tokenSource()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
opts = append(opts, option.WithAuthToken(tok))
|
||||
}
|
||||
|
||||
params, err := buildParams(messages, tools, model, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
return parseResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
|
||||
func (p *Provider) BaseURL() string {
|
||||
return p.baseURL
|
||||
}
|
||||
|
||||
func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||
var system []anthropic.TextBlockParam
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
var blocks []anthropic.ContentBlockParamUnion
|
||||
if msg.Content != "" {
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "tool":
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
maxTokens := int64(4096)
|
||||
if mt, ok := options["max_tokens"].(int); ok {
|
||||
maxTokens = int64(mt)
|
||||
}
|
||||
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(model),
|
||||
Messages: anthropicMessages,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
params.System = system
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = anthropic.Float(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateTools(tools)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool := anthropic.ToolParam{
|
||||
Name: t.Function.Name,
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: t.Function.Parameters["properties"],
|
||||
},
|
||||
}
|
||||
if desc := t.Function.Description; desc != "" {
|
||||
tool.Description = anthropic.String(desc)
|
||||
}
|
||||
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||
required := make([]string, 0, len(req))
|
||||
for _, r := range req {
|
||||
if s, ok := r.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
tool.InputSchema.Required = required
|
||||
}
|
||||
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content string
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content += tb.Text
|
||||
case "tool_use":
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
|
||||
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tu.ID,
|
||||
Name: tu.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
switch resp.StopReason {
|
||||
case anthropic.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case anthropic.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case anthropic.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
PromptTokens: int(resp.Usage.InputTokens),
|
||||
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeBaseURL(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if strings.HasSuffix(base, "/v1") {
|
||||
base = strings.TrimSuffix(base, "/v1")
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package anthropicprovider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||
)
|
||||
|
||||
func TestBuildParams_BasicMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||
"max_tokens": 1024,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
if params.MaxTokens != 1024 {
|
||||
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_SystemMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if len(params.System) != 1 {
|
||||
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||
}
|
||||
if params.System[0].Text != "You are helpful" {
|
||||
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_ToolCallMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]interface{}{"city": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if len(params.Messages) != 3 {
|
||||
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a city",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []interface{}{"city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_TextOnly(t *testing.T) {
|
||||
resp := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
},
|
||||
}
|
||||
result := parseResponse(resp)
|
||||
if result.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||
}
|
||||
if result.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
stopReason anthropic.StopReason
|
||||
want string
|
||||
}{
|
||||
{anthropic.StopReasonEndTurn, "stop"},
|
||||
{anthropic.StopReasonMaxTokens, "length"},
|
||||
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
resp := &anthropic.Message{
|
||||
StopReason: tt.stopReason,
|
||||
}
|
||||
result := parseResponse(resp)
|
||||
if result.FinishReason != tt.want {
|
||||
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ChatRoundTrip(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": reqBody["model"],
|
||||
"stop_reason": "end_turn",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "Hello! How can I help you?"},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 15,
|
||||
"output_tokens": 8,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello! How can I help you?" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 15 {
|
||||
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewProvider("test-token")
|
||||
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) {
|
||||
p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/")
|
||||
if got := p.BaseURL(); got != "https://api.anthropic.com" {
|
||||
t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ChatUsesTokenSource(t *testing.T) {
|
||||
var requests int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
atomic.AddInt32(&requests, 1)
|
||||
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": reqBody["model"],
|
||||
"stop_reason": "end_turn",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "ok"},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
|
||||
return "refreshed-token", nil
|
||||
}, server.URL)
|
||||
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&requests); got != 1 {
|
||||
t.Fatalf("requests = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||
c := anthropic.NewClient(
|
||||
anthropicoption.WithAuthToken(token),
|
||||
anthropicoption.WithBaseURL(baseURL),
|
||||
)
|
||||
return &c
|
||||
}
|
||||
@@ -2,200 +2,58 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
|
||||
)
|
||||
|
||||
type ClaudeProvider struct {
|
||||
client *anthropic.Client
|
||||
tokenSource func() (string, error)
|
||||
delegate *anthropicprovider.Provider
|
||||
}
|
||||
|
||||
func NewClaudeProvider(token string) *ClaudeProvider {
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL("https://api.anthropic.com"),
|
||||
)
|
||||
return &ClaudeProvider{client: &client}
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProvider(token),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
|
||||
p := NewClaudeProvider(token)
|
||||
p.tokenSource = tokenSource
|
||||
return p
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
|
||||
}
|
||||
}
|
||||
|
||||
func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider {
|
||||
return &ClaudeProvider{delegate: delegate}
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
if p.tokenSource != nil {
|
||||
tok, err := p.tokenSource()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
opts = append(opts, option.WithAuthToken(tok))
|
||||
}
|
||||
|
||||
params, err := buildClaudeParams(messages, tools, model, options)
|
||||
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
return parseClaudeResponse(resp), nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) GetDefaultModel() string {
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
|
||||
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||
var system []anthropic.TextBlockParam
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
var blocks []anthropic.ContentBlockParamUnion
|
||||
if msg.Content != "" {
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "tool":
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
maxTokens := int64(4096)
|
||||
if mt, ok := options["max_tokens"].(int); ok {
|
||||
maxTokens = int64(mt)
|
||||
}
|
||||
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(model),
|
||||
Messages: anthropicMessages,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
params.System = system
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = anthropic.Float(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForClaude(tools)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool := anthropic.ToolParam{
|
||||
Name: t.Function.Name,
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: t.Function.Parameters["properties"],
|
||||
},
|
||||
}
|
||||
if desc := t.Function.Description; desc != "" {
|
||||
tool.Description = anthropic.String(desc)
|
||||
}
|
||||
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||
required := make([]string, 0, len(req))
|
||||
for _, r := range req {
|
||||
if s, ok := r.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
tool.InputSchema.Required = required
|
||||
}
|
||||
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content string
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content += tb.Text
|
||||
case "tool_use":
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tu.ID,
|
||||
Name: tu.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
switch resp.StopReason {
|
||||
case anthropic.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case anthropic.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case anthropic.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
PromptTokens: int(resp.Usage.InputTokens),
|
||||
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||
},
|
||||
}
|
||||
return p.delegate.GetDefaultModel()
|
||||
}
|
||||
|
||||
func createClaudeTokenSource() func() (string, error) {
|
||||
return func() (string, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
cred, err := getCredential("anthropic")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,140 +8,9 @@ import (
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
|
||||
)
|
||||
|
||||
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||
"max_tokens": 1024,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
if params.MaxTokens != 1024 {
|
||||
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.System) != 1 {
|
||||
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||
}
|
||||
if params.System[0].Text != "You are helpful" {
|
||||
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]interface{}{"city": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.Messages) != 3 {
|
||||
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a city",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []interface{}{"city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeResponse_TextOnly(t *testing.T) {
|
||||
resp := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
},
|
||||
}
|
||||
result := parseClaudeResponse(resp)
|
||||
if result.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||
}
|
||||
if result.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
stopReason anthropic.StopReason
|
||||
want string
|
||||
}{
|
||||
{anthropic.StopReasonEndTurn, "stop"},
|
||||
{anthropic.StopReasonMaxTokens, "length"},
|
||||
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
resp := &anthropic.Message{
|
||||
StopReason: tt.stopReason,
|
||||
}
|
||||
result := parseClaudeResponse(resp)
|
||||
if result.FinishReason != tt.want {
|
||||
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
@@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewClaudeProvider("test-token")
|
||||
provider.client = createAnthropicTestClient(server.URL, "test-token")
|
||||
delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
|
||||
provider := newClaudeProviderWithDelegate(delegate)
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
//go:build integration
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
exec "os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI.
|
||||
// Run with: go test -tags=integration ./pkg/providers/...
|
||||
func TestIntegration_RealCodexCLI(t *testing.T) {
|
||||
path, err := exec.LookPath("codex")
|
||||
if err != nil {
|
||||
t.Skip("codex CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using codex CLI at: %s", path)
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run codex directly and verify our parser handles real output
|
||||
cmd := exec.Command("codex", "exec",
|
||||
"--json",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
"--skip-git-repo-check",
|
||||
"--color", "never",
|
||||
"-C", t.TempDir(),
|
||||
"-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// codex may write diagnostic noise to stderr but still produce valid output
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("codex CLI failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)]))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewCodexCliProvider("")
|
||||
resp, err := p.parseJSONLEvents(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
@@ -18,9 +18,10 @@ const codexDefaultModel = "gpt-5.2"
|
||||
const codexDefaultInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
type CodexProvider struct {
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
enableWebSearch bool
|
||||
}
|
||||
|
||||
const defaultCodexInstructions = "You are Codex, a coding assistant."
|
||||
@@ -37,8 +38,9 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
}
|
||||
client := openai.NewClient(opts...)
|
||||
return &CodexProvider{
|
||||
client: &client,
|
||||
accountID: accountID,
|
||||
client: &client,
|
||||
accountID: accountID,
|
||||
enableWebSearch: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +80,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
|
||||
})
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
|
||||
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
@@ -182,7 +184,7 @@ func resolveCodexModel(model string) (string, string) {
|
||||
return codexDefaultModel, "unsupported model family"
|
||||
}
|
||||
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams {
|
||||
var inputItems responses.ResponseInputParam
|
||||
var instructions string
|
||||
|
||||
@@ -217,12 +219,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
name, args, ok := resolveCodexToolCall(tc)
|
||||
if !ok {
|
||||
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{
|
||||
"call_id": tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Arguments: string(argsJSON),
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -260,20 +268,50 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
params.Instructions = openai.Opt(defaultCodexInstructions)
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForCodex(tools)
|
||||
if len(tools) > 0 || enableWebSearch {
|
||||
params.Tools = translateToolsForCodex(tools, enableWebSearch)
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
|
||||
name = tc.Name
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
if name == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if len(tc.Arguments) > 0 {
|
||||
argsJSON, err := json.Marshal(tc.Arguments)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return name, string(argsJSON), true
|
||||
}
|
||||
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
return name, tc.Function.Arguments, true
|
||||
}
|
||||
|
||||
return name, "{}", true
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
|
||||
capHint := len(tools)
|
||||
if enableWebSearch {
|
||||
capHint++
|
||||
}
|
||||
result := make([]responses.ToolUnionParam, 0, capHint)
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
ft := responses.FunctionToolParam{
|
||||
Name: t.Function.Name,
|
||||
Parameters: t.Function.Parameters,
|
||||
@@ -284,6 +322,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
}
|
||||
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||
}
|
||||
if enableWebSearch {
|
||||
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
}, true)
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
}
|
||||
@@ -29,6 +29,9 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
if params.Instructions.Or("") != defaultCodexInstructions {
|
||||
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
|
||||
}
|
||||
if params.MaxOutputTokens.Valid() {
|
||||
t.Fatalf("MaxOutputTokens should not be set for Codex backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
@@ -36,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true)
|
||||
if !params.Instructions.Valid() {
|
||||
t.Fatal("Instructions should be set")
|
||||
}
|
||||
@@ -56,7 +59,7 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
@@ -65,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Read a file"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
if len(params.Input.OfInputItemList) != 3 {
|
||||
t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||
}
|
||||
|
||||
fc := params.Input.OfInputItemList[1].OfFunctionCall
|
||||
if fc == nil {
|
||||
t.Fatal("assistant tool call should be converted to function_call input item")
|
||||
}
|
||||
if fc.Name != "read_file" {
|
||||
t.Errorf("Function call name = %q, want %q", fc.Name, "read_file")
|
||||
}
|
||||
if fc.Arguments != `{"path":"README.md"}` {
|
||||
t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
@@ -81,7 +123,7 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false)
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
@@ -94,12 +136,61 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if !params.Store.Valid() || params.Store.Or(true) != false {
|
||||
t.Error("Store should be explicitly set to false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) {
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true)
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
if params.Tools[0].OfWebSearch == nil {
|
||||
t.Fatal("Tool should include built-in web_search")
|
||||
}
|
||||
if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch {
|
||||
t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "local web search",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: "read file",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true)
|
||||
if len(params.Tools) != 2 {
|
||||
t.Fatalf("len(Tools) = %d, want 2", len(params.Tools))
|
||||
}
|
||||
if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" {
|
||||
t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0])
|
||||
}
|
||||
if params.Tools[1].OfWebSearch == nil {
|
||||
t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||
respJSON := `{
|
||||
"id": "resp_test",
|
||||
@@ -214,6 +305,20 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
toolsAny, ok := reqBody["tools"].([]interface{})
|
||||
if !ok || len(toolsAny) != 1 {
|
||||
http.Error(w, "missing default web search tool", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
toolObj, ok := toolsAny[0].(map[string]interface{})
|
||||
if !ok || toolObj["type"] != "web_search" {
|
||||
http.Error(w, "expected web_search tool", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
@@ -261,6 +366,64 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["tools"]; ok {
|
||||
http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 4,
|
||||
"output_tokens": 3,
|
||||
"total_tokens": 7,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("test-token", "acc-123")
|
||||
provider.enableWebSearch = false
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
@@ -293,6 +456,10 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
|
||||
http.Error(w, "temperature is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFailureWindow = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CooldownTracker manages per-provider cooldown state for the fallback chain.
|
||||
// Thread-safe via sync.RWMutex. In-memory only (resets on restart).
|
||||
type CooldownTracker struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*cooldownEntry
|
||||
failureWindow time.Duration
|
||||
nowFunc func() time.Time // for testing
|
||||
}
|
||||
|
||||
type cooldownEntry struct {
|
||||
ErrorCount int
|
||||
FailureCounts map[FailoverReason]int
|
||||
CooldownEnd time.Time // standard cooldown expiry
|
||||
DisabledUntil time.Time // billing-specific disable expiry
|
||||
DisabledReason FailoverReason // reason for disable (billing)
|
||||
LastFailure time.Time
|
||||
}
|
||||
|
||||
// NewCooldownTracker creates a tracker with default 24h failure window.
|
||||
func NewCooldownTracker() *CooldownTracker {
|
||||
return &CooldownTracker{
|
||||
entries: make(map[string]*cooldownEntry),
|
||||
failureWindow: defaultFailureWindow,
|
||||
nowFunc: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkFailure records a failure for a provider and sets appropriate cooldown.
|
||||
// Resets error counts if last failure was more than failureWindow ago.
|
||||
func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) {
|
||||
ct.mu.Lock()
|
||||
defer ct.mu.Unlock()
|
||||
|
||||
now := ct.nowFunc()
|
||||
entry := ct.getOrCreate(provider)
|
||||
|
||||
// 24h failure window reset: if no failure in failureWindow, reset counters.
|
||||
if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow {
|
||||
entry.ErrorCount = 0
|
||||
entry.FailureCounts = make(map[FailoverReason]int)
|
||||
}
|
||||
|
||||
entry.ErrorCount++
|
||||
entry.FailureCounts[reason]++
|
||||
entry.LastFailure = now
|
||||
|
||||
if reason == FailoverBilling {
|
||||
billingCount := entry.FailureCounts[FailoverBilling]
|
||||
entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount))
|
||||
entry.DisabledReason = FailoverBilling
|
||||
} else {
|
||||
entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount))
|
||||
}
|
||||
}
|
||||
|
||||
// MarkSuccess resets all counters and cooldowns for a provider.
|
||||
func (ct *CooldownTracker) MarkSuccess(provider string) {
|
||||
ct.mu.Lock()
|
||||
defer ct.mu.Unlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
entry.ErrorCount = 0
|
||||
entry.FailureCounts = make(map[FailoverReason]int)
|
||||
entry.CooldownEnd = time.Time{}
|
||||
entry.DisabledUntil = time.Time{}
|
||||
entry.DisabledReason = ""
|
||||
}
|
||||
|
||||
// IsAvailable returns true if the provider is not in cooldown or disabled.
|
||||
func (ct *CooldownTracker) IsAvailable(provider string) bool {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
now := ct.nowFunc()
|
||||
|
||||
// Billing disable takes precedence (longer cooldown).
|
||||
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Standard cooldown.
|
||||
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CooldownRemaining returns how long until the provider becomes available.
|
||||
// Returns 0 if already available.
|
||||
func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
now := ct.nowFunc()
|
||||
var remaining time.Duration
|
||||
|
||||
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
|
||||
d := entry.DisabledUntil.Sub(now)
|
||||
if d > remaining {
|
||||
remaining = d
|
||||
}
|
||||
}
|
||||
|
||||
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
|
||||
d := entry.CooldownEnd.Sub(now)
|
||||
if d > remaining {
|
||||
remaining = d
|
||||
}
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// ErrorCount returns the current error count for a provider.
|
||||
func (ct *CooldownTracker) ErrorCount(provider string) int {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
return entry.ErrorCount
|
||||
}
|
||||
|
||||
// FailureCount returns the failure count for a specific reason.
|
||||
func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
return entry.FailureCounts[reason]
|
||||
}
|
||||
|
||||
func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry {
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
entry = &cooldownEntry{
|
||||
FailureCounts: make(map[FailoverReason]int),
|
||||
}
|
||||
ct.entries[provider] = entry
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// calculateStandardCooldown computes standard exponential backoff.
|
||||
// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3))
|
||||
//
|
||||
// 1 error → 1 min
|
||||
// 2 errors → 5 min
|
||||
// 3 errors → 25 min
|
||||
// 4+ errors → 1 hour (cap)
|
||||
func calculateStandardCooldown(errorCount int) time.Duration {
|
||||
n := max(1, errorCount)
|
||||
exp := min(n-1, 3)
|
||||
ms := 60_000 * int(math.Pow(5, float64(exp)))
|
||||
ms = min(3_600_000, ms) // cap at 1 hour
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
// calculateBillingCooldown computes billing-specific exponential backoff.
|
||||
// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10))
|
||||
//
|
||||
// 1 error → 5 hours
|
||||
// 2 errors → 10 hours
|
||||
// 3 errors → 20 hours
|
||||
// 4+ errors → 24 hours (cap)
|
||||
func calculateBillingCooldown(billingErrorCount int) time.Duration {
|
||||
const baseMs = 5 * 60 * 60 * 1000 // 5 hours
|
||||
const maxMs = 24 * 60 * 60 * 1000 // 24 hours
|
||||
|
||||
n := max(1, billingErrorCount)
|
||||
exp := min(n-1, 10)
|
||||
raw := float64(baseMs) * math.Pow(2, float64(exp))
|
||||
ms := int(math.Min(float64(maxMs), raw))
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) {
|
||||
current := now
|
||||
ct := NewCooldownTracker()
|
||||
ct.nowFunc = func() time.Time { return current }
|
||||
return ct, ¤t
|
||||
}
|
||||
|
||||
func TestCooldown_InitiallyAvailable(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("new provider should be available")
|
||||
}
|
||||
if ct.ErrorCount("openai") != 0 {
|
||||
t.Error("new provider should have 0 errors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_StandardEscalation(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 1st error → 1 min cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be in cooldown after 1st error")
|
||||
}
|
||||
|
||||
// Advance 61 seconds → available
|
||||
*current = now.Add(61 * time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 1 min cooldown")
|
||||
}
|
||||
|
||||
// 2nd error → 5 min cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = now.Add(61*time.Second + 4*time.Minute)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be in cooldown (5 min) after 2nd error")
|
||||
}
|
||||
*current = now.Add(61*time.Second + 6*time.Minute)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 5 min cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_StandardCap(t *testing.T) {
|
||||
// Verify formula: 1m, 5m, 25m, 1h, 1h, 1h...
|
||||
expected := []time.Duration{
|
||||
1 * time.Minute,
|
||||
5 * time.Minute,
|
||||
25 * time.Minute,
|
||||
1 * time.Hour,
|
||||
1 * time.Hour,
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
got := calculateStandardCooldown(i + 1)
|
||||
if got != want {
|
||||
t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingEscalation(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 1st billing error → 5h cooldown
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be disabled after billing error")
|
||||
}
|
||||
|
||||
// Advance 4h → still disabled
|
||||
*current = now.Add(4 * time.Hour)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should still be disabled (5h cooldown)")
|
||||
}
|
||||
|
||||
// Advance 5h + 1s → available
|
||||
*current = now.Add(5*time.Hour + 1*time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 5h billing cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingCap(t *testing.T) {
|
||||
expected := []time.Duration{
|
||||
5 * time.Hour,
|
||||
10 * time.Hour,
|
||||
20 * time.Hour,
|
||||
24 * time.Hour,
|
||||
24 * time.Hour,
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
got := calculateBillingCooldown(i + 1)
|
||||
if got != want {
|
||||
t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_SuccessReset(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
if ct.ErrorCount("openai") != 2 {
|
||||
t.Errorf("error count = %d, want 2", ct.ErrorCount("openai"))
|
||||
}
|
||||
|
||||
ct.MarkSuccess("openai")
|
||||
if ct.ErrorCount("openai") != 0 {
|
||||
t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai"))
|
||||
}
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after success")
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverRateLimit) != 0 {
|
||||
t.Error("failure counts should be reset after success")
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverBilling) != 0 {
|
||||
t.Error("billing failure count should be reset after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_FailureWindowReset(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 4 errors → 1h cooldown
|
||||
for i := 0; i < 4; i++ {
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = current.Add(2 * time.Second) // small advance between errors
|
||||
}
|
||||
if ct.ErrorCount("openai") != 4 {
|
||||
t.Errorf("error count = %d, want 4", ct.ErrorCount("openai"))
|
||||
}
|
||||
|
||||
// Advance 25 hours (past 24h failure window)
|
||||
*current = now.Add(25 * time.Hour)
|
||||
|
||||
// Next error should reset counters first, then increment to 1
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
if ct.ErrorCount("openai") != 1 {
|
||||
t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_PerReasonTracking(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
ct.MarkFailure("openai", FailoverAuth)
|
||||
|
||||
if ct.FailureCount("openai", FailoverRateLimit) != 2 {
|
||||
t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit))
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverBilling) != 1 {
|
||||
t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling))
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverAuth) != 1 {
|
||||
t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth))
|
||||
}
|
||||
if ct.ErrorCount("openai") != 4 {
|
||||
t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingTakesPrecedence(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// Standard cooldown (1 min) + billing disable (5h)
|
||||
ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown
|
||||
ct.MarkFailure("openai", FailoverBilling) // 5h disable
|
||||
|
||||
// After 2 min: standard cooldown expired but billing still active
|
||||
*current = now.Add(2 * time.Minute)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("billing disable should take precedence over standard cooldown")
|
||||
}
|
||||
|
||||
// After 5h + 1s: both expired
|
||||
*current = now.Add(5*time.Hour + 1*time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after all cooldowns expire")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_CooldownRemaining(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// No failures → 0 remaining
|
||||
if ct.CooldownRemaining("openai") != 0 {
|
||||
t.Error("expected 0 remaining for new provider")
|
||||
}
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
|
||||
*current = now.Add(30 * time.Second)
|
||||
remaining := ct.CooldownRemaining("openai")
|
||||
if remaining <= 0 || remaining > 1*time.Minute {
|
||||
t.Errorf("remaining = %v, expected ~30s", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_SuccessOnUnknownProvider(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
// Should not panic
|
||||
ct.MarkSuccess("nonexistent")
|
||||
if !ct.IsAvailable("nonexistent") {
|
||||
t.Error("nonexistent provider should be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_ConcurrentAccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.IsAvailable("openai")
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.MarkSuccess("openai")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// If we got here without panic, concurrent access is safe
|
||||
}
|
||||
|
||||
func TestCooldown_MultipleProviders(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("anthropic", FailoverBilling)
|
||||
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("openai should be in cooldown")
|
||||
}
|
||||
if ct.IsAvailable("anthropic") {
|
||||
t.Error("anthropic should be in cooldown")
|
||||
}
|
||||
// groq was never touched
|
||||
if !ct.IsAvailable("groq") {
|
||||
t.Error("groq should be available")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// errorPattern defines a single pattern (string or regex) for error classification.
|
||||
type errorPattern struct {
|
||||
substring string
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
func substr(s string) errorPattern { return errorPattern{substring: s} }
|
||||
func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} }
|
||||
|
||||
// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns).
|
||||
var (
|
||||
rateLimitPatterns = []errorPattern{
|
||||
rxp(`rate[_ ]limit`),
|
||||
substr("too many requests"),
|
||||
substr("429"),
|
||||
substr("exceeded your current quota"),
|
||||
rxp(`exceeded.*quota`),
|
||||
rxp(`resource has been exhausted`),
|
||||
rxp(`resource.*exhausted`),
|
||||
substr("resource_exhausted"),
|
||||
substr("quota exceeded"),
|
||||
substr("usage limit"),
|
||||
}
|
||||
|
||||
overloadedPatterns = []errorPattern{
|
||||
rxp(`overloaded_error`),
|
||||
rxp(`"type"\s*:\s*"overloaded_error"`),
|
||||
substr("overloaded"),
|
||||
}
|
||||
|
||||
timeoutPatterns = []errorPattern{
|
||||
substr("timeout"),
|
||||
substr("timed out"),
|
||||
substr("deadline exceeded"),
|
||||
substr("context deadline exceeded"),
|
||||
}
|
||||
|
||||
billingPatterns = []errorPattern{
|
||||
rxp(`\b402\b`),
|
||||
substr("payment required"),
|
||||
substr("insufficient credits"),
|
||||
substr("credit balance"),
|
||||
substr("plans & billing"),
|
||||
substr("insufficient balance"),
|
||||
}
|
||||
|
||||
authPatterns = []errorPattern{
|
||||
rxp(`invalid[_ ]?api[_ ]?key`),
|
||||
substr("incorrect api key"),
|
||||
substr("invalid token"),
|
||||
substr("authentication"),
|
||||
substr("re-authenticate"),
|
||||
substr("oauth token refresh failed"),
|
||||
substr("unauthorized"),
|
||||
substr("forbidden"),
|
||||
substr("access denied"),
|
||||
substr("expired"),
|
||||
substr("token has expired"),
|
||||
rxp(`\b401\b`),
|
||||
rxp(`\b403\b`),
|
||||
substr("no credentials found"),
|
||||
substr("no api key found"),
|
||||
}
|
||||
|
||||
formatPatterns = []errorPattern{
|
||||
substr("string should match pattern"),
|
||||
substr("tool_use.id"),
|
||||
substr("tool_use_id"),
|
||||
substr("messages.1.content.1.tool_use.id"),
|
||||
substr("invalid request format"),
|
||||
}
|
||||
|
||||
imageDimensionPatterns = []errorPattern{
|
||||
rxp(`image dimensions exceed max`),
|
||||
}
|
||||
|
||||
imageSizePatterns = []errorPattern{
|
||||
rxp(`image exceeds.*mb`),
|
||||
}
|
||||
|
||||
// Transient HTTP status codes that map to timeout (server-side failures).
|
||||
transientStatusCodes = map[int]bool{
|
||||
500: true, 502: true, 503: true,
|
||||
521: true, 522: true, 523: true, 524: true,
|
||||
529: true,
|
||||
}
|
||||
)
|
||||
|
||||
// ClassifyError classifies an error into a FailoverError with reason.
|
||||
// Returns nil if the error is not classifiable (unknown errors should not trigger fallback).
|
||||
func ClassifyError(err error, provider, model string) *FailoverError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context cancellation: user abort, never fallback.
|
||||
if err == context.Canceled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context deadline exceeded: treat as timeout, always fallback.
|
||||
if err == context.DeadlineExceeded {
|
||||
return &FailoverError{
|
||||
Reason: FailoverTimeout,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
// Image dimension/size errors: non-retriable, non-fallback.
|
||||
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
|
||||
return &FailoverError{
|
||||
Reason: FailoverFormat,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Try HTTP status code extraction first.
|
||||
if status := extractHTTPStatus(msg); status > 0 {
|
||||
if reason := classifyByStatus(status); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Status: status,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message pattern matching (priority order from OpenClaw).
|
||||
if reason := classifyByMessage(msg); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// classifyByStatus maps HTTP status codes to FailoverReason.
|
||||
func classifyByStatus(status int) FailoverReason {
|
||||
switch {
|
||||
case status == 401 || status == 403:
|
||||
return FailoverAuth
|
||||
case status == 402:
|
||||
return FailoverBilling
|
||||
case status == 408:
|
||||
return FailoverTimeout
|
||||
case status == 429:
|
||||
return FailoverRateLimit
|
||||
case status == 400:
|
||||
return FailoverFormat
|
||||
case transientStatusCodes[status]:
|
||||
return FailoverTimeout
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// classifyByMessage matches error messages against patterns.
|
||||
// Priority order matters (from OpenClaw classifyFailoverReason).
|
||||
func classifyByMessage(msg string) FailoverReason {
|
||||
if matchesAny(msg, rateLimitPatterns) {
|
||||
return FailoverRateLimit
|
||||
}
|
||||
if matchesAny(msg, overloadedPatterns) {
|
||||
return FailoverRateLimit // Overloaded treated as rate_limit
|
||||
}
|
||||
if matchesAny(msg, billingPatterns) {
|
||||
return FailoverBilling
|
||||
}
|
||||
if matchesAny(msg, timeoutPatterns) {
|
||||
return FailoverTimeout
|
||||
}
|
||||
if matchesAny(msg, authPatterns) {
|
||||
return FailoverAuth
|
||||
}
|
||||
if matchesAny(msg, formatPatterns) {
|
||||
return FailoverFormat
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHTTPStatus extracts an HTTP status code from an error message.
|
||||
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
|
||||
func extractHTTPStatus(msg string) int {
|
||||
// Common patterns in Go HTTP error messages
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`status[:\s]+(\d{3})`),
|
||||
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
if m := p.FindStringSubmatch(msg); len(m) > 1 {
|
||||
return parseDigits(m[1])
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsImageDimensionError returns true if the message indicates an image dimension error.
|
||||
func IsImageDimensionError(msg string) bool {
|
||||
return matchesAny(msg, imageDimensionPatterns)
|
||||
}
|
||||
|
||||
// IsImageSizeError returns true if the message indicates an image file size error.
|
||||
func IsImageSizeError(msg string) bool {
|
||||
return matchesAny(msg, imageSizePatterns)
|
||||
}
|
||||
|
||||
// matchesAny checks if msg matches any of the patterns.
|
||||
func matchesAny(msg string, patterns []errorPattern) bool {
|
||||
for _, p := range patterns {
|
||||
if p.regex != nil {
|
||||
if p.regex.MatchString(msg) {
|
||||
return true
|
||||
}
|
||||
} else if p.substring != "" {
|
||||
if strings.Contains(msg, p.substring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseDigits converts a string of digits to an int.
|
||||
func parseDigits(s string) int {
|
||||
n := 0
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int(c-'0')
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClassifyError_Nil(t *testing.T) {
|
||||
result := ClassifyError(nil, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for nil error, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ContextCanceled(t *testing.T) {
|
||||
result := ClassifyError(context.Canceled, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for context.Canceled (user abort), got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ContextDeadlineExceeded(t *testing.T) {
|
||||
result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for deadline exceeded")
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Errorf("reason = %q, want timeout", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_StatusCodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
status int
|
||||
reason FailoverReason
|
||||
}{
|
||||
{401, FailoverAuth},
|
||||
{403, FailoverAuth},
|
||||
{402, FailoverBilling},
|
||||
{408, FailoverTimeout},
|
||||
{429, FailoverRateLimit},
|
||||
{400, FailoverFormat},
|
||||
{500, FailoverTimeout},
|
||||
{502, FailoverTimeout},
|
||||
{503, FailoverTimeout},
|
||||
{521, FailoverTimeout},
|
||||
{522, FailoverTimeout},
|
||||
{523, FailoverTimeout},
|
||||
{524, FailoverTimeout},
|
||||
{529, FailoverTimeout},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
err := fmt.Errorf("API error: status: %d something went wrong", tt.status)
|
||||
result := ClassifyError(err, "test", "model")
|
||||
if result == nil {
|
||||
t.Errorf("status %d: expected non-nil", tt.status)
|
||||
continue
|
||||
}
|
||||
if result.Reason != tt.reason {
|
||||
t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_RateLimitPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"rate limit exceeded",
|
||||
"rate_limit reached",
|
||||
"too many requests",
|
||||
"exceeded your current quota",
|
||||
"resource has been exhausted",
|
||||
"resource_exhausted",
|
||||
"quota exceeded",
|
||||
"usage limit reached",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverRateLimit {
|
||||
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_OverloadedPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"overloaded_error",
|
||||
`{"type": "overloaded_error"}`,
|
||||
"server is overloaded",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "anthropic", "claude")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
// Overloaded is treated as rate_limit
|
||||
if result.Reason != FailoverRateLimit {
|
||||
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_BillingPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"payment required",
|
||||
"insufficient credits",
|
||||
"credit balance too low",
|
||||
"plans & billing page",
|
||||
"insufficient balance",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverBilling {
|
||||
t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_TimeoutPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"request timeout",
|
||||
"connection timed out",
|
||||
"deadline exceeded",
|
||||
"context deadline exceeded",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_AuthPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"incorrect api key",
|
||||
"invalid token",
|
||||
"authentication failed",
|
||||
"re-authenticate",
|
||||
"oauth token refresh failed",
|
||||
"unauthorized access",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"expired",
|
||||
"token has expired",
|
||||
"no credentials found",
|
||||
"no api key found",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverAuth {
|
||||
t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_FormatPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"string should match pattern",
|
||||
"tool_use.id is required",
|
||||
"invalid tool_use_id",
|
||||
"messages.1.content.1.tool_use.id must be valid",
|
||||
"invalid request format",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "anthropic", "claude")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ImageDimensionError(t *testing.T) {
|
||||
err := errors.New("image dimensions exceed max allowed 2048x2048")
|
||||
result := ClassifyError(err, "openai", "gpt-4o")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for image dimension error")
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", result.Reason)
|
||||
}
|
||||
if result.IsRetriable() {
|
||||
t.Error("image dimension error should not be retriable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ImageSizeError(t *testing.T) {
|
||||
err := errors.New("image exceeds 20 mb limit")
|
||||
result := ClassifyError(err, "openai", "gpt-4o")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for image size error")
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_UnknownError(t *testing.T) {
|
||||
err := errors.New("some completely random error")
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for unknown error, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ProviderModelPropagation(t *testing.T) {
|
||||
err := errors.New("rate limit exceeded")
|
||||
result := ClassifyError(err, "my-provider", "my-model")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Provider != "my-provider" {
|
||||
t.Errorf("provider = %q, want my-provider", result.Provider)
|
||||
}
|
||||
if result.Model != "my-model" {
|
||||
t.Errorf("model = %q, want my-model", result.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_IsRetriable(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason FailoverReason
|
||||
retriable bool
|
||||
}{
|
||||
{FailoverAuth, true},
|
||||
{FailoverRateLimit, true},
|
||||
{FailoverBilling, true},
|
||||
{FailoverTimeout, true},
|
||||
{FailoverOverloaded, true},
|
||||
{FailoverFormat, false},
|
||||
{FailoverUnknown, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
fe := &FailoverError{Reason: tt.reason}
|
||||
if fe.IsRetriable() != tt.retriable {
|
||||
t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_ErrorString(t *testing.T) {
|
||||
fe := &FailoverError{
|
||||
Reason: FailoverRateLimit,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Status: 429,
|
||||
Wrapped: errors.New("too many requests"),
|
||||
}
|
||||
s := fe.Error()
|
||||
if s == "" {
|
||||
t.Error("expected non-empty error string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_Unwrap(t *testing.T) {
|
||||
inner := errors.New("inner error")
|
||||
fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner}
|
||||
if fe.Unwrap() != inner {
|
||||
t.Error("Unwrap should return wrapped error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHTTPStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"status: 429 rate limited", 429},
|
||||
{"status 401 unauthorized", 401},
|
||||
{"HTTP/1.1 502 Bad Gateway", 502},
|
||||
{"no status code here", 0},
|
||||
{"random number 12345", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := extractHTTPStatus(tt.msg)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsImageDimensionError(t *testing.T) {
|
||||
if !IsImageDimensionError("image dimensions exceed max 4096x4096") {
|
||||
t.Error("should match image dimensions exceed max")
|
||||
}
|
||||
if IsImageDimensionError("normal error message") {
|
||||
t.Error("should not match normal error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsImageSizeError(t *testing.T) {
|
||||
if !IsImageSizeError("image exceeds 20 mb") {
|
||||
t.Error("should match image exceeds mb")
|
||||
}
|
||||
if IsImageSizeError("normal error message") {
|
||||
t.Error("should not match normal error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,360 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
|
||||
|
||||
var getCredential = auth.GetCredential
|
||||
|
||||
type providerType int
|
||||
|
||||
const (
|
||||
providerTypeHTTPCompat providerType = iota
|
||||
providerTypeClaudeAuth
|
||||
providerTypeCodexAuth
|
||||
providerTypeCodexCLIToken
|
||||
providerTypeClaudeCLI
|
||||
providerTypeCodexCLI
|
||||
providerTypeGitHubCopilot
|
||||
)
|
||||
|
||||
type providerSelection struct {
|
||||
providerType providerType
|
||||
apiKey string
|
||||
apiBase string
|
||||
proxy string
|
||||
model string
|
||||
workspace string
|
||||
connectMode string
|
||||
enableWebSearch bool
|
||||
}
|
||||
|
||||
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
cred, err := getCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
|
||||
cred, err := getCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
|
||||
p.enableWebSearch = enableWebSearch
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
sel := providerSelection{
|
||||
providerType: providerTypeHTTPCompat,
|
||||
model: model,
|
||||
}
|
||||
|
||||
// First, prefer explicit provider configuration.
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Groq.APIKey
|
||||
sel.apiBase = cfg.Providers.Groq.APIBase
|
||||
sel.proxy = cfg.Providers.Groq.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
sel.providerType = providerTypeCodexCLIToken
|
||||
return sel, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
sel.providerType = providerTypeCodexAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.OpenAI.APIKey
|
||||
sel.apiBase = cfg.Providers.OpenAI.APIBase
|
||||
sel.proxy = cfg.Providers.OpenAI.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.Anthropic.APIKey
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
sel.apiBase = cfg.Providers.Zhipu.APIBase
|
||||
sel.proxy = cfg.Providers.Zhipu.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Gemini.APIKey
|
||||
sel.apiBase = cfg.Providers.Gemini.APIBase
|
||||
sel.proxy = cfg.Providers.Gemini.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
sel.proxy = cfg.Providers.VLLM.Proxy
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "nvidia":
|
||||
if cfg.Providers.Nvidia.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Nvidia.APIKey
|
||||
sel.apiBase = cfg.Providers.Nvidia.APIBase
|
||||
sel.proxy = cfg.Providers.Nvidia.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claude-code", "claudecode":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
sel.providerType = providerTypeClaudeCLI
|
||||
sel.workspace = workspace
|
||||
return sel, nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
sel.providerType = providerTypeCodexCLI
|
||||
sel.workspace = workspace
|
||||
return sel, nil
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
sel.apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
sel.proxy = cfg.Providers.DeepSeek.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
sel.model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
sel.providerType = providerTypeGitHubCopilot
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
sel.apiBase = "localhost:4321"
|
||||
}
|
||||
sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode
|
||||
return sel, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: infer provider from model and configured keys.
|
||||
if sel.apiKey == "" && sel.apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Moonshot.APIKey
|
||||
sel.apiBase = cfg.Providers.Moonshot.APIBase
|
||||
sel.proxy = cfg.Providers.Moonshot.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "openrouter/") ||
|
||||
strings.HasPrefix(model, "anthropic/") ||
|
||||
strings.HasPrefix(model, "openai/") ||
|
||||
strings.HasPrefix(model, "meta-llama/") ||
|
||||
strings.HasPrefix(model, "deepseek/") ||
|
||||
strings.HasPrefix(model, "google/"):
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
|
||||
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.Anthropic.APIKey
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
|
||||
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
sel.providerType = providerTypeCodexCLIToken
|
||||
return sel, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
sel.providerType = providerTypeCodexAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.OpenAI.APIKey
|
||||
sel.apiBase = cfg.Providers.OpenAI.APIBase
|
||||
sel.proxy = cfg.Providers.OpenAI.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Gemini.APIKey
|
||||
sel.apiBase = cfg.Providers.Gemini.APIBase
|
||||
sel.proxy = cfg.Providers.Gemini.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
sel.apiBase = cfg.Providers.Zhipu.APIBase
|
||||
sel.proxy = cfg.Providers.Zhipu.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Groq.APIKey
|
||||
sel.apiBase = cfg.Providers.Groq.APIBase
|
||||
sel.proxy = cfg.Providers.Groq.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Nvidia.APIKey
|
||||
sel.apiBase = cfg.Providers.Nvidia.APIBase
|
||||
sel.proxy = cfg.Providers.Nvidia.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Ollama.APIKey
|
||||
sel.apiBase = cfg.Providers.Ollama.APIBase
|
||||
sel.proxy = cfg.Providers.Ollama.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
sel.proxy = cfg.Providers.VLLM.Proxy
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sel.providerType == providerTypeHTTPCompat {
|
||||
if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
if sel.apiBase == "" {
|
||||
return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
}
|
||||
|
||||
return sel, nil
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
sel, err := resolveProviderSelection(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch sel.providerType {
|
||||
case providerTypeClaudeAuth:
|
||||
return createClaudeAuthProvider(sel.apiBase)
|
||||
case providerTypeCodexAuth:
|
||||
return createCodexAuthProvider(sel.enableWebSearch)
|
||||
case providerTypeCodexCLIToken:
|
||||
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
|
||||
c.enableWebSearch = sel.enableWebSearch
|
||||
return c, nil
|
||||
case providerTypeClaudeCLI:
|
||||
return NewClaudeCliProvider(sel.workspace), nil
|
||||
case providerTypeCodexCLI:
|
||||
return NewCodexCliProvider(sel.workspace), nil
|
||||
case providerTypeGitHubCopilot:
|
||||
return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model)
|
||||
default:
|
||||
return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestResolveProviderSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*config.Config)
|
||||
wantType providerType
|
||||
wantAPIBase string
|
||||
wantProxy string
|
||||
wantErrSubstr string
|
||||
}{
|
||||
{
|
||||
name: "explicit claude-cli provider routes to cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = "/tmp/ws"
|
||||
},
|
||||
wantType: providerTypeClaudeCLI,
|
||||
},
|
||||
{
|
||||
name: "explicit copilot provider routes to github copilot type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "copilot"
|
||||
},
|
||||
wantType: providerTypeGitHubCopilot,
|
||||
wantAPIBase: "localhost:4321",
|
||||
},
|
||||
{
|
||||
name: "explicit deepseek provider uses deepseek defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "deepseek"
|
||||
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
|
||||
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
|
||||
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.deepseek.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit shengsuanyun provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "shengsuanyun"
|
||||
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
|
||||
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit nvidia provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "nvidia"
|
||||
cfg.Providers.Nvidia.APIKey = "nvapi-test"
|
||||
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://integrate.api.nvidia.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "openrouter model uses openrouter defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
{
|
||||
name: "anthropic oauth routes to claude auth provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929"
|
||||
cfg.Providers.Anthropic.AuthMethod = "oauth"
|
||||
},
|
||||
wantType: providerTypeClaudeAuth,
|
||||
},
|
||||
{
|
||||
name: "openai oauth routes to codex auth provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "gpt-4o"
|
||||
cfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
},
|
||||
wantType: providerTypeCodexAuth,
|
||||
},
|
||||
{
|
||||
name: "openai codex-cli auth routes to codex cli token provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "gpt-4o"
|
||||
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
|
||||
},
|
||||
wantType: providerTypeCodexCLIToken,
|
||||
},
|
||||
{
|
||||
name: "explicit codex-code provider routes to codex cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "codex-code"
|
||||
cfg.Agents.Defaults.Workspace = "/tmp/ws"
|
||||
},
|
||||
wantType: providerTypeCodexCLI,
|
||||
},
|
||||
{
|
||||
name: "zhipu model uses zhipu base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "glm-4.7"
|
||||
cfg.Providers.Zhipu.APIKey = "zhipu-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://open.bigmodel.cn/api/paas/v4",
|
||||
},
|
||||
{
|
||||
name: "groq model uses groq base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "groq/llama-3.3-70b"
|
||||
cfg.Providers.Groq.APIKey = "gsk-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.groq.com/openai/v1",
|
||||
},
|
||||
{
|
||||
name: "ollama model uses ollama base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b"
|
||||
cfg.Providers.Ollama.APIKey = "ollama-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:11434/v1",
|
||||
},
|
||||
{
|
||||
name: "moonshot model keeps proxy and default base",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5"
|
||||
cfg.Providers.Moonshot.APIKey = "moonshot-key"
|
||||
cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.moonshot.cn/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "missing keys returns model config error",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "custom-model"
|
||||
},
|
||||
wantErrSubstr: "no API key configured for model",
|
||||
},
|
||||
{
|
||||
name: "openrouter prefix without key returns provider key error",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
},
|
||||
wantErrSubstr: "no API key configured for provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
tt.setup(cfg)
|
||||
|
||||
got, err := resolveProviderSelection(cfg)
|
||||
if tt.wantErrSubstr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
|
||||
t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("resolveProviderSelection() error = %v", err)
|
||||
}
|
||||
if got.providerType != tt.wantType {
|
||||
t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType)
|
||||
}
|
||||
if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase {
|
||||
t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase)
|
||||
}
|
||||
if tt.wantProxy != "" && got.proxy != tt.wantProxy {
|
||||
t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*HTTPProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *HTTPProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "codex-code"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*CodexCliProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *CodexCliProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "openai"
|
||||
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*CodexProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *CodexProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
|
||||
originalGetCredential := getCredential
|
||||
t.Cleanup(func() { getCredential = originalGetCredential })
|
||||
|
||||
getCredential = func(provider string) (*auth.AuthCredential, error) {
|
||||
if provider != "anthropic" {
|
||||
t.Fatalf("provider = %q, want anthropic", provider)
|
||||
}
|
||||
return &auth.AuthCredential{
|
||||
AccessToken: "anthropic-token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "anthropic"
|
||||
cfg.Providers.Anthropic.AuthMethod = "oauth"
|
||||
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
claudeProvider, ok := provider.(*ClaudeProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
|
||||
}
|
||||
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
|
||||
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
|
||||
originalGetCredential := getCredential
|
||||
t.Cleanup(func() { getCredential = originalGetCredential })
|
||||
|
||||
getCredential = func(provider string) (*auth.AuthCredential, error) {
|
||||
if provider != "openai" {
|
||||
t.Fatalf("provider = %q, want openai", provider)
|
||||
}
|
||||
return &auth.AuthCredential{
|
||||
AccessToken: "openai-token",
|
||||
AccountID: "acct_123",
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "openai"
|
||||
cfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*CodexProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *CodexProvider", provider)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FallbackChain orchestrates model fallback across multiple candidates.
|
||||
type FallbackChain struct {
|
||||
cooldown *CooldownTracker
|
||||
}
|
||||
|
||||
// FallbackCandidate represents one model/provider to try.
|
||||
type FallbackCandidate struct {
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// FallbackResult contains the successful response and metadata about all attempts.
|
||||
type FallbackResult struct {
|
||||
Response *LLMResponse
|
||||
Provider string
|
||||
Model string
|
||||
Attempts []FallbackAttempt
|
||||
}
|
||||
|
||||
// FallbackAttempt records one attempt in the fallback chain.
|
||||
type FallbackAttempt struct {
|
||||
Provider string
|
||||
Model string
|
||||
Error error
|
||||
Reason FailoverReason
|
||||
Duration time.Duration
|
||||
Skipped bool // true if skipped due to cooldown
|
||||
}
|
||||
|
||||
// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
|
||||
func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
|
||||
return &FallbackChain{cooldown: cooldown}
|
||||
}
|
||||
|
||||
// ResolveCandidates parses model config into a deduplicated candidate list.
|
||||
func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate {
|
||||
seen := make(map[string]bool)
|
||||
var candidates []FallbackCandidate
|
||||
|
||||
addCandidate := func(raw string) {
|
||||
ref := ParseModelRef(raw, defaultProvider)
|
||||
if ref == nil {
|
||||
return
|
||||
}
|
||||
key := ModelKey(ref.Provider, ref.Model)
|
||||
if seen[key] {
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
candidates = append(candidates, FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
})
|
||||
}
|
||||
|
||||
// Primary first.
|
||||
addCandidate(cfg.Primary)
|
||||
|
||||
// Then fallbacks.
|
||||
for _, fb := range cfg.Fallbacks {
|
||||
addCandidate(fb)
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// Execute runs the fallback chain for text/chat requests.
|
||||
// It tries each candidate in order, respecting cooldowns and error classification.
|
||||
//
|
||||
// Behavior:
|
||||
// - Candidates in cooldown are skipped (logged as skipped attempt).
|
||||
// - context.Canceled aborts immediately (user abort, no fallback).
|
||||
// - Non-retriable errors (format) abort immediately.
|
||||
// - Retriable errors trigger fallback to next candidate.
|
||||
// - Success marks provider as good (resets cooldown).
|
||||
// - If all fail, returns aggregate error with all attempts.
|
||||
func (fc *FallbackChain) Execute(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("fallback: no candidates configured")
|
||||
}
|
||||
|
||||
result := &FallbackResult{
|
||||
Attempts: make([]FallbackAttempt, 0, len(candidates)),
|
||||
}
|
||||
|
||||
for i, candidate := range candidates {
|
||||
// Check context before each attempt.
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Check cooldown.
|
||||
if !fc.cooldown.IsAvailable(candidate.Provider) {
|
||||
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute the run function.
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
// Success.
|
||||
fc.cooldown.MarkSuccess(candidate.Provider)
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Context cancellation: abort immediately, no fallback.
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Classify the error.
|
||||
failErr := ClassifyError(err, candidate.Provider, candidate.Model)
|
||||
|
||||
if failErr == nil {
|
||||
// Unclassifiable error: do not fallback, return immediately.
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w",
|
||||
candidate.Provider, candidate.Model, err)
|
||||
}
|
||||
|
||||
// Non-retriable error: abort immediately.
|
||||
if !failErr.IsRetriable() {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: failErr,
|
||||
Reason: failErr.Reason,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, failErr
|
||||
}
|
||||
|
||||
// Retriable error: mark failure and continue to next candidate.
|
||||
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: failErr,
|
||||
Reason: failErr.Reason,
|
||||
Duration: elapsed,
|
||||
})
|
||||
|
||||
// If this was the last candidate, return aggregate error.
|
||||
if i == len(candidates)-1 {
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
}
|
||||
|
||||
// All candidates were skipped (all in cooldown).
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
|
||||
// ExecuteImage runs the fallback chain for image/vision requests.
|
||||
// Simpler than Execute: no cooldown checks (image endpoints have different rate limits).
|
||||
// Image dimension/size errors abort immediately (non-retriable).
|
||||
func (fc *FallbackChain) ExecuteImage(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("image fallback: no candidates configured")
|
||||
}
|
||||
|
||||
result := &FallbackResult{
|
||||
Attempts: make([]FallbackAttempt, 0, len(candidates)),
|
||||
}
|
||||
|
||||
for i, candidate := range candidates {
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Image dimension/size errors are non-retriable.
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Reason: FailoverFormat,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, &FailoverError{
|
||||
Reason: FailoverFormat,
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Any other error: record and try next.
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
|
||||
if i == len(candidates)-1 {
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
|
||||
// FallbackExhaustedError indicates all fallback candidates were tried and failed.
|
||||
type FallbackExhaustedError struct {
|
||||
Attempts []FallbackAttempt
|
||||
}
|
||||
|
||||
func (e *FallbackExhaustedError) Error() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts)))
|
||||
for i, a := range e.Attempts {
|
||||
if a.Skipped {
|
||||
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)",
|
||||
i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond)))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,473 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func makeCandidate(provider, model string) FallbackCandidate {
|
||||
return FallbackCandidate{Provider: provider, Model: model}
|
||||
}
|
||||
|
||||
func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return &LLMResponse{Content: content, FinishReason: "stop"}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SingleCandidate_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "hello" {
|
||||
t.Errorf("content = %q, want hello", result.Response.Content)
|
||||
}
|
||||
if result.Provider != "openai" || result.Model != "gpt-4" {
|
||||
t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SecondCandidateSuccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude-opus"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
if result.Response.Content != "from claude" {
|
||||
t.Errorf("content = %q, want 'from claude'", result.Response.Content)
|
||||
}
|
||||
if len(result.Attempts) != 1 {
|
||||
t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_AllFail(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
makeCandidate("groq", "llama"),
|
||||
}
|
||||
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all candidates fail")
|
||||
}
|
||||
var exhausted *FallbackExhaustedError
|
||||
if !errors.As(err, &exhausted) {
|
||||
t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err)
|
||||
}
|
||||
if len(exhausted.Attempts) != 3 {
|
||||
t.Errorf("attempts = %d, want 3", len(exhausted.Attempts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_ContextCanceled(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
cancel() // cancel context
|
||||
return nil, context.Canceled
|
||||
}
|
||||
t.Error("should not reach second candidate after cancel")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, err := fc.Execute(ctx, candidates, run)
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NonRetriableError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("string should match pattern")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-retriable")
|
||||
}
|
||||
var fe *FailoverError
|
||||
if !errors.As(err, &fe) {
|
||||
t.Fatalf("expected FailoverError, got %T", err)
|
||||
}
|
||||
if fe.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", fe.Reason)
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_CooldownSkip(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, _ := newTestTracker(now)
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put openai in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
if provider == "openai" {
|
||||
t.Error("should not call openai (in cooldown)")
|
||||
}
|
||||
return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
// Should have 1 skipped attempt
|
||||
skipped := 0
|
||||
for _, a := range result.Attempts {
|
||||
if a.Skipped {
|
||||
skipped++
|
||||
}
|
||||
}
|
||||
if skipped != 1 {
|
||||
t.Errorf("skipped = %d, want 1", skipped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_AllInCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put all providers in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("anthropic", FailoverBilling)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates,
|
||||
func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
t.Error("should not call any provider (all in cooldown)")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all in cooldown")
|
||||
}
|
||||
var exhausted *FallbackExhaustedError
|
||||
if !errors.As(err, &exhausted) {
|
||||
t.Fatalf("expected FallbackExhaustedError, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty candidates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_EmptyFallbacks(t *testing.T) {
|
||||
// Single primary, no fallbacks: should work like direct call
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "ok" {
|
||||
t.Error("expected success with single candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_UnclassifiedError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("completely unknown internal error")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unclassified error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
|
||||
}
|
||||
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("success should reset cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Image Fallback Tests ---
|
||||
|
||||
func TestImageFallback_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
|
||||
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "image result" {
|
||||
t.Error("expected image result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_DimensionError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("image dimensions exceed max 4096x4096")
|
||||
}
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for image dimension error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_SizeError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("image exceeds 20 mb")
|
||||
}
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for image size error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude-sonnet"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty candidates")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResolveCandidates Tests ---
|
||||
|
||||
func TestResolveCandidates_Simple(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 3 {
|
||||
t.Fatalf("candidates = %d, want 3", len(candidates))
|
||||
}
|
||||
|
||||
if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" {
|
||||
t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model)
|
||||
}
|
||||
if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" {
|
||||
t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model)
|
||||
}
|
||||
if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" {
|
||||
t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_Deduplication(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "openai/gpt-4",
|
||||
Fallbacks: []string{"openai/gpt-4", "anthropic/claude"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "default")
|
||||
if len(candidates) != 2 {
|
||||
t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_EmptyFallbacks(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: nil,
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 1 {
|
||||
t.Errorf("candidates = %d, want 1", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_EmptyPrimary(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "",
|
||||
Fallbacks: []string{"anthropic/claude"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 1 {
|
||||
t.Errorf("candidates = %d, want 1", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackExhaustedError_Message(t *testing.T) {
|
||||
e := &FallbackExhaustedError{
|
||||
Attempts: []FallbackAttempt{
|
||||
{Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond},
|
||||
{Provider: "anthropic", Model: "claude", Skipped: true},
|
||||
},
|
||||
}
|
||||
msg := e.Error()
|
||||
if msg == "" {
|
||||
t.Error("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
@@ -7,201 +7,29 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
|
||||
)
|
||||
|
||||
type HTTPProvider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
delegate *openai_compat.Provider
|
||||
}
|
||||
|
||||
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
|
||||
}
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
proxyURL, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &HTTPProvider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
maxTokensField: maxTokensField,
|
||||
httpClient: client,
|
||||
delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
|
||||
if idx := strings.Index(model, "/"); idx != -1 {
|
||||
prefix := model[:idx]
|
||||
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" || prefix == "qwen" || prefix == "cerebras" {
|
||||
model = model[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
// Use configured max_tokens_field if specified, otherwise fallback to model-based detection
|
||||
fieldName := p.maxTokensField
|
||||
if fieldName == "" {
|
||||
// Fallback: detect from model name for backward compatibility
|
||||
lowerModel := strings.ToLower(model)
|
||||
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
|
||||
fieldName = "max_completion_tokens"
|
||||
} else {
|
||||
fieldName = "max_tokens"
|
||||
}
|
||||
}
|
||||
requestBody[fieldName] = maxTokens
|
||||
}
|
||||
|
||||
if temperature, ok := options["temperature"].(float64); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
requestBody["temperature"] = 1.0
|
||||
} else {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return p.parseResponse(body)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
thoughtSignature := ""
|
||||
argsStr := ""
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
argsStr = tc.Function.Arguments
|
||||
if argsStr != "" {
|
||||
if err := json.Unmarshal([]byte(argsStr), &arguments); err != nil {
|
||||
arguments["raw"] = argsStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: &FunctionCall{
|
||||
Name: name,
|
||||
Arguments: argsStr,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
})
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
return p.delegate.Chat(ctx, messages, tools, model, options)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package providers
|
||||
|
||||
import "strings"
|
||||
|
||||
// ModelRef represents a parsed model reference with provider and model name.
|
||||
type ModelRef struct {
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}.
|
||||
// If no slash present, uses defaultProvider.
|
||||
// Returns nil for empty input.
|
||||
func ParseModelRef(raw string, defaultProvider string) *ModelRef {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if idx := strings.Index(raw, "/"); idx > 0 {
|
||||
provider := NormalizeProvider(raw[:idx])
|
||||
model := strings.TrimSpace(raw[idx+1:])
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
return &ModelRef{Provider: provider, Model: model}
|
||||
}
|
||||
|
||||
return &ModelRef{
|
||||
Provider: NormalizeProvider(defaultProvider),
|
||||
Model: raw,
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeProvider normalizes provider identifiers to canonical form.
|
||||
func NormalizeProvider(provider string) string {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
|
||||
switch p {
|
||||
case "z.ai", "z-ai":
|
||||
return "zai"
|
||||
case "opencode-zen":
|
||||
return "opencode"
|
||||
case "qwen":
|
||||
return "qwen-portal"
|
||||
case "kimi-code":
|
||||
return "kimi-coding"
|
||||
case "gpt":
|
||||
return "openai"
|
||||
case "claude":
|
||||
return "anthropic"
|
||||
case "glm":
|
||||
return "zhipu"
|
||||
case "google":
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// ModelKey returns a canonical "provider/model" key for deduplication.
|
||||
func ModelKey(provider, model string) string {
|
||||
return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model))
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseModelRef_WithSlash(t *testing.T) {
|
||||
ref := ParseModelRef("anthropic/claude-opus", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", ref.Provider)
|
||||
}
|
||||
if ref.Model != "claude-opus" {
|
||||
t.Errorf("model = %q, want claude-opus", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_WithoutSlash(t *testing.T) {
|
||||
ref := ParseModelRef("gpt-4", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "openai" {
|
||||
t.Errorf("provider = %q, want openai", ref.Provider)
|
||||
}
|
||||
if ref.Model != "gpt-4" {
|
||||
t.Errorf("model = %q, want gpt-4", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_Empty(t *testing.T) {
|
||||
ref := ParseModelRef("", "openai")
|
||||
if ref != nil {
|
||||
t.Errorf("expected nil for empty string, got %+v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) {
|
||||
ref := ParseModelRef("openai/", "default")
|
||||
if ref != nil {
|
||||
t.Errorf("expected nil for empty model, got %+v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_WhitespaceHandling(t *testing.T) {
|
||||
ref := ParseModelRef(" anthropic / claude-opus ", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", ref.Provider)
|
||||
}
|
||||
if ref.Model != "claude-opus" {
|
||||
t.Errorf("model = %q, want claude-opus", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"OpenAI", "openai"},
|
||||
{"ANTHROPIC", "anthropic"},
|
||||
{"z.ai", "zai"},
|
||||
{"z-ai", "zai"},
|
||||
{"Z.AI", "zai"},
|
||||
{"opencode-zen", "opencode"},
|
||||
{"qwen", "qwen-portal"},
|
||||
{"kimi-code", "kimi-coding"},
|
||||
{"gpt", "openai"},
|
||||
{"claude", "anthropic"},
|
||||
{"glm", "zhipu"},
|
||||
{"google", "gemini"},
|
||||
{"groq", "groq"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := NormalizeProvider(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
provider string
|
||||
model string
|
||||
want string
|
||||
}{
|
||||
{"openai", "gpt-4", "openai/gpt-4"},
|
||||
{"Anthropic", "Claude-Opus", "anthropic/claude-opus"},
|
||||
{"claude", "sonnet", "anthropic/sonnet"},
|
||||
{"z.ai", "Model-X", "zai/model-x"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := ModelKey(tt.provider, tt.model)
|
||||
if got != tt.want {
|
||||
t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_ProviderNormalization(t *testing.T) {
|
||||
ref := ParseModelRef("Z.AI/model-x", "default")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "zai" {
|
||||
t.Errorf("provider = %q, want zai", ref.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_DefaultProviderNormalization(t *testing.T) {
|
||||
ref := ParseModelRef("gpt-4o", "GPT")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "openai" {
|
||||
t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
} else {
|
||||
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
model = normalizeModel(model, p.apiBase)
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := asInt(options["max_tokens"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
} else {
|
||||
requestBody["max_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
if temperature, ok := asFloat(options["temperature"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1.
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
requestBody["temperature"] = 1.0
|
||||
} else {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return parseResponse(body)
|
||||
}
|
||||
|
||||
func parseResponse(body []byte) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
})
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
idx := strings.Index(model, "/")
|
||||
if idx == -1 {
|
||||
return model
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
|
||||
return model
|
||||
}
|
||||
|
||||
prefix := strings.ToLower(model[:idx])
|
||||
switch prefix {
|
||||
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
|
||||
return model[idx+1:]
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func asInt(v interface{}) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := requestBody["max_completion_tokens"]; !ok {
|
||||
t.Fatalf("expected max_completion_tokens in request body")
|
||||
}
|
||||
if _, ok := requestBody["max_tokens"]; ok {
|
||||
t.Fatalf("did not expect max_tokens key for glm model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ParsesToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{
|
||||
"content": "",
|
||||
"tool_calls": []map[string]interface{}{
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"SF\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["city"] != "SF" {
|
||||
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_HTTPError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"moonshot/kimi-k2.5",
|
||||
map[string]interface{}{"temperature": 0.3},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["model"] != "kimi-k2.5" {
|
||||
t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"])
|
||||
}
|
||||
if requestBody["temperature"] != 1.0 {
|
||||
t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "strips groq prefix and keeps nested model",
|
||||
input: "groq/openai/gpt-oss-120b",
|
||||
wantModel: "openai/gpt-oss-120b",
|
||||
},
|
||||
{
|
||||
name: "strips ollama prefix",
|
||||
input: "ollama/qwen2.5:14b",
|
||||
wantModel: "qwen2.5:14b",
|
||||
},
|
||||
{
|
||||
name: "strips deepseek prefix",
|
||||
input: "deepseek/deepseek-chat",
|
||||
wantModel: "deepseek-chat",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["model"] != tt.wantModel {
|
||||
t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ProxyConfigured(t *testing.T) {
|
||||
proxyURL := "http://127.0.0.1:8080"
|
||||
p := NewProvider("key", "https://example.com", proxyURL)
|
||||
|
||||
transport, ok := p.httpClient.Transport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport)
|
||||
}
|
||||
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
|
||||
gotProxy, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy function returned error: %v", err)
|
||||
}
|
||||
if gotProxy == nil || gotProxy.String() != proxyURL {
|
||||
t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["max_tokens"] != float64(512) {
|
||||
t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"])
|
||||
}
|
||||
if requestBody["temperature"] != float64(1) {
|
||||
t.Fatalf("temperature = %v, want 1", requestBody["temperature"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" {
|
||||
t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat")
|
||||
}
|
||||
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
|
||||
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package protocoltypes
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
+51
-40
@@ -1,53 +1,64 @@
|
||||
package providers
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
ThoughtSignature string `json:"thought_signature,omitempty"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
type LLMProvider interface {
|
||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
|
||||
GetDefaultModel() string
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
// FailoverReason classifies why an LLM request failed for fallback decisions.
|
||||
type FailoverReason string
|
||||
|
||||
const (
|
||||
FailoverAuth FailoverReason = "auth"
|
||||
FailoverRateLimit FailoverReason = "rate_limit"
|
||||
FailoverBilling FailoverReason = "billing"
|
||||
FailoverTimeout FailoverReason = "timeout"
|
||||
FailoverFormat FailoverReason = "format"
|
||||
FailoverOverloaded FailoverReason = "overloaded"
|
||||
FailoverUnknown FailoverReason = "unknown"
|
||||
)
|
||||
|
||||
// FailoverError wraps an LLM provider error with classification metadata.
|
||||
type FailoverError struct {
|
||||
Reason FailoverReason
|
||||
Provider string
|
||||
Model string
|
||||
Status int
|
||||
Wrapped error
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
func (e *FailoverError) Error() string {
|
||||
return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v",
|
||||
e.Reason, e.Provider, e.Model, e.Status, e.Wrapped)
|
||||
}
|
||||
|
||||
func (e *FailoverError) Unwrap() error {
|
||||
return e.Wrapped
|
||||
}
|
||||
|
||||
// IsRetriable returns true if this error should trigger fallback to next candidate.
|
||||
// Non-retriable: Format errors (bad request structure, image dimension/size).
|
||||
func (e *FailoverError) IsRetriable() bool {
|
||||
return e.Reason != FailoverFormat
|
||||
}
|
||||
|
||||
// ModelConfig holds primary model and fallback list.
|
||||
type ModelConfig struct {
|
||||
Primary string
|
||||
Fallbacks []string
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user