mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(providers): reorganize provider packages and facades
This commit is contained in:
@@ -0,0 +1,205 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
)
|
||||
|
||||
// ClaudeCliProvider implements LLMProvider using the claude CLI as a subprocess.
|
||||
type ClaudeCliProvider struct {
|
||||
command string
|
||||
workspace string
|
||||
}
|
||||
|
||||
// NewClaudeCliProvider creates a new Claude CLI provider.
|
||||
func NewClaudeCliProvider(workspace string) *ClaudeCliProvider {
|
||||
return &ClaudeCliProvider{
|
||||
command: "claude",
|
||||
workspace: workspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Chat implements LLMProvider.Chat by executing the claude CLI.
|
||||
func (p *ClaudeCliProvider) Chat(
|
||||
ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
systemPrompt := p.buildSystemPrompt(messages, tools)
|
||||
prompt := p.messagesToPrompt(messages)
|
||||
|
||||
args := []string{"-p", "--output-format", "json", "--dangerously-skip-permissions", "--no-chrome"}
|
||||
if systemPrompt != "" {
|
||||
args = append(args, "--system-prompt", systemPrompt)
|
||||
}
|
||||
if model != "" && model != "claude-code" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
args = append(args, "-") // read from stdin
|
||||
|
||||
cmd := exec.CommandContext(ctx, p.command, args...)
|
||||
if p.workspace != "" {
|
||||
cmd.Dir = p.workspace
|
||||
}
|
||||
cmd.Stdin = bytes.NewReader([]byte(prompt))
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// Execute the CLI through the shared isolation wrapper so external provider
|
||||
// processes honor the configured isolation policy.
|
||||
if err := isolation.Run(cmd); err != nil {
|
||||
stderrStr := strings.TrimSpace(stderr.String())
|
||||
stdoutStr := strings.TrimSpace(stdout.String())
|
||||
switch {
|
||||
case stderrStr != "" && stdoutStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %w\nstderr: %s\nstdout: %s", err, stderrStr, stdoutStr)
|
||||
case stderrStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %s", stderrStr)
|
||||
case stdoutStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %w\noutput: %s", err, stdoutStr)
|
||||
default:
|
||||
return nil, fmt.Errorf("claude cli error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return p.parseClaudeCliResponse(stdout.String())
|
||||
}
|
||||
|
||||
// GetDefaultModel returns the default model identifier.
|
||||
func (p *ClaudeCliProvider) GetDefaultModel() string {
|
||||
return "claude-code"
|
||||
}
|
||||
|
||||
// messagesToPrompt converts messages to a CLI-compatible prompt string.
|
||||
func (p *ClaudeCliProvider) messagesToPrompt(messages []Message) string {
|
||||
var parts []string
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
// handled via --system-prompt flag
|
||||
case "user":
|
||||
parts = append(parts, "User: "+msg.Content)
|
||||
case "assistant":
|
||||
parts = append(parts, "Assistant: "+msg.Content)
|
||||
case "tool":
|
||||
parts = append(parts, fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
|
||||
}
|
||||
}
|
||||
|
||||
// Simplify single user message
|
||||
if len(parts) == 1 && strings.HasPrefix(parts[0], "User: ") {
|
||||
return strings.TrimPrefix(parts[0], "User: ")
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// buildSystemPrompt combines system messages and tool definitions.
|
||||
func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDefinition) string {
|
||||
var parts []string
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
parts = append(parts, msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
parts = append(parts, buildCLIToolsPrompt(tools))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// parseClaudeCliResponse parses the JSON output from the claude CLI.
|
||||
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
|
||||
var resp claudeCliJSONResponse
|
||||
if err := json.Unmarshal([]byte(output), &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse claude cli response: %w", err)
|
||||
}
|
||||
|
||||
if resp.IsError {
|
||||
return nil, fmt.Errorf("claude cli returned error: %s", resp.Result)
|
||||
}
|
||||
|
||||
toolCalls := p.extractToolCalls(resp.Result)
|
||||
|
||||
finishReason := "stop"
|
||||
content := resp.Result
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
content = p.stripToolCallsJSON(resp.Result)
|
||||
}
|
||||
|
||||
var usage *UsageInfo
|
||||
if resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens,
|
||||
CompletionTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens + resp.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.TrimSpace(content),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractToolCalls delegates to the shared extractToolCallsFromText function.
|
||||
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
|
||||
return extractToolCallsFromText(text)
|
||||
}
|
||||
|
||||
// stripToolCallsJSON delegates to the shared stripToolCallsFromText function.
|
||||
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
|
||||
return stripToolCallsFromText(text)
|
||||
}
|
||||
|
||||
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
|
||||
func findMatchingBrace(text string, pos int) int {
|
||||
depth := 0
|
||||
for i := pos; i < len(text); i++ {
|
||||
if text[i] == '{' {
|
||||
depth++
|
||||
} else if text[i] == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
return pos
|
||||
}
|
||||
|
||||
// claudeCliJSONResponse represents the JSON output from the claude CLI.
|
||||
// Matches the real claude CLI v2.x output format.
|
||||
type claudeCliJSONResponse struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
IsError bool `json:"is_error"`
|
||||
Result string `json:"result"`
|
||||
SessionID string `json:"session_id"`
|
||||
TotalCostUSD float64 `json:"total_cost_usd"`
|
||||
DurationMS int `json:"duration_ms"`
|
||||
DurationAPI int `json:"duration_api_ms"`
|
||||
NumTurns int `json:"num_turns"`
|
||||
Usage claudeCliUsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
// claudeCliUsageInfo represents token usage from the claude CLI response.
|
||||
type claudeCliUsageInfo struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
//go:build integration
|
||||
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
exec "os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
|
||||
// Run with: go test -tags=integration ./pkg/providers/...
|
||||
func TestIntegration_RealClaudeCLI(t *testing.T) {
|
||||
// Check if claude CLI is available
|
||||
path, err := exec.LookPath("claude")
|
||||
if err != nil {
|
||||
t.Skip("claude CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using claude CLI at: %s", path)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
// Verify response structure
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil from real CLI")
|
||||
} else {
|
||||
if resp.Usage.PromptTokens == 0 {
|
||||
t.Error("PromptTokens should be > 0")
|
||||
}
|
||||
if resp.Usage.CompletionTokens == 0 {
|
||||
t.Error("CompletionTokens should be > 0")
|
||||
}
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run claude directly and verify our parser handles real output
|
||||
cmd := exec.Command("claude", "-p", "--output-format", "json",
|
||||
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
cmd.Dir = t.TempDir()
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
t.Fatalf("claude CLI failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output: %s", string(output))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewClaudeCliProvider("")
|
||||
resp, err := p.parseClaudeCliResponse(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil")
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
@@ -0,0 +1,907 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Compile-time interface check ---
|
||||
|
||||
var _ LLMProvider = (*ClaudeCliProvider)(nil)
|
||||
|
||||
// --- Helper: create mock CLI scripts ---
|
||||
|
||||
// createMockCLI creates a temporary script that simulates the claude CLI.
|
||||
// Uses files for stdout/stderr to avoid shell quoting issues with JSON.
|
||||
func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("mock CLI scripts not supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
if stdout != "" {
|
||||
if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if stderr != "" {
|
||||
if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("#!/bin/sh\n")
|
||||
if stderr != "" {
|
||||
sb.WriteString(fmt.Sprintf("cat '%s/stderr.txt' >&2\n", dir))
|
||||
}
|
||||
if stdout != "" {
|
||||
sb.WriteString(fmt.Sprintf("cat '%s/stdout.txt'\n", dir))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("exit %d\n", exitCode))
|
||||
|
||||
script := filepath.Join(dir, "claude")
|
||||
if err := os.WriteFile(script, []byte(sb.String()), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
// createSlowMockCLI creates a script that sleeps before responding (for context cancellation tests).
|
||||
func createSlowMockCLI(t *testing.T, sleepSeconds int) string {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("mock CLI scripts not supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
content := fmt.Sprintf("#!/bin/sh\nsleep %d\necho '{\"type\":\"result\",\"result\":\"late\"}'\n", sleepSeconds)
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
// createArgCaptureCLI creates a script that captures CLI args to a file, then outputs JSON.
|
||||
func createArgCaptureCLI(t *testing.T, argsFile string) string {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("mock CLI scripts not supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
script := filepath.Join(dir, "claude")
|
||||
content := fmt.Sprintf(`#!/bin/sh
|
||||
echo "$@" > '%s'
|
||||
cat <<'EOFMOCK'
|
||||
{"type":"result","result":"ok","session_id":"test"}
|
||||
EOFMOCK
|
||||
`, argsFile)
|
||||
if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return script
|
||||
}
|
||||
|
||||
// --- Constructor tests ---
|
||||
|
||||
func TestNewClaudeCliProvider(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/test/workspace")
|
||||
if p == nil {
|
||||
t.Fatal("NewClaudeCliProvider returned nil")
|
||||
}
|
||||
if p.workspace != "/test/workspace" {
|
||||
t.Errorf("workspace = %q, want %q", p.workspace, "/test/workspace")
|
||||
}
|
||||
if p.command != "claude" {
|
||||
t.Errorf("command = %q, want %q", p.command, "claude")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClaudeCliProvider_EmptyWorkspace(t *testing.T) {
|
||||
p := NewClaudeCliProvider("")
|
||||
if p.workspace != "" {
|
||||
t.Errorf("workspace = %q, want empty", p.workspace)
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetDefaultModel tests ---
|
||||
|
||||
func TestClaudeCliProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
if got := p.GetDefaultModel(); got != "claude-code" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-code")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Chat() tests ---
|
||||
|
||||
func TestChat_Success(t *testing.T) {
|
||||
mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Hello from mock!","session_id":"sess_123","total_cost_usd":0.005,"duration_ms":200,"duration_api_ms":150,"num_turns":1,"usage":{"input_tokens":10,"output_tokens":5,"cache_creation_input_tokens":100,"cache_read_input_tokens":0}}`
|
||||
script := createMockCLI(t, mockJSON, "", 0)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
resp, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if resp.Content != "Hello from mock!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello from mock!")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls len = %d, want 0", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 110 { // 10 + 100 + 0
|
||||
t.Errorf("PromptTokens = %d, want 110", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 5 {
|
||||
t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 115 { // 110 + 5
|
||||
t.Errorf("TotalTokens = %d, want 115", resp.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_IsErrorResponse(t *testing.T) {
|
||||
mockJSON := `{"type":"result","subtype":"error","is_error":true,"result":"Rate limit exceeded","session_id":"s1","total_cost_usd":0}`
|
||||
script := createMockCLI(t, mockJSON, "", 0)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected error when is_error=true")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Rate limit exceeded") {
|
||||
t.Errorf("error = %q, want to contain 'Rate limit exceeded'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_WithToolCallsInResponse(t *testing.T) {
|
||||
mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Checking weather.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"NYC\\\"}\"}}]}","session_id":"s1","total_cost_usd":0.01,"usage":{"input_tokens":5,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}`
|
||||
script := createMockCLI(t, mockJSON, "", 0)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
resp, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if resp.ToolCalls[0].Arguments["location"] != "NYC" {
|
||||
t.Errorf("ToolCalls[0].Arguments[location] = %v, want NYC", resp.ToolCalls[0].Arguments["location"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_StderrError(t *testing.T) {
|
||||
script := createMockCLI(t, "", "Error: rate limited", 1)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "rate limited") {
|
||||
t.Errorf("error = %q, want to contain 'rate limited'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_NonZeroExitNoStderr(t *testing.T) {
|
||||
script := createMockCLI(t, "", "", 1)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected error for non-zero exit")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "claude cli error") {
|
||||
t.Errorf("error = %q, want to contain 'claude cli error'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_CommandNotFound(t *testing.T) {
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = "/nonexistent/claude-binary-that-does-not-exist"
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected error for missing command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_InvalidResponseJSON(t *testing.T) {
|
||||
script := createMockCLI(t, "not valid json at all", "", 0)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected error for invalid JSON")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to parse claude cli response") {
|
||||
t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_ContextCancellation(t *testing.T) {
|
||||
script := createSlowMockCLI(t, 2) // sleep 2s
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Chat() expected error on context cancellation")
|
||||
}
|
||||
// Should fail well before the full 2s sleep completes
|
||||
if elapsed > 3*time.Second {
|
||||
t.Errorf("Chat() took %v, expected to fail faster via context cancellation", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_PassesSystemPromptFlag(t *testing.T) {
|
||||
argsFile := filepath.Join(t.TempDir(), "args.txt")
|
||||
script := createArgCaptureCLI(t, argsFile)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "system", Content: "Be helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
argsBytes, err := os.ReadFile(argsFile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read args file: %v", err)
|
||||
}
|
||||
args := string(argsBytes)
|
||||
if !strings.Contains(args, "--system-prompt") {
|
||||
t.Errorf("CLI args missing --system-prompt, got: %s", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_PassesModelFlag(t *testing.T) {
|
||||
argsFile := filepath.Join(t.TempDir(), "args.txt")
|
||||
script := createArgCaptureCLI(t, argsFile)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
}, nil, "claude-sonnet-4.6", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
argsBytes, _ := os.ReadFile(argsFile)
|
||||
args := string(argsBytes)
|
||||
if !strings.Contains(args, "--model") {
|
||||
t.Errorf("CLI args missing --model, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "claude-sonnet-4.6") {
|
||||
t.Errorf("CLI args missing model name, got: %s", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SkipsModelFlagForClaudeCode(t *testing.T) {
|
||||
argsFile := filepath.Join(t.TempDir(), "args.txt")
|
||||
script := createArgCaptureCLI(t, argsFile)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
}, nil, "claude-code", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
argsBytes, _ := os.ReadFile(argsFile)
|
||||
args := string(argsBytes)
|
||||
if strings.Contains(args, "--model") {
|
||||
t.Errorf("CLI args should NOT contain --model for claude-code, got: %s", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SkipsModelFlagForEmptyModel(t *testing.T) {
|
||||
argsFile := filepath.Join(t.TempDir(), "args.txt")
|
||||
script := createArgCaptureCLI(t, argsFile)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
p.command = script
|
||||
|
||||
_, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
argsBytes, _ := os.ReadFile(argsFile)
|
||||
args := string(argsBytes)
|
||||
if strings.Contains(args, "--model") {
|
||||
t.Errorf("CLI args should NOT contain --model for empty model, got: %s", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) {
|
||||
mockJSON := `{"type":"result","result":"ok","session_id":"s"}`
|
||||
script := createMockCLI(t, mockJSON, "", 0)
|
||||
|
||||
p := NewClaudeCliProvider("")
|
||||
p.command = script
|
||||
|
||||
resp, err := p.Chat(context.Background(), []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with empty workspace error = %v", err)
|
||||
}
|
||||
if resp.Content != "ok" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
// --- messagesToPrompt tests ---
|
||||
|
||||
func TestMessagesToPrompt_SingleUser(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
got := p.messagesToPrompt(messages)
|
||||
want := "Hello"
|
||||
if got != want {
|
||||
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToPrompt_Conversation(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
}
|
||||
got := p.messagesToPrompt(messages)
|
||||
want := "User: Hi\nAssistant: Hello!\nUser: How are you?"
|
||||
if got != want {
|
||||
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToPrompt_WithSystemMessage(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
got := p.messagesToPrompt(messages)
|
||||
want := "Hello"
|
||||
if got != want {
|
||||
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToPrompt_WithToolResults(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_123"},
|
||||
}
|
||||
got := p.messagesToPrompt(messages)
|
||||
if !strings.Contains(got, "[Tool Result for call_123]") {
|
||||
t.Errorf("messagesToPrompt() missing tool result marker, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, `{"temp": 72}`) {
|
||||
t.Errorf("messagesToPrompt() missing tool result content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToPrompt_EmptyMessages(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
got := p.messagesToPrompt(nil)
|
||||
if got != "" {
|
||||
t.Errorf("messagesToPrompt(nil) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToPrompt_OnlySystemMessages(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "System 1"},
|
||||
{Role: "system", Content: "System 2"},
|
||||
}
|
||||
got := p.messagesToPrompt(messages)
|
||||
if got != "" {
|
||||
t.Errorf("messagesToPrompt() with only system msgs = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- buildSystemPrompt tests ---
|
||||
|
||||
func TestBuildSystemPrompt_NoSystemNoTools(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
got := p.buildSystemPrompt(messages, nil)
|
||||
if got != "" {
|
||||
t.Errorf("buildSystemPrompt() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_SystemOnly(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
got := p.buildSystemPrompt(messages, nil)
|
||||
if got != "You are helpful." {
|
||||
t.Errorf("buildSystemPrompt() = %q, want %q", got, "You are helpful.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_MultipleSystemMessages(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "system", Content: "Be concise."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
got := p.buildSystemPrompt(messages, nil)
|
||||
if !strings.Contains(got, "You are helpful.") {
|
||||
t.Error("missing first system message")
|
||||
}
|
||||
if !strings.Contains(got, "Be concise.") {
|
||||
t.Error("missing second system message")
|
||||
}
|
||||
// Should be joined with double newline
|
||||
want := "You are helpful.\n\nBe concise."
|
||||
if got != want {
|
||||
t.Errorf("buildSystemPrompt() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_WithTools(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a location",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := p.buildSystemPrompt(messages, tools)
|
||||
if !strings.Contains(got, "You are helpful.") {
|
||||
t.Error("buildSystemPrompt() missing system message")
|
||||
}
|
||||
if !strings.Contains(got, "get_weather") {
|
||||
t.Error("buildSystemPrompt() missing tool definition")
|
||||
}
|
||||
if !strings.Contains(got, "Available Tools") {
|
||||
t.Error("buildSystemPrompt() missing tools header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
got := p.buildSystemPrompt(nil, tools)
|
||||
if !strings.Contains(got, "test_tool") {
|
||||
t.Error("should include tool definitions even without system messages")
|
||||
}
|
||||
}
|
||||
|
||||
// --- buildToolsPrompt tests ---
|
||||
|
||||
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
|
||||
}
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if strings.Contains(got, "skip_me") {
|
||||
t.Error("buildToolsPrompt() should skip non-function tools")
|
||||
}
|
||||
if !strings.Contains(got, "include_me") {
|
||||
t.Error("buildToolsPrompt() should include function tools")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
|
||||
}
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if !strings.Contains(got, "bare_tool") {
|
||||
t.Error("should include tool name")
|
||||
}
|
||||
if strings.Contains(got, "Description:") {
|
||||
t.Error("should not include Description: line when empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{
|
||||
Name: "no_params_tool",
|
||||
Description: "A tool with no parameters",
|
||||
}},
|
||||
}
|
||||
got := buildCLIToolsPrompt(tools)
|
||||
if strings.Contains(got, "Parameters:") {
|
||||
t.Error("should not include Parameters: section when nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- parseClaudeCliResponse tests ---
|
||||
|
||||
func TestParseClaudeCliResponse_TextOnly(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
output := `{"type":"result","subtype":"success","is_error":false,"result":"Hello, world!","session_id":"abc123","total_cost_usd":0.01,"duration_ms":500,"usage":{"input_tokens":10,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}`
|
||||
|
||||
resp, err := p.parseClaudeCliResponse(output)
|
||||
if err != nil {
|
||||
t.Fatalf("parseClaudeCliResponse() error = %v", err)
|
||||
}
|
||||
if resp.Content != "Hello, world!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello, world!")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls = %d, want 0", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeCliResponse_EmptyResult(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
output := `{"type":"result","subtype":"success","is_error":false,"result":"","session_id":"abc"}`
|
||||
|
||||
resp, err := p.parseClaudeCliResponse(output)
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if resp.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", resp.Content)
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeCliResponse_IsError(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
output := `{"type":"result","subtype":"error","is_error":true,"result":"Something went wrong","session_id":"abc"}`
|
||||
|
||||
_, err := p.parseClaudeCliResponse(output)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when is_error=true")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Something went wrong") {
|
||||
t.Errorf("error = %q, want to contain 'Something went wrong'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeCliResponse_NoUsage(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
output := `{"type":"result","subtype":"success","is_error":false,"result":"hi","session_id":"s"}`
|
||||
|
||||
resp, err := p.parseClaudeCliResponse(output)
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Errorf("Usage should be nil when no tokens, got %+v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeCliResponse_InvalidJSON(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
_, err := p.parseClaudeCliResponse("not json")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to parse claude cli response") {
|
||||
t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeCliResponse_WithToolCalls(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
output := `{"type":"result","subtype":"success","is_error":false,"result":"Let me check.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"Tokyo\\\"}\"}}]}","session_id":"abc123","total_cost_usd":0.01}`
|
||||
|
||||
resp, err := p.parseClaudeCliResponse(output)
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("ToolCalls = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
tc := resp.ToolCalls[0]
|
||||
if tc.Name != "get_weather" {
|
||||
t.Errorf("Name = %q, want %q", tc.Name, "get_weather")
|
||||
}
|
||||
if tc.Function == nil {
|
||||
t.Fatal("Function is nil")
|
||||
}
|
||||
if tc.Function.Name != "get_weather" {
|
||||
t.Errorf("Function.Name = %q, want %q", tc.Function.Name, "get_weather")
|
||||
}
|
||||
if tc.Arguments["location"] != "Tokyo" {
|
||||
t.Errorf("Arguments[location] = %v, want Tokyo", tc.Arguments["location"])
|
||||
}
|
||||
if strings.Contains(resp.Content, "tool_calls") {
|
||||
t.Errorf("Content should not contain tool_calls JSON, got %q", resp.Content)
|
||||
}
|
||||
if resp.Content != "Let me check." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Let me check.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeCliResponse_WhitespaceResult(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
output := `{"type":"result","subtype":"success","is_error":false,"result":" hello \n ","session_id":"s"}`
|
||||
|
||||
resp, err := p.parseClaudeCliResponse(output)
|
||||
if err != nil {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Errorf("Content = %q, want %q (should be trimmed)", resp.Content, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
// --- extractToolCalls tests ---
|
||||
|
||||
func TestExtractToolCalls_NoToolCalls(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
got := p.extractToolCalls("Just a regular response.")
|
||||
if len(got) != 0 {
|
||||
t.Errorf("extractToolCalls() = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls_WithToolCalls(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
text := `Here's the result:
|
||||
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}`
|
||||
|
||||
got := p.extractToolCalls(text)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("extractToolCalls() = %d, want 1", len(got))
|
||||
}
|
||||
if got[0].ID != "call_1" {
|
||||
t.Errorf("ID = %q, want %q", got[0].ID, "call_1")
|
||||
}
|
||||
if got[0].Name != "test" {
|
||||
t.Errorf("Name = %q, want %q", got[0].Name, "test")
|
||||
}
|
||||
if got[0].Type != "function" {
|
||||
t.Errorf("Type = %q, want %q", got[0].Type, "function")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls_InvalidJSON(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
got := p.extractToolCalls(`{"tool_calls":invalid}`)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("extractToolCalls() with invalid JSON = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls_MultipleToolCalls(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
text := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"/tmp/out\",\"content\":\"hello\"}"}}]}`
|
||||
|
||||
got := p.extractToolCalls(text)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("extractToolCalls() = %d, want 2", len(got))
|
||||
}
|
||||
if got[0].Name != "read_file" {
|
||||
t.Errorf("[0].Name = %q, want %q", got[0].Name, "read_file")
|
||||
}
|
||||
if got[1].Name != "write_file" {
|
||||
t.Errorf("[1].Name = %q, want %q", got[1].Name, "write_file")
|
||||
}
|
||||
// Verify arguments were parsed
|
||||
if got[0].Arguments["path"] != "/tmp/test" {
|
||||
t.Errorf("[0].Arguments[path] = %v, want /tmp/test", got[0].Arguments["path"])
|
||||
}
|
||||
if got[1].Arguments["content"] != "hello" {
|
||||
t.Errorf("[1].Arguments[content] = %v, want hello", got[1].Arguments["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls_UnmatchedBrace(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
got := p.extractToolCalls(`{"tool_calls":[{"id":"call_1"`)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("extractToolCalls() with unmatched brace = %d, want 0", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCalls_ToolCallArgumentsParsing(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{\"num\":42,\"flag\":true,\"name\":\"test\"}"}}]}`
|
||||
|
||||
got := p.extractToolCalls(text)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(got))
|
||||
}
|
||||
// Verify different argument types
|
||||
if got[0].Arguments["num"] != float64(42) {
|
||||
t.Errorf("Arguments[num] = %v (%T), want 42", got[0].Arguments["num"], got[0].Arguments["num"])
|
||||
}
|
||||
if got[0].Arguments["flag"] != true {
|
||||
t.Errorf("Arguments[flag] = %v, want true", got[0].Arguments["flag"])
|
||||
}
|
||||
if got[0].Arguments["name"] != "test" {
|
||||
t.Errorf("Arguments[name] = %v, want test", got[0].Arguments["name"])
|
||||
}
|
||||
// Verify raw arguments string is preserved in FunctionCall
|
||||
if got[0].Function.Arguments == "" {
|
||||
t.Error("Function.Arguments should contain raw JSON string")
|
||||
}
|
||||
}
|
||||
|
||||
// --- stripToolCallsJSON tests ---
|
||||
|
||||
func TestStripToolCallsJSON(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
text := `Let me check the weather.
|
||||
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}
|
||||
Done.`
|
||||
|
||||
got := p.stripToolCallsJSON(text)
|
||||
if strings.Contains(got, "tool_calls") {
|
||||
t.Errorf("should remove tool_calls JSON, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Let me check the weather.") {
|
||||
t.Errorf("should keep text before, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Done.") {
|
||||
t.Errorf("should keep text after, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripToolCallsJSON_NoToolCalls(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
text := "Just regular text."
|
||||
got := p.stripToolCallsJSON(text)
|
||||
if got != text {
|
||||
t.Errorf("stripToolCallsJSON() = %q, want %q", got, text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripToolCallsJSON_OnlyToolCalls(t *testing.T) {
|
||||
p := NewClaudeCliProvider("/workspace")
|
||||
text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{}"}}]}`
|
||||
got := p.stripToolCallsJSON(text)
|
||||
if got != "" {
|
||||
t.Errorf("stripToolCallsJSON() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- findMatchingBrace tests ---
|
||||
|
||||
func TestFindMatchingBrace(t *testing.T) {
|
||||
tests := []struct {
|
||||
text string
|
||||
pos int
|
||||
want int
|
||||
}{
|
||||
{`{"a":1}`, 0, 7},
|
||||
{`{"a":{"b":2}}`, 0, 13},
|
||||
{`text {"a":1} more`, 5, 12},
|
||||
{`{unclosed`, 0, 0}, // no match returns pos
|
||||
{`{}`, 0, 2}, // empty object
|
||||
{`{{{}}}`, 0, 6}, // deeply nested
|
||||
{`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := findMatchingBrace(tt.text, tt.pos)
|
||||
if got != tt.want {
|
||||
t.Errorf("findMatchingBrace(%q, %d) = %d, want %d", tt.text, tt.pos, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CodexHomeEnvVar is the environment variable that overrides the Codex CLI
|
||||
// home directory when resolving the codex auth.json credentials file.
|
||||
// Default: ~/.codex
|
||||
const CodexHomeEnvVar = "CODEX_HOME"
|
||||
|
||||
// CodexCliAuth represents the ~/.codex/auth.json file structure.
|
||||
type CodexCliAuth struct {
|
||||
Tokens struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id"`
|
||||
} `json:"tokens"`
|
||||
}
|
||||
|
||||
// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file.
|
||||
// Expiry is estimated as file modification time + 1 hour (same approach as moltbot).
|
||||
func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) {
|
||||
authPath, err := resolveCodexAuthPath()
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(authPath)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
var auth CodexCliAuth
|
||||
if err = json.Unmarshal(data, &auth); err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
if auth.Tokens.AccessToken == "" {
|
||||
return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath)
|
||||
}
|
||||
|
||||
stat, err := os.Stat(authPath)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(time.Hour)
|
||||
} else {
|
||||
expiresAt = stat.ModTime().Add(time.Hour)
|
||||
}
|
||||
|
||||
return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil
|
||||
}
|
||||
|
||||
// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json.
|
||||
// This allows the existing CodexProvider to reuse Codex CLI credentials.
|
||||
func CreateCodexCliTokenSource() func() (string, string, error) {
|
||||
return func() (string, string, error) {
|
||||
token, accountID, expiresAt, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("reading codex cli credentials: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return "", "", fmt.Errorf(
|
||||
"codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login",
|
||||
)
|
||||
}
|
||||
|
||||
return token, accountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCodexAuthPath() (string, error) {
|
||||
codexHome := os.Getenv(CodexHomeEnvVar)
|
||||
if codexHome == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting home dir: %w", err)
|
||||
}
|
||||
codexHome = filepath.Join(home, ".codex")
|
||||
}
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReadCodexCliCredentials_Valid(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{
|
||||
"tokens": {
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"account_id": "org-test123"
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
token, accountID, expiresAt, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadCodexCliCredentials() error: %v", err)
|
||||
}
|
||||
if token != "test-access-token" {
|
||||
t.Errorf("token = %q, want %q", token, "test-access-token")
|
||||
}
|
||||
if accountID != "org-test123" {
|
||||
t.Errorf("accountID = %q, want %q", accountID, "org-test123")
|
||||
}
|
||||
// Expiry should be within ~1 hour from now (file was just written)
|
||||
if expiresAt.Before(time.Now()) {
|
||||
t.Errorf("expiresAt = %v, should be in the future", expiresAt)
|
||||
}
|
||||
if expiresAt.After(time.Now().Add(2 * time.Hour)) {
|
||||
t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
// readCodexCliCredentialsErr calls ReadCodexCliCredentials and returns only the
|
||||
// error, for tests that only need to assert on failure.
|
||||
func readCodexCliCredentialsErr() error {
|
||||
_, _, _, err := ReadCodexCliCredentials() //nolint:dogsled
|
||||
return err
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_MissingFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
if err := readCodexCliCredentialsErr(); err == nil {
|
||||
t.Fatal("expected error for missing auth.json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_EmptyToken(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
if err := readCodexCliCredentialsErr(); err == nil {
|
||||
t.Fatal("expected error for empty access_token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
if err := os.WriteFile(authPath, []byte("not json"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
if err := readCodexCliCredentialsErr(); err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
token, accountID, _, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "tok123" {
|
||||
t.Errorf("token = %q, want %q", token, "tok123")
|
||||
}
|
||||
if accountID != "" {
|
||||
t.Errorf("accountID = %q, want empty", accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
customDir := filepath.Join(tmpDir, "custom-codex")
|
||||
if err := os.MkdirAll(customDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}`
|
||||
if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", customDir)
|
||||
|
||||
token, _, _, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "custom-token" {
|
||||
t.Errorf("token = %q, want %q", token, "custom-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCodexCliTokenSource_Valid(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
source := CreateCodexCliTokenSource()
|
||||
token, accountID, err := source()
|
||||
if err != nil {
|
||||
t.Fatalf("token source error: %v", err)
|
||||
}
|
||||
if token != "fresh-token" {
|
||||
t.Errorf("token = %q, want %q", token, "fresh-token")
|
||||
}
|
||||
if accountID != "acc" {
|
||||
t.Errorf("accountID = %q, want %q", accountID, "acc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCodexCliTokenSource_Expired(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set file modification time to 2 hours ago
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
if err := os.Chtimes(authPath, oldTime, oldTime); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
source := CreateCodexCliTokenSource()
|
||||
_, _, err := source()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for expired credentials")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
)
|
||||
|
||||
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
|
||||
type CodexCliProvider struct {
|
||||
command string
|
||||
workspace string
|
||||
}
|
||||
|
||||
// NewCodexCliProvider creates a new Codex CLI provider.
|
||||
func NewCodexCliProvider(workspace string) *CodexCliProvider {
|
||||
return &CodexCliProvider{
|
||||
command: "codex",
|
||||
workspace: workspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode.
|
||||
func (p *CodexCliProvider) Chat(
|
||||
ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
if p.command == "" {
|
||||
return nil, fmt.Errorf("codex command not configured")
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
args := []string{
|
||||
"exec",
|
||||
"--json",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
"--skip-git-repo-check",
|
||||
"--color", "never",
|
||||
}
|
||||
if model != "" && model != "codex-cli" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
if p.workspace != "" {
|
||||
args = append(args, "-C", p.workspace)
|
||||
}
|
||||
args = append(args, "-") // read prompt from stdin
|
||||
|
||||
cmd := exec.CommandContext(ctx, p.command, args...)
|
||||
cmd.Stdin = bytes.NewReader([]byte(prompt))
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// Execute the CLI through the shared isolation wrapper so external provider
|
||||
// processes honor the configured isolation policy.
|
||||
err := isolation.Run(cmd)
|
||||
|
||||
// Parse JSONL from stdout even if exit code is non-zero,
|
||||
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
|
||||
// but still produces valid JSONL output.
|
||||
if stdoutStr := stdout.String(); stdoutStr != "" {
|
||||
resp, parseErr := p.parseJSONLEvents(stdoutStr)
|
||||
if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if stderrStr := stderr.String(); stderrStr != "" {
|
||||
return nil, fmt.Errorf("codex cli error: %s", stderrStr)
|
||||
}
|
||||
return nil, fmt.Errorf("codex cli error: %w", err)
|
||||
}
|
||||
|
||||
return p.parseJSONLEvents(stdout.String())
|
||||
}
|
||||
|
||||
// GetDefaultModel returns the default model identifier.
|
||||
func (p *CodexCliProvider) GetDefaultModel() string {
|
||||
return "codex-cli"
|
||||
}
|
||||
|
||||
// buildPrompt converts messages to a prompt string for the Codex CLI.
|
||||
// System messages are prepended as instructions since Codex CLI has no --system-prompt flag.
|
||||
func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string {
|
||||
var systemParts []string
|
||||
var conversationParts []string
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
systemParts = append(systemParts, msg.Content)
|
||||
case "user":
|
||||
conversationParts = append(conversationParts, msg.Content)
|
||||
case "assistant":
|
||||
conversationParts = append(conversationParts, "Assistant: "+msg.Content)
|
||||
case "tool":
|
||||
conversationParts = append(conversationParts,
|
||||
fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
if len(systemParts) > 0 {
|
||||
sb.WriteString("## System Instructions\n\n")
|
||||
sb.WriteString(strings.Join(systemParts, "\n\n"))
|
||||
sb.WriteString("\n\n## Task\n\n")
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(buildCLIToolsPrompt(tools))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// Simplify single user message (no prefix)
|
||||
if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 {
|
||||
return conversationParts[0]
|
||||
}
|
||||
|
||||
sb.WriteString(strings.Join(conversationParts, "\n"))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// codexEvent represents a single JSONL event from `codex exec --json`.
|
||||
type codexEvent struct {
|
||||
Type string `json:"type"`
|
||||
ThreadID string `json:"thread_id,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Item *codexEventItem `json:"item,omitempty"`
|
||||
Usage *codexUsage `json:"usage,omitempty"`
|
||||
Error *codexEventErr `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type codexEventItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type codexUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CachedInputTokens int `json:"cached_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type codexEventErr struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// parseJSONLEvents processes the JSONL output from codex exec --json.
|
||||
func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) {
|
||||
var contentParts []string
|
||||
var usage *UsageInfo
|
||||
var lastError string
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event codexEvent
|
||||
if err := json.Unmarshal([]byte(line), &event); err != nil {
|
||||
continue // skip malformed lines
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "item.completed":
|
||||
if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" {
|
||||
contentParts = append(contentParts, event.Item.Text)
|
||||
}
|
||||
case "turn.completed":
|
||||
if event.Usage != nil {
|
||||
promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: event.Usage.OutputTokens,
|
||||
TotalTokens: promptTokens + event.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
case "error":
|
||||
lastError = event.Message
|
||||
case "turn.failed":
|
||||
if event.Error != nil {
|
||||
lastError = event.Error.Message
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastError != "" && len(contentParts) == 0 {
|
||||
return nil, fmt.Errorf("codex cli: %s", lastError)
|
||||
}
|
||||
|
||||
content := strings.Join(contentParts, "\n")
|
||||
|
||||
// Extract tool calls from response text (same pattern as ClaudeCliProvider)
|
||||
toolCalls := extractToolCallsFromText(content)
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
content = stripToolCallsFromText(content)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.TrimSpace(content),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
//go:build integration
|
||||
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
exec "os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI.
|
||||
// Run with: go test -tags=integration ./pkg/providers/...
|
||||
func TestIntegration_RealCodexCLI(t *testing.T) {
|
||||
path, err := exec.LookPath("codex")
|
||||
if err != nil {
|
||||
t.Skip("codex CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using codex CLI at: %s", path)
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run codex directly and verify our parser handles real output
|
||||
cmd := exec.Command("codex", "exec",
|
||||
"--json",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
"--skip-git-repo-check",
|
||||
"--color", "never",
|
||||
"-C", t.TempDir(),
|
||||
"-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// codex may write diagnostic noise to stderr but still produce valid output
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("codex CLI failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)]))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewCodexCliProvider("")
|
||||
resp, err := p.parseJSONLEvents(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
@@ -0,0 +1,595 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- JSONL Event Parsing Tests ---
|
||||
|
||||
func TestParseJSONLEvents_AgentMessage(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"thread.started","thread_id":"abc-123"}
|
||||
{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}}
|
||||
{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 150 {
|
||||
t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 170 {
|
||||
t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens)
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
toolCallText := `Let me read that file.
|
||||
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}`
|
||||
// Build valid JSONL by marshaling the event
|
||||
item := codexEvent{
|
||||
Type: "item.completed",
|
||||
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
|
||||
}
|
||||
itemJSON, _ := json.Marshal(item)
|
||||
usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}`
|
||||
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "read_file" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
|
||||
}
|
||||
if resp.ToolCalls[0].ID != "call_1" {
|
||||
t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
|
||||
}
|
||||
if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` {
|
||||
t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
// Content should have the tool call JSON stripped
|
||||
if strings.Contains(resp.Content, "tool_calls") {
|
||||
t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}`
|
||||
item := codexEvent{
|
||||
Type: "item.completed",
|
||||
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
|
||||
}
|
||||
itemJSON, _ := json.Marshal(item)
|
||||
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 2 {
|
||||
t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "read_file" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
|
||||
}
|
||||
if resp.ToolCalls[1].Name != "write_file" {
|
||||
t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file")
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MultipleMessages(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}}
|
||||
{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}}
|
||||
{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Content != "First part.\nSecond part." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ErrorEvent(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"thread.started","thread_id":"abc"}
|
||||
{"type":"turn.started"}
|
||||
{"type":"error","message":"token expired"}
|
||||
{"type":"turn.failed","error":{"message":"token expired"}}`
|
||||
|
||||
_, err := p.parseJSONLEvents(events)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "token expired") {
|
||||
t.Errorf("error = %q, want to contain 'token expired'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_TurnFailed(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"turn.failed","error":{"message":"rate limit exceeded"}}`
|
||||
|
||||
_, err := p.parseJSONLEvents(events)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "rate limit exceeded") {
|
||||
t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ErrorWithContent(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
// If there's an error but also content, return the content (partial success)
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}}
|
||||
{"type":"error","message":"connection reset"}
|
||||
{"type":"turn.failed","error":{"message":"connection reset"}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("should not error when content exists: %v", err)
|
||||
}
|
||||
if resp.Content != "Partial result." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Partial result.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_EmptyOutput(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
resp, err := p.parseJSONLEvents("")
|
||||
if err != nil {
|
||||
t.Fatalf("empty output should not error: %v", err)
|
||||
}
|
||||
if resp.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MalformedLines(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `not json at all
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}}
|
||||
another bad line
|
||||
{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("should skip malformed lines: %v", err)
|
||||
}
|
||||
if resp.Content != "Good line." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Good line.")
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 15 {
|
||||
t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_CommandExecution(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}}
|
||||
{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
// command_execution items should be skipped; only agent_message text is returned
|
||||
if resp.Content != "Found 2 files." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_NoUsage(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Prompt Building Tests ---
|
||||
|
||||
func TestBuildPrompt_SystemAsInstructions(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi there"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "## System Instructions") {
|
||||
t.Error("prompt should contain '## System Instructions'")
|
||||
}
|
||||
if !strings.Contains(prompt, "You are helpful.") {
|
||||
t.Error("prompt should contain system content")
|
||||
}
|
||||
if !strings.Contains(prompt, "## Task") {
|
||||
t.Error("prompt should contain '## Task'")
|
||||
}
|
||||
if !strings.Contains(prompt, "Hi there") {
|
||||
t.Error("prompt should contain user message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_NoSystem(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Just a question"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if strings.Contains(prompt, "## System Instructions") {
|
||||
t.Error("prompt should not contain system instructions header")
|
||||
}
|
||||
if prompt != "Just a question" {
|
||||
t.Errorf("prompt = %q, want %q", prompt, "Just a question")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_WithTools(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Get weather"},
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"city": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
if !strings.Contains(prompt, "## Available Tools") {
|
||||
t.Error("prompt should contain tools section")
|
||||
}
|
||||
if !strings.Contains(prompt, "get_weather") {
|
||||
t.Error("prompt should contain tool name")
|
||||
}
|
||||
if !strings.Contains(prompt, "Get current weather") {
|
||||
t.Error("prompt should contain tool description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_MultipleMessages(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi! How can I help?"},
|
||||
{Role: "user", Content: "Tell me about Go"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "Hello") {
|
||||
t.Error("prompt should contain first user message")
|
||||
}
|
||||
if !strings.Contains(prompt, "Assistant: Hi! How can I help?") {
|
||||
t.Error("prompt should contain assistant message with prefix")
|
||||
}
|
||||
if !strings.Contains(prompt, "Tell me about Go") {
|
||||
t.Error("prompt should contain second user message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_ToolResults(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "[Tool Result for call_1]") {
|
||||
t.Error("prompt should contain tool result")
|
||||
}
|
||||
if !strings.Contains(prompt, `{"temp": 72}`) {
|
||||
t.Error("prompt should contain tool result content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_SystemAndTools(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "Be concise."},
|
||||
{Role: "user", Content: "Do something"},
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "my_tool",
|
||||
Description: "A tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
// System instructions should come first
|
||||
sysIdx := strings.Index(prompt, "## System Instructions")
|
||||
toolIdx := strings.Index(prompt, "## Available Tools")
|
||||
taskIdx := strings.Index(prompt, "## Task")
|
||||
|
||||
if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 {
|
||||
t.Fatal("prompt should contain all sections")
|
||||
}
|
||||
if sysIdx >= taskIdx {
|
||||
t.Error("system instructions should come before task")
|
||||
}
|
||||
if taskIdx >= toolIdx {
|
||||
t.Error("task section should come before tools in the output")
|
||||
}
|
||||
}
|
||||
|
||||
// --- CLI Argument Tests ---
|
||||
|
||||
func TestCodexCliProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewCodexCliProvider("")
|
||||
if got := p.GetDefaultModel(); got != "codex-cli" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Mock CLI Integration Test ---
|
||||
|
||||
func createMockCodexCLI(t *testing.T, events []string) string {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("mock CLI scripts not supported on Windows")
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("#!/bin/bash\n")
|
||||
for _, event := range events {
|
||||
sb.WriteString(fmt.Sprintf("echo '%s'\n", event))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(sb.String()), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return scriptPath
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_Success(t *testing.T) {
|
||||
scriptPath := createMockCodexCLI(t, []string{
|
||||
`{"type":"thread.started","thread_id":"test-123"}`,
|
||||
`{"type":"turn.started"}`,
|
||||
`{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`,
|
||||
`{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`,
|
||||
})
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Mock response from Codex CLI" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 60 {
|
||||
t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 15 {
|
||||
t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_Error(t *testing.T) {
|
||||
scriptPath := createMockCodexCLI(t, []string{
|
||||
`{"type":"thread.started","thread_id":"test-err"}`,
|
||||
`{"type":"turn.started"}`,
|
||||
`{"type":"error","message":"auth token expired"}`,
|
||||
`{"type":"turn.failed","error":{"message":"auth token expired"}}`,
|
||||
})
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "auth token expired") {
|
||||
t.Errorf("error = %q, want to contain 'auth token expired'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("mock CLI scripts not supported on Windows")
|
||||
}
|
||||
// Mock script that captures args to verify model flag is passed
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
script := `#!/bin/bash
|
||||
# Write args to a file for verification
|
||||
echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `"
|
||||
echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}'
|
||||
echo '{"type":"turn.completed"}'`
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "/tmp/test-workspace",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "gpt-5.3-codex", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the args
|
||||
argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("reading args: %v", err)
|
||||
}
|
||||
args := string(argsData)
|
||||
|
||||
if !strings.Contains(args, "-m gpt-5.3-codex") {
|
||||
t.Errorf("args should contain model flag, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "-C /tmp/test-workspace") {
|
||||
t.Errorf("args should contain workspace flag, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "--json") {
|
||||
t.Errorf("args should contain --json, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
|
||||
t.Errorf("args should contain bypass flag, got: %s", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("mock CLI scripts not supported on Windows")
|
||||
}
|
||||
// Script that sleeps forever
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
script := "#!/bin/bash\nsleep 60"
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(ctx, messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error on canceled context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_EmptyCommand(t *testing.T) {
|
||||
p := &CodexCliProvider{command: ""}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty command")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Integration Test (requires real codex CLI with valid auth) ---
|
||||
|
||||
func TestCodexCliProvider_Integration(t *testing.T) {
|
||||
if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" {
|
||||
t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)")
|
||||
}
|
||||
|
||||
// Verify codex is available
|
||||
codexPath, err := exec.LookPath("codex")
|
||||
if err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: codexPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Respond with just the word 'hello' and nothing else."},
|
||||
}
|
||||
|
||||
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
lower := strings.ToLower(strings.TrimSpace(resp.Content))
|
||||
if !strings.Contains(lower, "hello") {
|
||||
t.Errorf("Content = %q, expected to contain 'hello'", resp.Content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
copilot "github.com/github/copilot-sdk/go"
|
||||
)
|
||||
|
||||
type GitHubCopilotProvider struct {
|
||||
uri string
|
||||
connectMode string // "stdio" or "grpc"
|
||||
|
||||
client *copilot.Client
|
||||
session *copilot.Session
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) {
|
||||
if connectMode == "" {
|
||||
connectMode = "grpc"
|
||||
}
|
||||
|
||||
switch connectMode {
|
||||
case "stdio":
|
||||
// TODO: Implement stdio mode for GitHub Copilot provider
|
||||
// See https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md for details
|
||||
return nil, fmt.Errorf("stdio mode not implemented for GitHub Copilot provider; please use 'grpc' mode instead")
|
||||
case "grpc":
|
||||
client := copilot.NewClient(&copilot.ClientOptions{
|
||||
CLIUrl: uri,
|
||||
})
|
||||
if err := client.Start(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"can't connect to Github Copilot: %w; `https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server` for details",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
session, err := client.CreateSession(context.Background(), &copilot.SessionConfig{
|
||||
Model: model,
|
||||
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
|
||||
Hooks: &copilot.SessionHooks{},
|
||||
})
|
||||
if err != nil {
|
||||
client.Stop()
|
||||
return nil, fmt.Errorf("create session failed: %w", err)
|
||||
}
|
||||
|
||||
return &GitHubCopilotProvider{
|
||||
uri: uri,
|
||||
connectMode: connectMode,
|
||||
client: client,
|
||||
session: session,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown connect mode: %s", connectMode)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GitHubCopilotProvider) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.client != nil {
|
||||
p.client.Stop()
|
||||
p.client = nil
|
||||
p.session = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GitHubCopilotProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
type tempMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
out := make([]tempMessage, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
out = append(out, tempMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
fullcontent, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal messages: %w", err)
|
||||
}
|
||||
p.mu.Lock()
|
||||
session := p.session
|
||||
p.mu.Unlock()
|
||||
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("provider closed")
|
||||
}
|
||||
|
||||
resp, err := session.SendAndWait(ctx, copilot.MessageOptions{
|
||||
Prompt: string(fullcontent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send message to copilot: %w", err)
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("empty response from copilot")
|
||||
}
|
||||
if resp.Data.Content == nil {
|
||||
return nil, fmt.Errorf("no content in copilot response")
|
||||
}
|
||||
content := *resp.Data.Content
|
||||
|
||||
return &LLMResponse{
|
||||
FinishReason: "stop",
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GitHubCopilotProvider) GetDefaultModel() string {
|
||||
return "gpt-4.1"
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// extractToolCallsFromText parses tool call JSON from response text.
|
||||
// Both ClaudeCliProvider and CodexCliProvider use this to extract
|
||||
// tool calls that the model outputs in its response text.
|
||||
func extractToolCallsFromText(text string) []ToolCall {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonStr := text[start:end]
|
||||
|
||||
var wrapper struct {
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ToolCall
|
||||
for _, tc := range wrapper.ToolCalls {
|
||||
var args map[string]any
|
||||
json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
|
||||
result = append(result, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// stripToolCallsFromText removes tool call JSON from response text.
|
||||
func stripToolCallsFromText(text string) string {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return text
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return text
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text[:start] + text[end:])
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
|
||||
func buildCLIToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(
|
||||
`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("Escaping rules (what to type in `function.arguments`):\n")
|
||||
sb.WriteString("- Use `\\n` to represent a real newline character.\n")
|
||||
sb.WriteString("- Use `\\\\n` to represent a literal backslash+n sequence (`\\n`).\n")
|
||||
sb.WriteString(
|
||||
"- `function.arguments` is a JSON-encoded string, so quotes/backslashes must be escaped in the outer payload.\n\n",
|
||||
)
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
|
||||
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)
|
||||
// and ensures both are populated consistently.
|
||||
func NormalizeToolCall(tc ToolCall) ToolCall {
|
||||
normalized := tc
|
||||
|
||||
// Ensure Name is populated from Function if not set
|
||||
if normalized.Name == "" && normalized.Function != nil {
|
||||
normalized.Name = normalized.Function.Name
|
||||
}
|
||||
|
||||
// Ensure Arguments is not nil
|
||||
if normalized.Arguments == nil {
|
||||
normalized.Arguments = map[string]any{}
|
||||
}
|
||||
|
||||
// Parse Arguments from Function.Arguments if not already set
|
||||
if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil {
|
||||
normalized.Arguments = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure Function is populated with consistent values
|
||||
argsJSON, _ := json.Marshal(normalized.Arguments)
|
||||
if normalized.Function == nil {
|
||||
normalized.Function = &FunctionCall{
|
||||
Name: normalized.Name,
|
||||
Arguments: string(argsJSON),
|
||||
}
|
||||
} else {
|
||||
if normalized.Function.Name == "" {
|
||||
normalized.Function.Name = normalized.Name
|
||||
}
|
||||
if normalized.Name == "" {
|
||||
normalized.Name = normalized.Function.Name
|
||||
}
|
||||
if normalized.Function.Arguments == "" {
|
||||
normalized.Function.Arguments = string(argsJSON)
|
||||
}
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package cliprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
ToolCall = protocoltypes.ToolCall
|
||||
FunctionCall = protocoltypes.FunctionCall
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
UsageInfo = protocoltypes.UsageInfo
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
)
|
||||
|
||||
type LLMProvider interface {
|
||||
Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error)
|
||||
GetDefaultModel() string
|
||||
}
|
||||
Reference in New Issue
Block a user