diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go new file mode 100644 index 000000000..ca72f0180 --- /dev/null +++ b/pkg/providers/anthropic/provider.go @@ -0,0 +1,241 @@ +package anthropicprovider + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" +) + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} + +type Provider struct { + client *anthropic.Client + tokenSource func() (string, error) +} + +func NewProvider(token string) *Provider { + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL("https://api.anthropic.com"), + ) + return &Provider{client: &client} +} + +func NewProviderWithClient(client *anthropic.Client) *Provider { + return &Provider{client: client} +} + +func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider { + p := NewProvider(token) + p.tokenSource = tokenSource + return p +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseResponse(resp), nil +} + +func (p *Provider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateTools(tools) + } + + return params, nil +} + +func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go new file mode 100644 index 000000000..01b4fe663 --- /dev/null +++ b/pkg/providers/anthropic/provider_test.go @@ -0,0 +1,208 @@ +package anthropicprovider + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestProvider_GetDefaultModel(t *testing.T) { + p := NewProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go index ae6aca96d..16f1884c5 100644 --- a/pkg/providers/claude_provider.go +++ b/pkg/providers/claude_provider.go @@ -2,195 +2,48 @@ package providers import ( "context" - "encoding/json" "fmt" - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" "github.com/sipeed/picoclaw/pkg/auth" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) type ClaudeProvider struct { - client *anthropic.Client - tokenSource func() (string, error) + delegate *anthropicprovider.Provider } func NewClaudeProvider(token string) *ClaudeProvider { - client := anthropic.NewClient( - option.WithAuthToken(token), - option.WithBaseURL("https://api.anthropic.com"), - ) - return &ClaudeProvider{client: &client} + return &ClaudeProvider{ + delegate: anthropicprovider.NewProvider(token), + } } func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { - p := NewClaudeProvider(token) - p.tokenSource = tokenSource - return p + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource), + } +} + +func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider { + return &ClaudeProvider{delegate: delegate} } func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - var opts []option.RequestOption - if p.tokenSource != nil { - tok, err := p.tokenSource() - if err != nil { - return nil, fmt.Errorf("refreshing token: %w", err) - } - opts = append(opts, option.WithAuthToken(tok)) - } - - params, err := buildClaudeParams(messages, tools, model, options) + resp, err := p.delegate.Chat( + ctx, + toAnthropicProviderMessages(messages), + toAnthropicProviderTools(tools), + model, + options, + ) if err != nil { return nil, err } - - resp, err := p.client.Messages.New(ctx, params, opts...) - if err != nil { - return nil, fmt.Errorf("claude API call: %w", err) - } - - return parseClaudeResponse(resp), nil + return fromAnthropicProviderResponse(resp), nil } func (p *ClaudeProvider) GetDefaultModel() string { - return "claude-sonnet-4-5-20250929" -} - -func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { - var system []anthropic.TextBlockParam - var anthropicMessages []anthropic.MessageParam - - for _, msg := range messages { - switch msg.Role { - case "system": - system = append(system, anthropic.TextBlockParam{Text: msg.Content}) - case "user": - if msg.ToolCallID != "" { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "assistant": - if len(msg.ToolCalls) > 0 { - var blocks []anthropic.ContentBlockParamUnion - if msg.Content != "" { - blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) - } - for _, tc := range msg.ToolCalls { - blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) - } - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "tool": - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } - } - - maxTokens := int64(4096) - if mt, ok := options["max_tokens"].(int); ok { - maxTokens = int64(mt) - } - - params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), - Messages: anthropicMessages, - MaxTokens: maxTokens, - } - - if len(system) > 0 { - params.System = system - } - - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = anthropic.Float(temp) - } - - if len(tools) > 0 { - params.Tools = translateToolsForClaude(tools) - } - - return params, nil -} - -func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { - result := make([]anthropic.ToolUnionParam, 0, len(tools)) - for _, t := range tools { - tool := anthropic.ToolParam{ - Name: t.Function.Name, - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: t.Function.Parameters["properties"], - }, - } - if desc := t.Function.Description; desc != "" { - tool.Description = anthropic.String(desc) - } - if req, ok := t.Function.Parameters["required"].([]interface{}); ok { - required := make([]string, 0, len(req)) - for _, r := range req { - if s, ok := r.(string); ok { - required = append(required, s) - } - } - tool.InputSchema.Required = required - } - result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) - } - return result -} - -func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { - var content string - var toolCalls []ToolCall - - for _, block := range resp.Content { - switch block.Type { - case "text": - tb := block.AsText() - content += tb.Text - case "tool_use": - tu := block.AsToolUse() - var args map[string]interface{} - if err := json.Unmarshal(tu.Input, &args); err != nil { - args = map[string]interface{}{"raw": string(tu.Input)} - } - toolCalls = append(toolCalls, ToolCall{ - ID: tu.ID, - Name: tu.Name, - Arguments: args, - }) - } - } - - finishReason := "stop" - switch resp.StopReason { - case anthropic.StopReasonToolUse: - finishReason = "tool_calls" - case anthropic.StopReasonMaxTokens: - finishReason = "length" - case anthropic.StopReasonEndTurn: - finishReason = "stop" - } - - return &LLMResponse{ - Content: content, - ToolCalls: toolCalls, - FinishReason: finishReason, - Usage: &UsageInfo{ - PromptTokens: int(resp.Usage.InputTokens), - CompletionTokens: int(resp.Usage.OutputTokens), - TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), - }, - } + return p.delegate.GetDefaultModel() } func createClaudeTokenSource() func() (string, error) { @@ -205,3 +58,95 @@ func createClaudeTokenSource() func() (string, error) { return cred.AccessToken, nil } } + +func toAnthropicProviderMessages(messages []Message) []anthropicprovider.Message { + out := make([]anthropicprovider.Message, 0, len(messages)) + for _, msg := range messages { + out = append(out, anthropicprovider.Message{ + Role: msg.Role, + Content: msg.Content, + ToolCalls: toAnthropicProviderToolCalls(msg.ToolCalls), + ToolCallID: msg.ToolCallID, + }) + } + return out +} + +func toAnthropicProviderTools(tools []ToolDefinition) []anthropicprovider.ToolDefinition { + out := make([]anthropicprovider.ToolDefinition, 0, len(tools)) + for _, t := range tools { + out = append(out, anthropicprovider.ToolDefinition{ + Type: t.Type, + Function: anthropicprovider.ToolFunctionDefinition{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + }, + }) + } + return out +} + +func toAnthropicProviderToolCalls(toolCalls []ToolCall) []anthropicprovider.ToolCall { + out := make([]anthropicprovider.ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *anthropicprovider.FunctionCall + if tc.Function != nil { + fn = &anthropicprovider.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + } + } + out = append(out, anthropicprovider.ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out +} + +func fromAnthropicProviderResponse(resp *anthropicprovider.LLMResponse) *LLMResponse { + if resp == nil { + return &LLMResponse{} + } + + var usage *UsageInfo + if resp.Usage != nil { + usage = &UsageInfo{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } + } + + return &LLMResponse{ + Content: resp.Content, + ToolCalls: fromAnthropicProviderToolCalls(resp.ToolCalls), + FinishReason: resp.FinishReason, + Usage: usage, + } +} + +func fromAnthropicProviderToolCalls(toolCalls []anthropicprovider.ToolCall) []ToolCall { + out := make([]ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + var fn *FunctionCall + if tc.Function != nil { + fn = &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + } + } + out = append(out, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: fn, + Name: tc.Name, + Arguments: tc.Arguments, + }) + } + return out +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go index bbad2d269..13bbde1fc 100644 --- a/pkg/providers/claude_provider_test.go +++ b/pkg/providers/claude_provider_test.go @@ -8,140 +8,9 @@ import ( "github.com/anthropics/anthropic-sdk-go" anthropicoption "github.com/anthropics/anthropic-sdk-go/option" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) -func TestBuildClaudeParams_BasicMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "Hello"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ - "max_tokens": 1024, - }) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if string(params.Model) != "claude-sonnet-4-5-20250929" { - t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") - } - if params.MaxTokens != 1024 { - t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_SystemMessage(t *testing.T) { - messages := []Message{ - {Role: "system", Content: "You are helpful"}, - {Role: "user", Content: "Hi"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.System) != 1 { - t.Fatalf("len(System) = %d, want 1", len(params.System)) - } - if params.System[0].Text != "You are helpful" { - t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{ - { - ID: "call_1", - Name: "get_weather", - Arguments: map[string]interface{}{"city": "SF"}, - }, - }, - }, - {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Messages) != 3 { - t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) - } -} - -func TestBuildClaudeParams_WithTools(t *testing.T) { - tools := []ToolDefinition{ - { - Type: "function", - Function: ToolFunctionDefinition{ - Name: "get_weather", - Description: "Get weather for a city", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{"type": "string"}, - }, - "required": []interface{}{"city"}, - }, - }, - }, - } - params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Tools) != 1 { - t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) - } -} - -func TestParseClaudeResponse_TextOnly(t *testing.T) { - resp := &anthropic.Message{ - Content: []anthropic.ContentBlockUnion{}, - Usage: anthropic.Usage{ - InputTokens: 10, - OutputTokens: 20, - }, - } - result := parseClaudeResponse(resp) - if result.Usage.PromptTokens != 10 { - t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) - } - if result.Usage.CompletionTokens != 20 { - t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) - } - if result.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") - } -} - -func TestParseClaudeResponse_StopReasons(t *testing.T) { - tests := []struct { - stopReason anthropic.StopReason - want string - }{ - {anthropic.StopReasonEndTurn, "stop"}, - {anthropic.StopReasonMaxTokens, "length"}, - {anthropic.StopReasonToolUse, "tool_calls"}, - } - for _, tt := range tests { - resp := &anthropic.Message{ - StopReason: tt.stopReason, - } - result := parseClaudeResponse(resp) - if result.FinishReason != tt.want { - t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) - } - } -} - func TestClaudeProvider_ChatRoundTrip(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/messages" { @@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) { })) defer server.Close() - provider := NewClaudeProvider("test-token") - provider.client = createAnthropicTestClient(server.URL, "test-token") + delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + provider := newClaudeProviderWithDelegate(delegate) messages := []Message{{Role: "user", Content: "Hello"}} resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})