mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge PR #343: Add Google Antigravity provider and harden tool-call compatibility
This commit is contained in:
+164
-3
@@ -10,6 +10,7 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -373,6 +374,7 @@ func migrateHelp() {
|
||||
func agentCmd() {
|
||||
message := ""
|
||||
sessionKey := "cli:default"
|
||||
modelOverride := ""
|
||||
|
||||
args := os.Args[2:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
@@ -390,6 +392,11 @@ func agentCmd() {
|
||||
sessionKey = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--model", "-model":
|
||||
if i+1 < len(args) {
|
||||
modelOverride = args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -399,6 +406,10 @@ func agentCmd() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if modelOverride != "" {
|
||||
cfg.Agents.Defaults.Model = modelOverride
|
||||
}
|
||||
|
||||
provider, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating provider: %v\n", err)
|
||||
@@ -793,6 +804,8 @@ func authCmd() {
|
||||
authLogoutCmd()
|
||||
case "status":
|
||||
authStatusCmd()
|
||||
case "models":
|
||||
authModelsCmd()
|
||||
default:
|
||||
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
|
||||
authHelp()
|
||||
@@ -804,15 +817,18 @@ func authHelp() {
|
||||
fmt.Println(" login Login via OAuth or paste token")
|
||||
fmt.Println(" logout Remove stored credentials")
|
||||
fmt.Println(" status Show current auth status")
|
||||
fmt.Println(" models List available Antigravity models")
|
||||
fmt.Println()
|
||||
fmt.Println("Login options:")
|
||||
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
|
||||
fmt.Println(" --provider <name> Provider to login with (openai, anthropic, google-antigravity)")
|
||||
fmt.Println(" --device-code Use device code flow (for headless environments)")
|
||||
fmt.Println()
|
||||
fmt.Println("Examples:")
|
||||
fmt.Println(" picoclaw auth login --provider openai")
|
||||
fmt.Println(" picoclaw auth login --provider openai --device-code")
|
||||
fmt.Println(" picoclaw auth login --provider anthropic")
|
||||
fmt.Println(" picoclaw auth login --provider google-antigravity")
|
||||
fmt.Println(" picoclaw auth models")
|
||||
fmt.Println(" picoclaw auth logout --provider openai")
|
||||
fmt.Println(" picoclaw auth status")
|
||||
}
|
||||
@@ -836,7 +852,7 @@ func authLoginCmd() {
|
||||
|
||||
if provider == "" {
|
||||
fmt.Println("Error: --provider is required")
|
||||
fmt.Println("Supported providers: openai, anthropic")
|
||||
fmt.Println("Supported providers: openai, anthropic, google-antigravity")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -845,9 +861,11 @@ func authLoginCmd() {
|
||||
authLoginOpenAI(useDeviceCode)
|
||||
case "anthropic":
|
||||
authLoginPasteToken(provider)
|
||||
case "google-antigravity", "antigravity":
|
||||
authLoginGoogleAntigravity()
|
||||
default:
|
||||
fmt.Printf("Unsupported provider: %s\n", provider)
|
||||
fmt.Println("Supported providers: openai, anthropic")
|
||||
fmt.Println("Supported providers: openai, anthropic, google-antigravity")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -887,6 +905,88 @@ func authLoginOpenAI(useDeviceCode bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func authLoginGoogleAntigravity() {
|
||||
cfg := auth.GoogleAntigravityOAuthConfig()
|
||||
|
||||
cred, err := auth.LoginBrowser(cfg)
|
||||
if err != nil {
|
||||
fmt.Printf("Login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cred.Provider = "google-antigravity"
|
||||
|
||||
// Fetch user email from Google userinfo
|
||||
email, err := fetchGoogleUserEmail(cred.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: could not fetch email: %v\n", err)
|
||||
} else {
|
||||
cred.Email = email
|
||||
fmt.Printf("Email: %s\n", email)
|
||||
}
|
||||
|
||||
// Fetch Cloud Code Assist project ID
|
||||
projectID, err := providers.FetchAntigravityProjectID(cred.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: could not fetch project ID: %v\n", err)
|
||||
fmt.Println("You may need Google Cloud Code Assist enabled on your account.")
|
||||
} else {
|
||||
cred.ProjectID = projectID
|
||||
fmt.Printf("Project: %s\n", projectID)
|
||||
}
|
||||
|
||||
if err := auth.SetCredential("google-antigravity", cred); err != nil {
|
||||
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
appCfg.Providers.Antigravity.AuthMethod = "oauth"
|
||||
if appCfg.Agents.Defaults.Provider == "" {
|
||||
appCfg.Agents.Defaults.Provider = "antigravity"
|
||||
}
|
||||
if appCfg.Agents.Defaults.Provider == "antigravity" || appCfg.Agents.Defaults.Provider == "google-antigravity" {
|
||||
appCfg.Agents.Defaults.Model = "gemini-3-flash"
|
||||
}
|
||||
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Google Antigravity login successful!")
|
||||
fmt.Println("Config updated: provider=antigravity, model=gemini-3-flash")
|
||||
fmt.Println("Try it: picoclaw agent -m \"Hello world\"")
|
||||
}
|
||||
|
||||
func fetchGoogleUserEmail(accessToken string) (string, error) {
|
||||
req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("userinfo request failed: %s", string(body))
|
||||
}
|
||||
|
||||
var userInfo struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userInfo.Email, nil
|
||||
}
|
||||
|
||||
func authLoginPasteToken(provider string) {
|
||||
cred, err := auth.LoginPasteToken(provider, os.Stdin)
|
||||
if err != nil {
|
||||
@@ -942,6 +1042,8 @@ func authLogoutCmd() {
|
||||
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||
case "anthropic":
|
||||
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||
case "google-antigravity", "antigravity":
|
||||
appCfg.Providers.Antigravity.AuthMethod = ""
|
||||
}
|
||||
config.SaveConfig(getConfigPath(), appCfg)
|
||||
}
|
||||
@@ -957,6 +1059,7 @@ func authLogoutCmd() {
|
||||
if err == nil {
|
||||
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||
appCfg.Providers.Antigravity.AuthMethod = ""
|
||||
config.SaveConfig(getConfigPath(), appCfg)
|
||||
}
|
||||
|
||||
@@ -993,12 +1096,70 @@ func authStatusCmd() {
|
||||
if cred.AccountID != "" {
|
||||
fmt.Printf(" Account: %s\n", cred.AccountID)
|
||||
}
|
||||
if cred.Email != "" {
|
||||
fmt.Printf(" Email: %s\n", cred.Email)
|
||||
}
|
||||
if cred.ProjectID != "" {
|
||||
fmt.Printf(" Project: %s\n", cred.ProjectID)
|
||||
}
|
||||
if !cred.ExpiresAt.IsZero() {
|
||||
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func authModelsCmd() {
|
||||
cred, err := auth.GetCredential("google-antigravity")
|
||||
if err != nil || cred == nil {
|
||||
fmt.Println("Not logged in to Google Antigravity.")
|
||||
fmt.Println("Run: picoclaw auth login --provider google-antigravity")
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh token if needed
|
||||
if cred.NeedsRefresh() && cred.RefreshToken != "" {
|
||||
oauthCfg := auth.GoogleAntigravityOAuthConfig()
|
||||
refreshed, refreshErr := auth.RefreshAccessToken(cred, oauthCfg)
|
||||
if refreshErr == nil {
|
||||
cred = refreshed
|
||||
_ = auth.SetCredential("google-antigravity", cred)
|
||||
}
|
||||
}
|
||||
|
||||
projectID := cred.ProjectID
|
||||
if projectID == "" {
|
||||
fmt.Println("No project ID stored. Try logging in again.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Fetching models for project: %s\n\n", projectID)
|
||||
|
||||
models, err := providers.FetchAntigravityModels(cred.AccessToken, projectID)
|
||||
if err != nil {
|
||||
fmt.Printf("Error fetching models: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
fmt.Println("No models available.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Available Antigravity Models:")
|
||||
fmt.Println("-----------------------------")
|
||||
for _, m := range models {
|
||||
status := "✓"
|
||||
if m.IsExhausted {
|
||||
status = "✗ (quota exhausted)"
|
||||
}
|
||||
name := m.ID
|
||||
if m.DisplayName != "" {
|
||||
name = fmt.Sprintf("%s (%s)", m.ID, m.DisplayName)
|
||||
}
|
||||
fmt.Printf(" %s %s\n", status, name)
|
||||
}
|
||||
}
|
||||
|
||||
func getConfigPath() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,72 @@
|
||||
# Using Antigravity Provider in PicoClaw
|
||||
|
||||
This guide explains how to set up and use the **Antigravity** (Google Cloud Code Assist) provider in PicoClaw.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. A Google account.
|
||||
2. Google Cloud Code Assist enabled (usually available via the "Gemini for Google Cloud" onboarding).
|
||||
|
||||
## 1. Authentication
|
||||
|
||||
To authenticate with Antigravity, run the following command:
|
||||
|
||||
```bash
|
||||
picoclaw auth login --provider antigravity
|
||||
```
|
||||
|
||||
### Manual Authentication (Headless/VPS)
|
||||
If you are running on a server (Coolify/Docker) and cannot reach `localhost`, follow these steps:
|
||||
1. Run the command above.
|
||||
2. Copy the URL provided and open it in your local browser.
|
||||
3. Complete the login.
|
||||
4. Your browser will redirect to a `localhost:51121` URL (which will fail to load).
|
||||
5. **Copy that final URL** from your browser's address bar.
|
||||
6. **Paste it back into the terminal** where PicoClaw is waiting.
|
||||
|
||||
PicoClaw will extract the authorization code and complete the process automatically.
|
||||
|
||||
## 2. Managing Models
|
||||
|
||||
### List Available Models
|
||||
To see which models your project has access to and check their quotas:
|
||||
|
||||
```bash
|
||||
picoclaw auth models
|
||||
```
|
||||
|
||||
### Switch Models
|
||||
You can change the default model in `~/.picoclaw/config.json` or override it via the CLI:
|
||||
|
||||
```bash
|
||||
# Override for a single command
|
||||
picoclaw agent -m "Hello" --model claude-opus-4-6-thinking
|
||||
```
|
||||
|
||||
## 3. Real-world Usage (Coolify/Docker)
|
||||
|
||||
If you are deploying via Coolify or Docker, follow these steps to test:
|
||||
|
||||
1. **Branch**: Use the `feat/antigravity-provider` branch.
|
||||
2. **Environment Variables**:
|
||||
* `PICOCLAW_AGENTS_DEFAULTS_PROVIDER=antigravity`
|
||||
* `PICOCLAW_AGENTS_DEFAULTS_MODEL=gemini-3-flash`
|
||||
3. **Authentication persistence**:
|
||||
If you've logged in locally, you can copy your credentials to the server:
|
||||
```bash
|
||||
scp ~/.picoclaw/auth-profiles.json user@your-server:~/.picoclaw/
|
||||
```
|
||||
*Alternatively*, run the `auth login` command once on the server if you have terminal access.
|
||||
|
||||
## 4. Troubleshooting
|
||||
|
||||
* **Empty Response**: If a model returns an empty reply, it may be restricted for your project. Try `gemini-3-flash` or `claude-opus-4-6-thinking`.
|
||||
* **429 Rate Limit**: Antigravity has strict quotas. PicoClaw will display the "reset time" in the error message if you hit a limit.
|
||||
* **404 Not Found**: Ensure you are using a model ID from the `picoclaw auth models` list. Use the short ID (e.g., `gemini-3-flash`) not the full path.
|
||||
|
||||
## 5. Summary of Working Models
|
||||
|
||||
Based on testing, the following models are most reliable:
|
||||
* `gemini-3-flash` (Fast, highly available)
|
||||
* `gemini-2.5-flash-lite` (Lightweight)
|
||||
* `claude-opus-4-6-thinking` (Powerful, includes reasoning)
|
||||
+49
-14
@@ -189,16 +189,7 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
|
||||
systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary
|
||||
}
|
||||
|
||||
//This fix prevents the session memory from LLM failure due to elimination of toolu_IDs required from LLM
|
||||
// --- INICIO DEL FIX ---
|
||||
//Diegox-17
|
||||
for len(history) > 0 && (history[0].Role == "tool") {
|
||||
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
|
||||
map[string]interface{}{"role": history[0].Role})
|
||||
history = history[1:]
|
||||
}
|
||||
//Diegox-17
|
||||
// --- FIN DEL FIX ---
|
||||
history = sanitizeHistoryForProvider(history)
|
||||
|
||||
messages = append(messages, providers.Message{
|
||||
Role: "system",
|
||||
@@ -207,14 +198,58 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
|
||||
|
||||
messages = append(messages, history...)
|
||||
|
||||
messages = append(messages, providers.Message{
|
||||
Role: "user",
|
||||
Content: currentMessage,
|
||||
})
|
||||
if strings.TrimSpace(currentMessage) != "" {
|
||||
messages = append(messages, providers.Message{
|
||||
Role: "user",
|
||||
Content: currentMessage,
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func sanitizeHistoryForProvider(history []providers.Message) []providers.Message {
|
||||
if len(history) == 0 {
|
||||
return history
|
||||
}
|
||||
|
||||
sanitized := make([]providers.Message, 0, len(history))
|
||||
for _, msg := range history {
|
||||
switch msg.Role {
|
||||
case "tool":
|
||||
if len(sanitized) == 0 {
|
||||
logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]interface{}{})
|
||||
continue
|
||||
}
|
||||
last := sanitized[len(sanitized)-1]
|
||||
if last.Role != "assistant" || len(last.ToolCalls) == 0 {
|
||||
logger.DebugCF("agent", "Dropping orphaned tool message", map[string]interface{}{})
|
||||
continue
|
||||
}
|
||||
sanitized = append(sanitized, msg)
|
||||
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if len(sanitized) == 0 {
|
||||
logger.DebugCF("agent", "Dropping assistant tool-call turn at history start", map[string]interface{}{})
|
||||
continue
|
||||
}
|
||||
prev := sanitized[len(sanitized)-1]
|
||||
if prev.Role != "user" && prev.Role != "tool" {
|
||||
logger.DebugCF("agent", "Dropping assistant tool-call turn with invalid predecessor", map[string]interface{}{"prev_role": prev.Role})
|
||||
continue
|
||||
}
|
||||
}
|
||||
sanitized = append(sanitized, msg)
|
||||
|
||||
default:
|
||||
sanitized = append(sanitized, msg)
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) AddToolResult(messages []providers.Message, toolCallID, toolName, result string) []providers.Message {
|
||||
messages = append(messages, providers.Message{
|
||||
Role: "tool",
|
||||
|
||||
+61
-9
@@ -605,15 +605,20 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
break
|
||||
}
|
||||
|
||||
// Log tool calls
|
||||
toolNames := make([]string, 0, len(response.ToolCalls))
|
||||
normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls))
|
||||
for _, tc := range response.ToolCalls {
|
||||
normalizedToolCalls = append(normalizedToolCalls, normalizeProviderToolCall(tc))
|
||||
}
|
||||
|
||||
// Log tool calls
|
||||
toolNames := make([]string, 0, len(normalizedToolCalls))
|
||||
for _, tc := range normalizedToolCalls {
|
||||
toolNames = append(toolNames, tc.Name)
|
||||
}
|
||||
logger.InfoCF("agent", "LLM requested tool calls",
|
||||
map[string]interface{}{
|
||||
"tools": toolNames,
|
||||
"count": len(response.ToolCalls),
|
||||
"count": len(normalizedToolCalls),
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
@@ -622,14 +627,22 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
}
|
||||
for _, tc := range response.ToolCalls {
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
||||
thoughtSignature := ""
|
||||
if tc.Function != nil {
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
}
|
||||
|
||||
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
Function: &providers.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
Name: tc.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -639,7 +652,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
al.sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||
|
||||
// Execute tool calls
|
||||
for _, tc := range response.ToolCalls {
|
||||
for _, tc := range normalizedToolCalls {
|
||||
// Log tool call with arguments preview
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
@@ -702,6 +715,45 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
return finalContent, iteration, nil
|
||||
}
|
||||
|
||||
func normalizeProviderToolCall(tc providers.ToolCall) providers.ToolCall {
|
||||
normalized := tc
|
||||
|
||||
if normalized.Name == "" && normalized.Function != nil {
|
||||
normalized.Name = normalized.Function.Name
|
||||
}
|
||||
|
||||
if normalized.Arguments == nil {
|
||||
normalized.Arguments = map[string]interface{}{}
|
||||
}
|
||||
|
||||
if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" {
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil {
|
||||
normalized.Arguments = parsed
|
||||
}
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(normalized.Arguments)
|
||||
if normalized.Function == nil {
|
||||
normalized.Function = &providers.FunctionCall{
|
||||
Name: normalized.Name,
|
||||
Arguments: string(argsJSON),
|
||||
}
|
||||
} else {
|
||||
if normalized.Function.Name == "" {
|
||||
normalized.Function.Name = normalized.Name
|
||||
}
|
||||
if normalized.Name == "" {
|
||||
normalized.Name = normalized.Function.Name
|
||||
}
|
||||
if normalized.Function.Arguments == "" {
|
||||
normalized.Function.Arguments = string(argsJSON)
|
||||
}
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||
// Use ContextualTool interface instead of type assertions
|
||||
|
||||
+118
-23
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -19,11 +21,13 @@ import (
|
||||
)
|
||||
|
||||
type OAuthProviderConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Originator string
|
||||
Port int
|
||||
Issuer string
|
||||
ClientID string
|
||||
ClientSecret string // Required for Google OAuth (confidential client)
|
||||
TokenURL string // Override token endpoint (Google uses a different URL than issuer)
|
||||
Scopes string
|
||||
Originator string
|
||||
Port int
|
||||
}
|
||||
|
||||
func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||
@@ -36,6 +40,30 @@ func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// GoogleAntigravityOAuthConfig returns the OAuth configuration for Google Cloud Code Assist (Antigravity).
|
||||
// Client credentials are the same ones used by OpenCode/pi-ai for Cloud Code Assist access.
|
||||
func GoogleAntigravityOAuthConfig() OAuthProviderConfig {
|
||||
// These are the same client credentials used by the OpenCode antigravity plugin.
|
||||
clientID := decodeBase64("MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==")
|
||||
clientSecret := decodeBase64("R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=")
|
||||
return OAuthProviderConfig{
|
||||
Issuer: "https://accounts.google.com/o/oauth2/v2",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/cclog https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
Port: 51121,
|
||||
}
|
||||
}
|
||||
|
||||
func decodeBase64(s string) string {
|
||||
data, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
@@ -101,8 +129,17 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code")
|
||||
fmt.Println("Waiting for authentication in browser...")
|
||||
fmt.Printf("Wait! If you are in a headless environment (like Coolify/VPS) and cannot reach localhost:%d,\n", cfg.Port)
|
||||
fmt.Println("please complete the login in your local browser and then PASTE the final redirect URL (or just the code) here.")
|
||||
fmt.Println("Waiting for authentication (browser or manual paste)...")
|
||||
|
||||
// Start manual input in a goroutine
|
||||
manualCh := make(chan string)
|
||||
go func() {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, _ := reader.ReadString('\n')
|
||||
manualCh <- strings.TrimSpace(input)
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
@@ -110,6 +147,22 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
return nil, result.err
|
||||
}
|
||||
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
||||
case manualInput := <-manualCh:
|
||||
if manualInput == "" {
|
||||
return nil, fmt.Errorf("manual input cancelled")
|
||||
}
|
||||
// Extract code from URL if it's a full URL
|
||||
code := manualInput
|
||||
if strings.Contains(manualInput, "?") {
|
||||
u, err := url.Parse(manualInput)
|
||||
if err == nil {
|
||||
code = u.Query().Get("code")
|
||||
}
|
||||
}
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("could not find authorization code in input")
|
||||
}
|
||||
return exchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI)
|
||||
case <-time.After(5 * time.Minute):
|
||||
return nil, fmt.Errorf("authentication timed out after 5 minutes")
|
||||
}
|
||||
@@ -269,8 +322,16 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
"refresh_token": {cred.RefreshToken},
|
||||
"scope": {"openid profile email"},
|
||||
}
|
||||
if cfg.ClientSecret != "" {
|
||||
data.Set("client_secret", cfg.ClientSecret)
|
||||
}
|
||||
|
||||
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||
tokenURL := cfg.Issuer + "/oauth/token"
|
||||
if cfg.TokenURL != "" {
|
||||
tokenURL = cfg.TokenURL
|
||||
}
|
||||
|
||||
resp, err := http.PostForm(tokenURL, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
@@ -291,6 +352,12 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
if cred.Email != "" && refreshed.Email == "" {
|
||||
refreshed.Email = cred.Email
|
||||
}
|
||||
if cred.ProjectID != "" && refreshed.ProjectID == "" {
|
||||
refreshed.ProjectID = cred.ProjectID
|
||||
}
|
||||
return refreshed, nil
|
||||
}
|
||||
|
||||
@@ -300,21 +367,35 @@ func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
||||
|
||||
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"state": {state},
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
}
|
||||
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
|
||||
params.Set("originator", "picoclaw")
|
||||
|
||||
isGoogle := strings.Contains(strings.ToLower(cfg.Issuer), "accounts.google.com")
|
||||
if isGoogle {
|
||||
// Google OAuth requires these for refresh token support
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("prompt", "consent")
|
||||
} else {
|
||||
// OpenAI-specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
|
||||
params.Set("originator", "picoclaw")
|
||||
}
|
||||
if cfg.Originator != "" {
|
||||
params.Set("originator", cfg.Originator)
|
||||
}
|
||||
}
|
||||
if cfg.Originator != "" {
|
||||
params.Set("originator", cfg.Originator)
|
||||
|
||||
// Google uses /auth path, OpenAI uses /oauth/authorize
|
||||
if isGoogle {
|
||||
return cfg.Issuer + "/auth?" + params.Encode()
|
||||
}
|
||||
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
|
||||
}
|
||||
@@ -327,8 +408,22 @@ func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect
|
||||
"client_id": {cfg.ClientID},
|
||||
"code_verifier": {codeVerifier},
|
||||
}
|
||||
if cfg.ClientSecret != "" {
|
||||
data.Set("client_secret", cfg.ClientSecret)
|
||||
}
|
||||
|
||||
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||
tokenURL := cfg.Issuer + "/oauth/token"
|
||||
if cfg.TokenURL != "" {
|
||||
tokenURL = cfg.TokenURL
|
||||
}
|
||||
|
||||
// Determine provider name from config
|
||||
provider := "openai"
|
||||
if cfg.TokenURL != "" && strings.Contains(cfg.TokenURL, "googleapis.com") {
|
||||
provider = "google-antigravity"
|
||||
}
|
||||
|
||||
resp, err := http.PostForm(tokenURL, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
|
||||
}
|
||||
@@ -339,7 +434,7 @@ func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
return parseTokenResponse(body, "openai")
|
||||
return parseTokenResponse(body, provider)
|
||||
}
|
||||
|
||||
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
|
||||
@@ -14,6 +14,8 @@ type AuthCredential struct {
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
Provider string `json:"provider"`
|
||||
AuthMethod string `json:"auth_method"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
type AuthStore struct {
|
||||
|
||||
@@ -182,6 +182,7 @@ type ProvidersConfig struct {
|
||||
Cerebras ProviderConfig `json:"cerebras"`
|
||||
VolcEngine ProviderConfig `json:"volcengine"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
Antigravity ProviderConfig `json:"antigravity"`
|
||||
Qwen ProviderConfig `json:"qwen"`
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,827 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityDefaultModel = "gemini-3-flash"
|
||||
antigravityUserAgent = "antigravity"
|
||||
antigravityXGoogClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||
antigravityVersion = "1.15.8"
|
||||
)
|
||||
|
||||
// AntigravityProvider implements LLMProvider using Google's Cloud Code Assist (Antigravity) API.
|
||||
// This provider authenticates via Google OAuth and provides access to models like Claude and Gemini
|
||||
// through Google's infrastructure.
|
||||
type AntigravityProvider struct {
|
||||
tokenSource func() (string, string, error) // Returns (accessToken, projectID, error)
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAntigravityProvider creates a new Antigravity provider using stored auth credentials.
|
||||
func NewAntigravityProvider() *AntigravityProvider {
|
||||
return &AntigravityProvider{
|
||||
tokenSource: createAntigravityTokenSource(),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Chat implements LLMProvider.Chat using the Cloud Code Assist v1internal API.
|
||||
// The v1internal endpoint wraps the standard Gemini request in an envelope with
|
||||
// project, model, request, requestType, userAgent, and requestId fields.
|
||||
func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
accessToken, projectID, err := p.tokenSource()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("antigravity auth: %w", err)
|
||||
}
|
||||
|
||||
if model == "" || model == "antigravity" || model == "google-antigravity" {
|
||||
model = antigravityDefaultModel
|
||||
}
|
||||
// Strip provider prefixes if present
|
||||
model = strings.TrimPrefix(model, "google-antigravity/")
|
||||
model = strings.TrimPrefix(model, "antigravity/")
|
||||
|
||||
logger.DebugCF("provider.antigravity", "Starting chat", map[string]interface{}{
|
||||
"model": model,
|
||||
"project": projectID,
|
||||
"requestId": fmt.Sprintf("agent-%d-%s", time.Now().UnixMilli(), randomString(9)),
|
||||
})
|
||||
|
||||
// Build the inner Gemini-format request
|
||||
innerRequest := p.buildRequest(messages, tools, model, options)
|
||||
|
||||
// Wrap in v1internal envelope (matches pi-ai SDK format)
|
||||
envelope := map[string]interface{}{
|
||||
"project": projectID,
|
||||
"model": model,
|
||||
"request": innerRequest,
|
||||
"requestType": "agent",
|
||||
"userAgent": antigravityUserAgent,
|
||||
"requestId": fmt.Sprintf("agent-%d-%s", time.Now().UnixMilli(), randomString(9)),
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
// Build API URL — uses Cloud Code Assist v1internal streaming endpoint
|
||||
apiURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", antigravityBaseURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
// Headers matching the pi-ai SDK antigravity format
|
||||
clientMetadata, _ := json.Marshal(map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
})
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("antigravity/%s linux/amd64", antigravityVersion))
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityXGoogClient)
|
||||
req.Header.Set("Client-Metadata", string(clientMetadata))
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("antigravity API call: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF("provider.antigravity", "API call failed", map[string]interface{}{
|
||||
"status_code": resp.StatusCode,
|
||||
"response": string(respBody),
|
||||
"model": model,
|
||||
})
|
||||
|
||||
return nil, p.parseAntigravityError(resp.StatusCode, respBody)
|
||||
}
|
||||
|
||||
// Response is always SSE from streamGenerateContent — each line is "data: {...}"
|
||||
// with a "response" wrapper containing the standard Gemini response
|
||||
llmResp, err := p.parseSSEResponse(string(respBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for empty response (some models might return valid success but empty text)
|
||||
if llmResp.Content == "" && len(llmResp.ToolCalls) == 0 {
|
||||
return nil, fmt.Errorf("antigravity: model returned an empty response (this model might be invalid or restricted)")
|
||||
}
|
||||
|
||||
return llmResp, nil
|
||||
}
|
||||
|
||||
// GetDefaultModel returns the default model identifier.
|
||||
func (p *AntigravityProvider) GetDefaultModel() string {
|
||||
return antigravityDefaultModel
|
||||
}
|
||||
|
||||
// --- Request building ---
|
||||
|
||||
type antigravityRequest struct {
|
||||
Contents []antigravityContent `json:"contents"`
|
||||
Tools []antigravityTool `json:"tools,omitempty"`
|
||||
SystemPrompt *antigravitySystemPrompt `json:"systemInstruction,omitempty"`
|
||||
Config *antigravityGenConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type antigravityContent struct {
|
||||
Role string `json:"role"`
|
||||
Parts []antigravityPart `json:"parts"`
|
||||
}
|
||||
|
||||
type antigravityPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
|
||||
FunctionCall *antigravityFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *antigravityFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type antigravityFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]interface{} `json:"args"`
|
||||
}
|
||||
|
||||
type antigravityFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type antigravityTool struct {
|
||||
FunctionDeclarations []antigravityFuncDecl `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type antigravityFuncDecl struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type antigravitySystemPrompt struct {
|
||||
Parts []antigravityPart `json:"parts"`
|
||||
}
|
||||
|
||||
type antigravityGenConfig struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) buildRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) antigravityRequest {
|
||||
req := antigravityRequest{}
|
||||
toolCallNames := make(map[string]string)
|
||||
|
||||
// Build contents from messages
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
req.SystemPrompt = &antigravitySystemPrompt{
|
||||
Parts: []antigravityPart{{Text: msg.Content}},
|
||||
}
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
// Tool result
|
||||
req.Contents = append(req.Contents, antigravityContent{
|
||||
Role: "user",
|
||||
Parts: []antigravityPart{{
|
||||
FunctionResponse: &antigravityFunctionResponse{
|
||||
Name: toolName,
|
||||
Response: map[string]interface{}{
|
||||
"result": msg.Content,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
} else {
|
||||
req.Contents = append(req.Contents, antigravityContent{
|
||||
Role: "user",
|
||||
Parts: []antigravityPart{{Text: msg.Content}},
|
||||
})
|
||||
}
|
||||
case "assistant":
|
||||
content := antigravityContent{
|
||||
Role: "model",
|
||||
}
|
||||
if msg.Content != "" {
|
||||
content.Parts = append(content.Parts, antigravityPart{Text: msg.Content})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
|
||||
if toolName == "" {
|
||||
logger.WarnCF("provider.antigravity", "Skipping tool call with empty name in history", map[string]interface{}{
|
||||
"tool_call_id": tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if tc.ID != "" {
|
||||
toolCallNames[tc.ID] = toolName
|
||||
}
|
||||
content.Parts = append(content.Parts, antigravityPart{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
ThoughtSignatureSnake: thoughtSignature,
|
||||
FunctionCall: &antigravityFunctionCall{
|
||||
Name: toolName,
|
||||
Args: toolArgs,
|
||||
},
|
||||
})
|
||||
}
|
||||
if len(content.Parts) > 0 {
|
||||
req.Contents = append(req.Contents, content)
|
||||
}
|
||||
case "tool":
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
req.Contents = append(req.Contents, antigravityContent{
|
||||
Role: "user",
|
||||
Parts: []antigravityPart{{
|
||||
FunctionResponse: &antigravityFunctionResponse{
|
||||
Name: toolName,
|
||||
Response: map[string]interface{}{
|
||||
"result": msg.Content,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Build tools (sanitize schemas for Gemini compatibility)
|
||||
if len(tools) > 0 {
|
||||
var funcDecls []antigravityFuncDecl
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
params := sanitizeSchemaForGemini(t.Function.Parameters)
|
||||
funcDecls = append(funcDecls, antigravityFuncDecl{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: params,
|
||||
})
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
req.Tools = []antigravityTool{{FunctionDeclarations: funcDecls}}
|
||||
}
|
||||
}
|
||||
|
||||
// Generation config
|
||||
config := &antigravityGenConfig{}
|
||||
if val, ok := options["max_tokens"]; ok {
|
||||
if maxTokens, ok := val.(int); ok && maxTokens > 0 {
|
||||
config.MaxOutputTokens = maxTokens
|
||||
} else if maxTokens, ok := val.(float64); ok && maxTokens > 0 {
|
||||
config.MaxOutputTokens = int(maxTokens)
|
||||
}
|
||||
}
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
config.Temperature = temp
|
||||
}
|
||||
if config.MaxOutputTokens > 0 || config.Temperature > 0 {
|
||||
req.Config = config
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func normalizeStoredToolCall(tc ToolCall) (string, map[string]interface{}, string) {
|
||||
name := tc.Name
|
||||
args := tc.Arguments
|
||||
thoughtSignature := ""
|
||||
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
} else if tc.Function != nil {
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
}
|
||||
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
|
||||
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
|
||||
args = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return name, args, thoughtSignature
|
||||
}
|
||||
|
||||
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
|
||||
if toolCallID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
|
||||
return name
|
||||
}
|
||||
|
||||
return inferToolNameFromCallID(toolCallID)
|
||||
}
|
||||
|
||||
func inferToolNameFromCallID(toolCallID string) string {
|
||||
if !strings.HasPrefix(toolCallID, "call_") {
|
||||
return toolCallID
|
||||
}
|
||||
|
||||
rest := strings.TrimPrefix(toolCallID, "call_")
|
||||
if idx := strings.LastIndex(rest, "_"); idx > 0 {
|
||||
candidate := rest[:idx]
|
||||
if candidate != "" {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
return toolCallID
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
type antigravityJSONResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
ThoughtSignatureSnake string `json:"thought_signature,omitempty"`
|
||||
FunctionCall *antigravityFunctionCall `json:"functionCall,omitempty"`
|
||||
} `json:"parts"`
|
||||
Role string `json:"role"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) parseJSONResponse(body []byte) (*LLMResponse, error) {
|
||||
var resp antigravityJSONResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, fmt.Errorf("parsing antigravity response: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Candidates) == 0 {
|
||||
return nil, fmt.Errorf("antigravity: no candidates in response")
|
||||
}
|
||||
|
||||
candidate := resp.Candidates[0]
|
||||
var contentParts []string
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
contentParts = append(contentParts, part.Text)
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
argumentsJSON, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: fmt.Sprintf("call_%s_%d", part.FunctionCall.Name, time.Now().UnixNano()),
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: part.FunctionCall.Args,
|
||||
Function: &FunctionCall{
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
ThoughtSignature: extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
if candidate.FinishReason == "MAX_TOKENS" {
|
||||
finishReason = "length"
|
||||
}
|
||||
|
||||
var usage *UsageInfo
|
||||
if resp.UsageMetadata.TotalTokenCount > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: resp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: resp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: resp.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.Join(contentParts, ""),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error) {
|
||||
var contentParts []string
|
||||
var toolCalls []ToolCall
|
||||
var usage *UsageInfo
|
||||
var finishReason string
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// v1internal SSE wraps the Gemini response in a "response" field
|
||||
var sseChunk struct {
|
||||
Response antigravityJSONResponse `json:"response"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &sseChunk); err != nil {
|
||||
continue
|
||||
}
|
||||
resp := sseChunk.Response
|
||||
|
||||
for _, candidate := range resp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
contentParts = append(contentParts, part.Text)
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
argumentsJSON, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: fmt.Sprintf("call_%s_%d", part.FunctionCall.Name, time.Now().UnixNano()),
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: part.FunctionCall.Args,
|
||||
Function: &FunctionCall{
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
ThoughtSignature: extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
if candidate.FinishReason != "" {
|
||||
finishReason = candidate.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
if resp.UsageMetadata.TotalTokenCount > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: resp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: resp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: resp.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mappedFinish := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
mappedFinish = "tool_calls"
|
||||
}
|
||||
if finishReason == "MAX_TOKENS" {
|
||||
mappedFinish = "length"
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.Join(contentParts, ""),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: mappedFinish,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string {
|
||||
if thoughtSignature != "" {
|
||||
return thoughtSignature
|
||||
}
|
||||
if thoughtSignatureSnake != "" {
|
||||
return thoughtSignatureSnake
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --- Schema sanitization ---
|
||||
|
||||
// Google/Gemini doesn't support many JSON Schema keywords that other providers accept.
|
||||
var geminiUnsupportedKeywords = map[string]bool{
|
||||
"patternProperties": true,
|
||||
"additionalProperties": true,
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
"examples": true,
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"multipleOf": true,
|
||||
"pattern": true,
|
||||
"format": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"uniqueItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
}
|
||||
|
||||
func sanitizeSchemaForGemini(schema map[string]interface{}) map[string]interface{} {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range schema {
|
||||
if geminiUnsupportedKeywords[k] {
|
||||
continue
|
||||
}
|
||||
// Recursively sanitize nested objects
|
||||
switch val := v.(type) {
|
||||
case map[string]interface{}:
|
||||
result[k] = sanitizeSchemaForGemini(val)
|
||||
case []interface{}:
|
||||
sanitized := make([]interface{}, len(val))
|
||||
for i, item := range val {
|
||||
if m, ok := item.(map[string]interface{}); ok {
|
||||
sanitized[i] = sanitizeSchemaForGemini(m)
|
||||
} else {
|
||||
sanitized[i] = item
|
||||
}
|
||||
}
|
||||
result[k] = sanitized
|
||||
default:
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure top-level has type: "object" if properties are present
|
||||
if _, hasProps := result["properties"]; hasProps {
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "object"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// --- Token source ---
|
||||
|
||||
func createAntigravityTokenSource() func() (string, string, error) {
|
||||
return func() (string, string, error) {
|
||||
cred, err := auth.GetCredential("google-antigravity")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return "", "", fmt.Errorf("no credentials for google-antigravity. Run: picoclaw auth login --provider google-antigravity")
|
||||
}
|
||||
|
||||
// Refresh if needed
|
||||
if cred.NeedsRefresh() && cred.RefreshToken != "" {
|
||||
oauthCfg := auth.GoogleAntigravityOAuthConfig()
|
||||
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
refreshed.Email = cred.Email
|
||||
if refreshed.ProjectID == "" {
|
||||
refreshed.ProjectID = cred.ProjectID
|
||||
}
|
||||
if err := auth.SetCredential("google-antigravity", refreshed); err != nil {
|
||||
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||
}
|
||||
cred = refreshed
|
||||
}
|
||||
|
||||
if cred.IsExpired() {
|
||||
return "", "", fmt.Errorf("antigravity credentials expired. Run: picoclaw auth login --provider google-antigravity")
|
||||
}
|
||||
|
||||
projectID := cred.ProjectID
|
||||
if projectID == "" {
|
||||
// Try to fetch project ID from API
|
||||
fetchedID, err := FetchAntigravityProjectID(cred.AccessToken)
|
||||
if err != nil {
|
||||
logger.WarnCF("provider.antigravity", "Could not fetch project ID, using fallback", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
projectID = "rising-fact-p41fc" // Default fallback (same as OpenCode)
|
||||
} else {
|
||||
projectID = fetchedID
|
||||
cred.ProjectID = projectID
|
||||
_ = auth.SetCredential("google-antigravity", cred)
|
||||
}
|
||||
}
|
||||
|
||||
return cred.AccessToken, projectID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAntigravityProjectID retrieves the Google Cloud project ID from the loadCodeAssist endpoint.
|
||||
func FetchAntigravityProjectID(accessToken string) (string, error) {
|
||||
reqBody, _ := json.Marshal(map[string]interface{}{
|
||||
"metadata": map[string]interface{}{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("POST", antigravityBaseURL+"/v1internal:loadCodeAssist", bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", antigravityUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityXGoogClient)
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("loadCodeAssist failed: %s", string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if result.CloudAICompanionProject == "" {
|
||||
return "", fmt.Errorf("no project ID in loadCodeAssist response")
|
||||
}
|
||||
|
||||
return result.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
// FetchAntigravityModels fetches available models from the Cloud Code Assist API.
|
||||
func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelInfo, error) {
|
||||
reqBody, _ := json.Marshal(map[string]interface{}{
|
||||
"project": projectID,
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("POST", antigravityBaseURL+"/v1internal:fetchAvailableModels", bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", antigravityUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityXGoogClient)
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("fetchAvailableModels failed (HTTP %d): %s", resp.StatusCode, truncateString(string(body), 200))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Models map[string]struct {
|
||||
DisplayName string `json:"displayName"`
|
||||
QuotaInfo struct {
|
||||
RemainingFraction interface{} `json:"remainingFraction"`
|
||||
ResetTime string `json:"resetTime"`
|
||||
IsExhausted bool `json:"isExhausted"`
|
||||
} `json:"quotaInfo"`
|
||||
} `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parsing models response: %w", err)
|
||||
}
|
||||
|
||||
var models []AntigravityModelInfo
|
||||
for id, info := range result.Models {
|
||||
models = append(models, AntigravityModelInfo{
|
||||
ID: id,
|
||||
DisplayName: info.DisplayName,
|
||||
IsExhausted: info.QuotaInfo.IsExhausted,
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure gemini-3-flash-preview and gemini-3-flash are in the list if they aren't already
|
||||
hasFlashPreview := false
|
||||
hasFlash := false
|
||||
for _, m := range models {
|
||||
if m.ID == "gemini-3-flash-preview" {
|
||||
hasFlashPreview = true
|
||||
}
|
||||
if m.ID == "gemini-3-flash" {
|
||||
hasFlash = true
|
||||
}
|
||||
}
|
||||
if !hasFlashPreview {
|
||||
models = append(models, AntigravityModelInfo{
|
||||
ID: "gemini-3-flash-preview",
|
||||
DisplayName: "Gemini 3 Flash (Preview)",
|
||||
})
|
||||
}
|
||||
if !hasFlash {
|
||||
models = append(models, AntigravityModelInfo{
|
||||
ID: "gemini-3-flash",
|
||||
DisplayName: "Gemini 3 Flash",
|
||||
})
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
type AntigravityModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
IsExhausted bool `json:"is_exhausted"`
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
func randomString(n int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) parseAntigravityError(statusCode int, body []byte) error {
|
||||
var errResp struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
Details []map[string]interface{} `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &errResp); err != nil {
|
||||
return fmt.Errorf("antigravity API error (HTTP %d): %s", statusCode, truncateString(string(body), 500))
|
||||
}
|
||||
|
||||
msg := errResp.Error.Message
|
||||
if statusCode == 429 {
|
||||
// Try to extract quota reset info
|
||||
for _, detail := range errResp.Error.Details {
|
||||
if typeVal, ok := detail["@type"].(string); ok && strings.HasSuffix(typeVal, "ErrorInfo") {
|
||||
if metadata, ok := detail["metadata"].(map[string]interface{}); ok {
|
||||
if delay, ok := metadata["quotaResetDelay"].(string); ok {
|
||||
return fmt.Errorf("antigravity rate limit exceeded: %s (reset in %s)", msg, delay)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("antigravity rate limit exceeded: %s", msg)
|
||||
}
|
||||
|
||||
return fmt.Errorf("antigravity API error (%s): %s", errResp.Error.Status, msg)
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) {
|
||||
p := &AntigravityProvider{}
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_read_file_123",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_read_file_123",
|
||||
Content: "ok",
|
||||
},
|
||||
}
|
||||
|
||||
req := p.buildRequest(messages, nil, "", nil)
|
||||
if len(req.Contents) != 2 {
|
||||
t.Fatalf("expected 2 contents, got %d", len(req.Contents))
|
||||
}
|
||||
|
||||
modelPart := req.Contents[0].Parts[0]
|
||||
if modelPart.FunctionCall == nil {
|
||||
t.Fatal("expected functionCall in assistant message")
|
||||
}
|
||||
if modelPart.FunctionCall.Name != "read_file" {
|
||||
t.Fatalf("expected functionCall name read_file, got %q", modelPart.FunctionCall.Name)
|
||||
}
|
||||
if got := modelPart.FunctionCall.Args["path"]; got != "README.md" {
|
||||
t.Fatalf("expected functionCall args[path] to be README.md, got %v", got)
|
||||
}
|
||||
|
||||
toolPart := req.Contents[1].Parts[0]
|
||||
if toolPart.FunctionResponse == nil {
|
||||
t.Fatal("expected functionResponse in tool message")
|
||||
}
|
||||
if toolPart.FunctionResponse.Name != "read_file" {
|
||||
t.Fatalf("expected functionResponse name read_file, got %q", toolPart.FunctionResponse.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) {
|
||||
got := resolveToolResponseName("call_search_docs_999", map[string]string{})
|
||||
if got != "search_docs" {
|
||||
t.Fatalf("expected inferred tool name search_docs, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -132,8 +132,9 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
@@ -159,18 +160,11 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
thoughtSignature := ""
|
||||
|
||||
// Handle OpenAI format with nested function object
|
||||
if tc.Type == "function" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else if tc.Function != nil {
|
||||
// Legacy format without type field
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
@@ -179,7 +173,13 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tc.ID,
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: &FunctionCall{
|
||||
Name: name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
})
|
||||
@@ -347,6 +347,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
case "antigravity", "google-antigravity":
|
||||
return NewAntigravityProvider(), nil
|
||||
|
||||
case "volcengine", "doubao":
|
||||
if cfg.Providers.VolcEngine.APIKey != "" {
|
||||
|
||||
@@ -11,8 +11,9 @@ type ToolCall struct {
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
ThoughtSignature string `json:"thought_signature,omitempty"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
|
||||
+53
-7
@@ -83,15 +83,20 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
|
||||
break
|
||||
}
|
||||
|
||||
// 5. Log tool calls
|
||||
toolNames := make([]string, 0, len(response.ToolCalls))
|
||||
normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls))
|
||||
for _, tc := range response.ToolCalls {
|
||||
normalizedToolCalls = append(normalizedToolCalls, normalizeProviderToolCall(tc))
|
||||
}
|
||||
|
||||
// 5. Log tool calls
|
||||
toolNames := make([]string, 0, len(normalizedToolCalls))
|
||||
for _, tc := range normalizedToolCalls {
|
||||
toolNames = append(toolNames, tc.Name)
|
||||
}
|
||||
logger.InfoCF("toolloop", "LLM requested tool calls",
|
||||
map[string]any{
|
||||
"tools": toolNames,
|
||||
"count": len(response.ToolCalls),
|
||||
"count": len(normalizedToolCalls),
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
@@ -100,11 +105,13 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
}
|
||||
for _, tc := range response.ToolCalls {
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
||||
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
Function: &providers.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
@@ -114,7 +121,7 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 7. Execute tool calls
|
||||
for _, tc := range response.ToolCalls {
|
||||
for _, tc := range normalizedToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
@@ -152,3 +159,42 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
|
||||
Iterations: iteration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeProviderToolCall(tc providers.ToolCall) providers.ToolCall {
|
||||
normalized := tc
|
||||
|
||||
if normalized.Name == "" && normalized.Function != nil {
|
||||
normalized.Name = normalized.Function.Name
|
||||
}
|
||||
|
||||
if normalized.Arguments == nil {
|
||||
normalized.Arguments = map[string]interface{}{}
|
||||
}
|
||||
|
||||
if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" {
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil {
|
||||
normalized.Arguments = parsed
|
||||
}
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(normalized.Arguments)
|
||||
if normalized.Function == nil {
|
||||
normalized.Function = &providers.FunctionCall{
|
||||
Name: normalized.Name,
|
||||
Arguments: string(argsJSON),
|
||||
}
|
||||
} else {
|
||||
if normalized.Function.Name == "" {
|
||||
normalized.Function.Name = normalized.Name
|
||||
}
|
||||
if normalized.Name == "" {
|
||||
normalized.Name = normalized.Function.Name
|
||||
}
|
||||
if normalized.Function.Arguments == "" {
|
||||
normalized.Function.Arguments = string(argsJSON)
|
||||
}
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user