mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into fix/max-payload-size-in-web-fetch
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -249,10 +250,8 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
|
||||
}
|
||||
|
||||
// Check tracked source files (bootstrap + memory).
|
||||
for _, p := range cb.sourcePaths() {
|
||||
if cb.fileChangedSince(p) {
|
||||
return true
|
||||
}
|
||||
if slices.ContainsFunc(cb.sourcePaths(), cb.fileChangedSince) {
|
||||
return true
|
||||
}
|
||||
|
||||
// --- Skills directory (handled separately from sourcePaths) ---
|
||||
|
||||
@@ -404,11 +404,11 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan string, goroutines*iterations)
|
||||
|
||||
for g := 0; g < goroutines; g++ {
|
||||
for g := range goroutines {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
for i := range iterations {
|
||||
result := cb.BuildSystemPromptWithCache()
|
||||
if result == "" {
|
||||
errs <- "empty prompt returned"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -51,7 +52,12 @@ func NewAgentInstance(
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
|
||||
|
||||
+67
-8
@@ -9,6 +9,7 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -98,7 +99,7 @@ func registerSharedTools(
|
||||
}
|
||||
|
||||
// Web tools
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
@@ -112,10 +113,18 @@ func registerSharedTools(
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
}); searchTool != nil {
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes))
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
}
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
@@ -574,11 +583,36 @@ func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, chan
|
||||
return
|
||||
}
|
||||
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
// Use a short timeout so the goroutine does not block indefinitely when
|
||||
// the outbound bus is full. Reasoning output is best-effort; dropping it
|
||||
// is acceptable to avoid goroutine accumulation.
|
||||
pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer pubCancel()
|
||||
|
||||
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channelName,
|
||||
ChatID: channelID,
|
||||
Content: reasoningContent,
|
||||
})
|
||||
}); err != nil {
|
||||
// Treat context.DeadlineExceeded / context.Canceled as expected
|
||||
// (bus full under load, or parent canceled). Check the error
|
||||
// itself rather than ctx.Err(), because pubCtx may time out
|
||||
// (5 s) while the parent ctx is still active.
|
||||
// Also treat ErrBusClosed as expected — it occurs during normal
|
||||
// shutdown when the bus is closed before all goroutines finish.
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, bus.ErrBusClosed) {
|
||||
logger.DebugCF("agent", "Reasoning publish skipped (timeout/cancel)", map[string]any{
|
||||
"channel": channelName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
logger.WarnCF("agent", "Failed to publish reasoning (best-effort)", map[string]any{
|
||||
"channel": channelName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runLLMIteration executes the LLM call loop with tool handling.
|
||||
@@ -666,10 +700,35 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
isContextError := strings.Contains(errMsg, "token") ||
|
||||
strings.Contains(errMsg, "context") ||
|
||||
|
||||
// Check if this is a network/HTTP timeout — not a context window error.
|
||||
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
|
||||
strings.Contains(errMsg, "deadline exceeded") ||
|
||||
strings.Contains(errMsg, "client.timeout") ||
|
||||
strings.Contains(errMsg, "timed out") ||
|
||||
strings.Contains(errMsg, "timeout exceeded")
|
||||
|
||||
// Detect real context window / token limit errors, excluding network timeouts.
|
||||
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
|
||||
strings.Contains(errMsg, "context window") ||
|
||||
strings.Contains(errMsg, "maximum context length") ||
|
||||
strings.Contains(errMsg, "token limit") ||
|
||||
strings.Contains(errMsg, "too many tokens") ||
|
||||
strings.Contains(errMsg, "max_tokens") ||
|
||||
strings.Contains(errMsg, "invalidparameter") ||
|
||||
strings.Contains(errMsg, "length")
|
||||
strings.Contains(errMsg, "prompt is too long") ||
|
||||
strings.Contains(errMsg, "request too large"))
|
||||
|
||||
if isTimeoutError && retry < maxRetries {
|
||||
backoff := time.Duration(retry+1) * 5 * time.Second
|
||||
logger.WarnCF("agent", "Timeout error, retrying after backoff", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
"backoff": backoff.String(),
|
||||
})
|
||||
time.Sleep(backoff)
|
||||
continue
|
||||
}
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
|
||||
|
||||
+56
-14
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -187,13 +188,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(toolsList, "mock_custom")
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
@@ -262,13 +257,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(toolsList, "mock_custom")
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
@@ -797,4 +786,57 @@ func TestHandleReasoning(t *testing.T) {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns promptly when bus is full", func(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
|
||||
// Fill the outbound bus buffer until a publish would block.
|
||||
// Use a short timeout to detect when the buffer is full,
|
||||
// rather than hardcoding the buffer size.
|
||||
for i := 0; ; i++ {
|
||||
fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{
|
||||
Channel: "filler",
|
||||
ChatID: "filler",
|
||||
Content: fmt.Sprintf("filler-%d", i),
|
||||
})
|
||||
fillCancel()
|
||||
if err != nil {
|
||||
// Buffer is full (timed out trying to send).
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Use a short-deadline parent context to bound the test.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
al.handleReasoning(ctx, "should timeout", "slack", "channel-full")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// handleReasoning uses a 5s internal timeout, but the parent ctx
|
||||
// expires in 500ms. It should return within ~500ms, not 5s.
|
||||
if elapsed > 2*time.Second {
|
||||
t.Fatalf("handleReasoning blocked too long (%v); expected prompt return", elapsed)
|
||||
}
|
||||
|
||||
// Drain the bus and verify the reasoning message was NOT published
|
||||
// (it should have been dropped due to timeout).
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer drainCancel()
|
||||
foundReasoning := false
|
||||
for {
|
||||
msg, ok := msgBus.SubscribeOutbound(drainCtx)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if msg.Content == "should timeout" {
|
||||
foundReasoning = true
|
||||
}
|
||||
}
|
||||
if foundReasoning {
|
||||
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+1
-1
@@ -111,7 +111,7 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
|
||||
var sb strings.Builder
|
||||
first := true
|
||||
|
||||
for i := 0; i < days; i++ {
|
||||
for i := range days {
|
||||
date := time.Now().AddDate(0, 0, -i)
|
||||
dateStr := date.Format("20060102") // YYYYMMDD
|
||||
monthDir := dateStr[:6] // YYYYMM
|
||||
|
||||
+64
-8
@@ -66,7 +66,8 @@ func decodeBase64(s string) string {
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
// GenerateState generates a random state string for OAuth CSRF protection.
|
||||
func GenerateState() (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
@@ -80,7 +81,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
return nil, fmt.Errorf("generating PKCE: %w", err)
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating state: %w", err)
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
|
||||
fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)
|
||||
|
||||
if err := openBrowser(authURL); err != nil {
|
||||
if err := OpenBrowser(authURL); err != nil {
|
||||
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
|
||||
}
|
||||
|
||||
@@ -153,7 +154,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
if result.err != nil {
|
||||
return nil, result.err
|
||||
}
|
||||
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
||||
return ExchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
||||
case manualInput := <-manualCh:
|
||||
if manualInput == "" {
|
||||
return nil, fmt.Errorf("manual input canceled")
|
||||
@@ -169,7 +170,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("could not find authorization code in input")
|
||||
}
|
||||
return exchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI)
|
||||
return ExchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI)
|
||||
case <-time.After(5 * time.Minute):
|
||||
return nil, fmt.Errorf("authentication timed out after 5 minutes")
|
||||
}
|
||||
@@ -186,6 +187,59 @@ type deviceCodeResponse struct {
|
||||
Interval int
|
||||
}
|
||||
|
||||
// DeviceCodeInfo holds the device code information returned by the OAuth provider.
|
||||
type DeviceCodeInfo struct {
|
||||
DeviceAuthID string `json:"device_auth_id"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerifyURL string `json:"verify_url"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// RequestDeviceCode requests a device code from the OAuth provider.
|
||||
// Returns the info needed for the user to authenticate in a browser.
|
||||
func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) {
|
||||
reqBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": cfg.ClientID,
|
||||
})
|
||||
|
||||
resp, err := http.Post(
|
||||
cfg.Issuer+"/api/accounts/deviceauth/usercode",
|
||||
"application/json",
|
||||
strings.NewReader(string(reqBody)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("requesting device code: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||
}
|
||||
|
||||
deviceResp, err := parseDeviceCodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing device code response: %w", err)
|
||||
}
|
||||
|
||||
if deviceResp.Interval < 1 {
|
||||
deviceResp.Interval = 5
|
||||
}
|
||||
|
||||
return &DeviceCodeInfo{
|
||||
DeviceAuthID: deviceResp.DeviceAuthID,
|
||||
UserCode: deviceResp.UserCode,
|
||||
VerifyURL: cfg.Issuer + "/codex/device",
|
||||
Interval: deviceResp.Interval,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PollDeviceCodeOnce makes a single poll attempt to check if the user has authenticated.
|
||||
// Returns (credential, nil) on success, (nil, nil) if still pending, or (nil, err) on failure.
|
||||
func PollDeviceCodeOnce(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
|
||||
return pollDeviceCode(cfg, deviceAuthID, userCode)
|
||||
}
|
||||
|
||||
func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
|
||||
var raw struct {
|
||||
DeviceAuthID string `json:"device_auth_id"`
|
||||
@@ -318,7 +372,7 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au
|
||||
}
|
||||
|
||||
redirectURI := cfg.Issuer + "/deviceauth/callback"
|
||||
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
|
||||
return ExchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
|
||||
}
|
||||
|
||||
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
@@ -410,7 +464,8 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
||||
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
|
||||
}
|
||||
|
||||
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||
// ExchangeCodeForTokens exchanges an authorization code for tokens.
|
||||
func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
@@ -552,7 +607,8 @@ func base64URLDecode(s string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
// OpenBrowser opens the given URL in the user's default browser.
|
||||
func OpenBrowser(url string) error {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return exec.Command("open", url).Start()
|
||||
|
||||
@@ -219,9 +219,9 @@ func TestExchangeCodeForTokens(t *testing.T) {
|
||||
Port: 1455,
|
||||
}
|
||||
|
||||
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
|
||||
cred, err := ExchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
|
||||
if err != nil {
|
||||
t.Fatalf("exchangeCodeForTokens() error: %v", err)
|
||||
t.Fatalf("ExchangeCodeForTokens() error: %v", err)
|
||||
}
|
||||
|
||||
if cred.AccessToken != "mock-access-token" {
|
||||
|
||||
+3
-3
@@ -67,7 +67,7 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
|
||||
// Fill the buffer
|
||||
ctx := context.Background()
|
||||
for i := 0; i < defaultBusBufferSize; i++ {
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
@@ -154,7 +154,7 @@ func TestConcurrentPublishClose(t *testing.T) {
|
||||
wg.Add(numGoroutines + 1)
|
||||
|
||||
// Spawn many goroutines trying to publish
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
for range numGoroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Use a short timeout context so we don't block forever after close
|
||||
@@ -194,7 +194,7 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill the buffer
|
||||
for i := 0; i < defaultBusBufferSize; i++ {
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
|
||||
+80
-27
@@ -1,7 +1,5 @@
|
||||
# PicoClaw Channel System Refactor: Complete Development Guide
|
||||
# PicoClaw Channel System: Complete Development Guide
|
||||
|
||||
> **Branch**: `refactor/channel-system`
|
||||
> **Status**: Active development (~40 commits)
|
||||
> **Scope**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
|
||||
|
||||
---
|
||||
@@ -46,6 +44,8 @@ pkg/channels/
|
||||
pkg/channels/
|
||||
├── base.go # BaseChannel shared abstraction layer
|
||||
├── interfaces.go # Optional capability interfaces (TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
|
||||
├── README.md # English documentation
|
||||
├── README.zh.md # Chinese documentation
|
||||
├── media.go # MediaSender optional interface
|
||||
├── webhook.go # WebhookHandler, HealthChecker optional interfaces
|
||||
├── errors.go # Sentinel errors (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
|
||||
@@ -60,7 +60,7 @@ pkg/channels/
|
||||
├── discord/
|
||||
│ ├── init.go
|
||||
│ └── discord.go
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
|
||||
│ └── ...
|
||||
|
||||
pkg/bus/
|
||||
@@ -111,7 +111,7 @@ pkg/identity/
|
||||
|-----------|-------------|
|
||||
| **Sub-package Isolation** | Each channel is a standalone Go sub-package, depending on `BaseChannel` and interfaces from the `channels` parent package |
|
||||
| **Factory Registration** | Sub-packages self-register via `init()`, Manager looks up factories by name, eliminating import coupling |
|
||||
| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`), discovered by Manager via runtime type assertions |
|
||||
| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`), discovered by Manager via runtime type assertions |
|
||||
| **Structured Messages** | Peer, MessageID, and SenderInfo promoted from Metadata to first-class fields on InboundMessage |
|
||||
| **Error Classification** | Channels return sentinel errors (`ErrRateLimit`, `ErrTemporary`, etc.), Manager uses these to determine retry strategy |
|
||||
| **Centralized Orchestration** | Rate limiting, message splitting, retries, and Typing/Reaction/Placeholder management are all handled by Manager and BaseChannel; channels only need to implement Send |
|
||||
@@ -145,6 +145,7 @@ After refactoring, these files have been removed and code moved to corresponding
|
||||
| _(did not exist)_ | `pkg/channels/interfaces.go` | New optional capability interfaces |
|
||||
| _(did not exist)_ | `pkg/channels/media.go` | New MediaSender interface |
|
||||
| _(did not exist)_ | `pkg/channels/webhook.go` | New WebhookHandler/HealthChecker |
|
||||
| _(did not exist)_ | `pkg/channels/whatsapp_native/` | New WhatsApp native mode (whatsmeow) |
|
||||
| _(did not exist)_ | `pkg/channels/split.go` | New message splitting (migrated from utils) |
|
||||
| _(did not exist)_ | `pkg/bus/types.go` | New structured message types |
|
||||
| _(did not exist)_ | `pkg/media/store.go` | New media file lifecycle management |
|
||||
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
cfg.Channels.Telegram.AllowFrom, // Allow list
|
||||
channels.WithMaxMessageLength(4096), // Platform message length limit
|
||||
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // Group trigger config
|
||||
channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // Reasoning chain routing
|
||||
)
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
|
||||
matrixCfg.AllowFrom, // Allow list
|
||||
channels.WithMaxMessageLength(65536), // Matrix message length limit
|
||||
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // Reasoning chain routing (optional)
|
||||
)
|
||||
|
||||
return &MatrixChannel{
|
||||
@@ -666,6 +669,32 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
|
||||
}
|
||||
```
|
||||
|
||||
#### PlaceholderCapable — Placeholder Messages
|
||||
|
||||
```go
|
||||
// If the platform supports sending placeholder messages (e.g. "Thinking... 💭"),
|
||||
// and the channel also implements MessageEditor, then Manager's preSend will
|
||||
// automatically edit the placeholder into the final response on outbound.
|
||||
// SendPlaceholder checks PlaceholderConfig.Enabled internally;
|
||||
// returning ("", nil) means skip.
|
||||
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
cfg := c.config.Channels.Matrix.Placeholder
|
||||
if !cfg.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
text := cfg.Text
|
||||
if text == "" {
|
||||
text = "Thinking... 💭"
|
||||
}
|
||||
// Call Matrix API to send placeholder message
|
||||
msg, err := c.sendText(ctx, chatID, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return msg.ID, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### WebhookHandler — HTTP Webhook Reception
|
||||
|
||||
```go
|
||||
@@ -746,15 +775,17 @@ When the Agent finishes processing a message, Manager's `preSend` automatically:
|
||||
```go
|
||||
type ChannelsConfig struct {
|
||||
// ... existing channels
|
||||
Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
|
||||
Matrix MatrixChannelConfig `json:"matrix"`
|
||||
}
|
||||
|
||||
type MatrixChannelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
HomeServer string `yaml:"home_server" json:"home_server"`
|
||||
Token string `yaml:"token" json:"token"`
|
||||
AllowFrom []string `yaml:"allow_from" json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HomeServer string `json:"home_server"`
|
||||
Token string `json:"token"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id"`
|
||||
}
|
||||
```
|
||||
|
||||
@@ -767,6 +798,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
|
||||
}
|
||||
```
|
||||
|
||||
> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config:
|
||||
> ```go
|
||||
> if cfg.UseNative {
|
||||
> m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
> } else {
|
||||
> m.initChannel("whatsapp", "WhatsApp")
|
||||
> }
|
||||
> ```
|
||||
|
||||
#### Add blank import in Gateway
|
||||
|
||||
```go
|
||||
@@ -882,19 +922,21 @@ BaseChannel is the shared abstraction layer for all channels, providing the foll
|
||||
| `IsRunning() bool` | Atomically read running state |
|
||||
| `SetRunning(bool)` | Atomically set running state |
|
||||
| `MaxMessageLength() int` | Message length limit (rune count), 0 = unlimited |
|
||||
| `ReasoningChannelID() string` | Reasoning chain routing target channel ID (empty = no routing) |
|
||||
| `IsAllowed(senderID string) bool` | Legacy allow-list check (supports `"id\|username"` and `"@username"` formats) |
|
||||
| `IsAllowedSender(sender SenderInfo) bool` | New allow-list check (delegates to `identity.MatchAllowed`) |
|
||||
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | Unified group chat trigger filtering logic |
|
||||
| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction → publish to Bus |
|
||||
| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction/Placeholder → publish to Bus |
|
||||
| `SetMediaStore(s) / GetMediaStore()` | MediaStore injected by Manager |
|
||||
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | PlaceholderRecorder injected by Manager |
|
||||
| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction type assertions in HandleMessage) |
|
||||
| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction/Placeholder type assertions in HandleMessage) |
|
||||
|
||||
**Functional Options**:
|
||||
|
||||
```go
|
||||
channels.WithMaxMessageLength(4096) // Set platform message length limit
|
||||
channels.WithGroupTrigger(groupTriggerCfg) // Set group trigger configuration
|
||||
channels.WithReasoningChannelID(id) // Set reasoning chain routing target channel
|
||||
```
|
||||
|
||||
### 4.4 Factory Registry
|
||||
@@ -998,7 +1040,7 @@ StartAll:
|
||||
- runMediaWorker (per-channel outbound media)
|
||||
- dispatchOutbound (route from bus to worker queues)
|
||||
- dispatchOutboundMedia (route from bus to media worker queues)
|
||||
- runTTLJanitor (every 10s clean up expired typing/placeholder)
|
||||
- runTTLJanitor (every 10s clean up expired typing/reaction/placeholder)
|
||||
4. Start shared HTTP server (if configured)
|
||||
|
||||
StopAll:
|
||||
@@ -1206,18 +1248,20 @@ make test # Full test suite
|
||||
|
||||
| Sub-package | Registered Name | Optional Interfaces |
|
||||
|-------------|----------------|-------------------|
|
||||
| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
|
||||
| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (architecture-specific build tags) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
|
||||
| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | — (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/qq/` | `"qq"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) |
|
||||
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) |
|
||||
| `pkg/channels/maixcam/` | `"maixcam"` | — |
|
||||
| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
|
||||
|
||||
### A.3 Interface Quick Reference
|
||||
|
||||
@@ -1231,6 +1275,7 @@ type Channel interface {
|
||||
IsRunning() bool
|
||||
IsAllowed(senderID string) bool
|
||||
IsAllowedSender(sender bus.SenderInfo) bool
|
||||
ReasoningChannelID() string
|
||||
}
|
||||
|
||||
// ===== Optional =====
|
||||
@@ -1324,8 +1369,16 @@ agentLoop.Stop() // Stop Agent
|
||||
|
||||
1. **Media cleanup temporarily disabled**: The `ReleaseAll` call in the Agent loop is commented out (`refactor(loop): disable media cleanup to prevent premature file deletion`) because session boundaries are not yet clearly defined. TTL cleanup remains active.
|
||||
|
||||
2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`).
|
||||
2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
|
||||
|
||||
3. **WeCom has two factories**: `"wecom"` (Bot mode) and `"wecom_app"` (App mode) are registered separately.
|
||||
3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`.
|
||||
|
||||
4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via webhook.
|
||||
4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`).
|
||||
|
||||
5. **WhatsApp has two modes**: `"whatsapp"` (Bridge mode, communicates via external bridge URL) and `"whatsapp_native"` (native whatsmeow mode, connects directly to WhatsApp). Manager selects which to initialize based on `WhatsAppConfig.UseNative`.
|
||||
|
||||
6. **DingTalk uses Stream mode**: DingTalk uses the SDK's Stream/WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
|
||||
|
||||
7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields.
|
||||
|
||||
8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method.
|
||||
+79
-27
@@ -1,7 +1,5 @@
|
||||
# PicoClaw Channel System 重构:完整开发指南
|
||||
# PicoClaw Channel System:完整开发指南
|
||||
|
||||
> **分支**: `refactor/channel-system`
|
||||
> **状态**: 活跃开发中(约 40 commits)
|
||||
> **影响范围**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
|
||||
|
||||
---
|
||||
@@ -46,6 +44,8 @@ pkg/channels/
|
||||
pkg/channels/
|
||||
├── base.go # BaseChannel 共享抽象层
|
||||
├── interfaces.go # 可选能力接口(TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
|
||||
├── README.md # 英文文档
|
||||
├── README.zh.md # 中文文档
|
||||
├── media.go # MediaSender 可选接口
|
||||
├── webhook.go # WebhookHandler, HealthChecker 可选接口
|
||||
├── errors.go # 错误哨兵值(ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
|
||||
@@ -60,7 +60,7 @@ pkg/channels/
|
||||
├── discord/
|
||||
│ ├── init.go
|
||||
│ └── discord.go
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
|
||||
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
|
||||
│ └── ...
|
||||
|
||||
pkg/bus/
|
||||
@@ -111,7 +111,7 @@ pkg/identity/
|
||||
|------|------|
|
||||
| **子包隔离** | 每个 channel 一个独立 Go 子包,依赖 `channels` 父包提供的 `BaseChannel` 和接口 |
|
||||
| **工厂注册** | 各子包通过 `init()` 自注册,Manager 通过名字查找工厂,消除 import 耦合 |
|
||||
| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`)声明,Manager 运行时类型断言发现 |
|
||||
| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`)声明,Manager 运行时类型断言发现 |
|
||||
| **结构化消息** | Peer、MessageID、SenderInfo 从 Metadata 提升为 InboundMessage 的一等字段 |
|
||||
| **错误分类** | Channel 返回哨兵错误(`ErrRateLimit`, `ErrTemporary` 等),Manager 据此决定重试策略 |
|
||||
| **集中编排** | 速率限制、消息分割、重试、Typing/Reaction/Placeholder 全部由 Manager 和 BaseChannel 统一处理,Channel 只负责 Send |
|
||||
@@ -145,6 +145,7 @@ pkg/identity/
|
||||
| _(不存在)_ | `pkg/channels/interfaces.go` | 新增可选能力接口 |
|
||||
| _(不存在)_ | `pkg/channels/media.go` | 新增 MediaSender 接口 |
|
||||
| _(不存在)_ | `pkg/channels/webhook.go` | 新增 WebhookHandler/HealthChecker |
|
||||
| _(不存在)_ | `pkg/channels/whatsapp_native/` | 新增 WhatsApp 原生模式(whatsmeow) |
|
||||
| _(不存在)_ | `pkg/channels/split.go` | 新增消息分割(从 utils 迁入) |
|
||||
| _(不存在)_ | `pkg/bus/types.go` | 新增结构化消息类型 |
|
||||
| _(不存在)_ | `pkg/media/store.go` | 新增媒体文件生命周期管理 |
|
||||
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
cfg.Channels.Telegram.AllowFrom, // 允许列表
|
||||
channels.WithMaxMessageLength(4096), // 平台消息长度上限
|
||||
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // 群聊触发配置
|
||||
channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // 思维链路由
|
||||
)
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
|
||||
matrixCfg.AllowFrom, // 允许列表
|
||||
channels.WithMaxMessageLength(65536), // Matrix 消息长度限制
|
||||
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // 思维链路由(可选)
|
||||
)
|
||||
|
||||
return &MatrixChannel{
|
||||
@@ -666,6 +669,31 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
|
||||
}
|
||||
```
|
||||
|
||||
#### PlaceholderCapable — 占位消息
|
||||
|
||||
```go
|
||||
// 如果平台支持发送占位消息(如 "Thinking... 💭"),并且实现了 MessageEditor,
|
||||
// 则 Manager 的 preSend 会在出站时自动将占位消息编辑为最终回复。
|
||||
// SendPlaceholder 内部根据 PlaceholderConfig.Enabled 决定是否发送;
|
||||
// 返回 ("", nil) 表示跳过。
|
||||
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
cfg := c.config.Channels.Matrix.Placeholder
|
||||
if !cfg.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
text := cfg.Text
|
||||
if text == "" {
|
||||
text = "Thinking... 💭"
|
||||
}
|
||||
// 调用 Matrix API 发送占位消息
|
||||
msg, err := c.sendText(ctx, chatID, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return msg.ID, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### WebhookHandler — HTTP Webhook 接收
|
||||
|
||||
```go
|
||||
@@ -746,15 +774,17 @@ if c.owner != nil && c.placeholderRecorder != nil {
|
||||
```go
|
||||
type ChannelsConfig struct {
|
||||
// ... 现有 channels
|
||||
Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
|
||||
Matrix MatrixChannelConfig `json:"matrix"`
|
||||
}
|
||||
|
||||
type MatrixChannelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
HomeServer string `yaml:"home_server" json:"home_server"`
|
||||
Token string `yaml:"token" json:"token"`
|
||||
AllowFrom []string `yaml:"allow_from" json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HomeServer string `json:"home_server"`
|
||||
Token string `json:"token"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id"`
|
||||
}
|
||||
```
|
||||
|
||||
@@ -767,6 +797,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
|
||||
}
|
||||
```
|
||||
|
||||
> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),需要在 initChannels 中根据配置分支:
|
||||
> ```go
|
||||
> if cfg.UseNative {
|
||||
> m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
> } else {
|
||||
> m.initChannel("whatsapp", "WhatsApp")
|
||||
> }
|
||||
> ```
|
||||
|
||||
#### 在 Gateway 中添加 blank import
|
||||
|
||||
```go
|
||||
@@ -882,19 +921,21 @@ BaseChannel 是所有 channel 的共享抽象层,提供以下能力:
|
||||
| `IsRunning() bool` | 原子读取运行状态 |
|
||||
| `SetRunning(bool)` | 原子设置运行状态 |
|
||||
| `MaxMessageLength() int` | 消息长度限制(rune 计数),0 = 无限制 |
|
||||
| `ReasoningChannelID() string` | 思维链路由目标 channel ID(空 = 不路由) |
|
||||
| `IsAllowed(senderID string) bool` | 旧格式允许列表检查(支持 `"id\|username"` 和 `"@username"` 格式) |
|
||||
| `IsAllowedSender(sender SenderInfo) bool` | 新格式允许列表检查(委托给 `identity.MatchAllowed`) |
|
||||
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | 统一群聊触发过滤逻辑 |
|
||||
| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction → 发布到 Bus |
|
||||
| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction/Placeholder → 发布到 Bus |
|
||||
| `SetMediaStore(s) / GetMediaStore()` | Manager 注入的媒体存储 |
|
||||
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | Manager 注入的占位符记录器 |
|
||||
| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction 类型断言) |
|
||||
| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction/Placeholder 类型断言) |
|
||||
|
||||
**功能选项**:
|
||||
|
||||
```go
|
||||
channels.WithMaxMessageLength(4096) // 设置平台消息长度限制
|
||||
channels.WithGroupTrigger(groupTriggerCfg) // 设置群聊触发配置
|
||||
channels.WithReasoningChannelID(id) // 设置思维链路由目标 channel
|
||||
```
|
||||
|
||||
### 4.4 工厂注册表
|
||||
@@ -998,7 +1039,7 @@ StartAll:
|
||||
- runMediaWorker (per-channel 出站媒体)
|
||||
- dispatchOutbound (从 bus 路由到 worker 队列)
|
||||
- dispatchOutboundMedia (从 bus 路由到 media worker 队列)
|
||||
- runTTLJanitor (每 10s 清理过期 typing/placeholder)
|
||||
- runTTLJanitor (每 10s 清理过期 typing/reaction/placeholder)
|
||||
4. 启动共享 HTTP 服务器(如已配置)
|
||||
|
||||
StopAll:
|
||||
@@ -1206,18 +1247,20 @@ make test # 全量测试
|
||||
|
||||
| 子包 | 注册名 | 可选接口 |
|
||||
|------|--------|----------|
|
||||
| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
|
||||
| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (架构特定 build tags) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
|
||||
| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
|
||||
| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | — (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/qq/` | `"qq"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) |
|
||||
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) |
|
||||
| `pkg/channels/maixcam/` | `"maixcam"` | — |
|
||||
| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
|
||||
| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
|
||||
|
||||
### A.3 接口速查表
|
||||
|
||||
@@ -1231,6 +1274,7 @@ type Channel interface {
|
||||
IsRunning() bool
|
||||
IsAllowed(senderID string) bool
|
||||
IsAllowedSender(sender bus.SenderInfo) bool
|
||||
ReasoningChannelID() string
|
||||
}
|
||||
|
||||
// ===== 可选实现 =====
|
||||
@@ -1324,8 +1368,16 @@ agentLoop.Stop() // 停止 Agent
|
||||
|
||||
1. **媒体清理暂时禁用**:Agent loop 中的 `ReleaseAll` 调用被注释掉了(`refactor(loop): disable media cleanup to prevent premature file deletion`),因为会话边界尚未明确定义。TTL 清理仍然有效。
|
||||
|
||||
2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。
|
||||
2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
|
||||
|
||||
3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式)和 `"wecom_app"`(应用模式)分别注册。
|
||||
3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler` 和 `HealthChecker`。
|
||||
|
||||
4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 webhook 接收消息。
|
||||
4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。
|
||||
|
||||
5. **WhatsApp 有两种模式**:`"whatsapp"`(Bridge 模式,通过外部 bridge URL 通信)和 `"whatsapp_native"`(原生 whatsmeow 模式,直接连接 WhatsApp)。Manager 根据 `WhatsAppConfig.UseNative` 决定初始化哪个。
|
||||
|
||||
6. **DingTalk 使用 Stream 模式**:DingTalk 使用 SDK 的 Stream/WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
|
||||
|
||||
7. **PlaceholderConfig 的配置与实现**:`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channel(Telegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。
|
||||
|
||||
8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。
|
||||
@@ -45,11 +45,13 @@ type replyTokenEntry struct {
|
||||
type LINEChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.LINEConfig
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
botDisplayName string // Bot's display name for text-based mention detection
|
||||
replyTokens sync.Map // chatID -> replyTokenEntry
|
||||
quoteTokens sync.Map // chatID -> quoteToken (string)
|
||||
infoClient *http.Client // for bot info lookups (short timeout)
|
||||
apiClient *http.Client // for messaging API calls
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
botDisplayName string // Bot's display name for text-based mention detection
|
||||
replyTokens sync.Map // chatID -> replyTokenEntry
|
||||
quoteTokens sync.Map // chatID -> quoteToken (string)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -69,6 +71,8 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
|
||||
return &LINEChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
infoClient: &http.Client{Timeout: 10 * time.Second},
|
||||
apiClient: &http.Client{Timeout: 30 * time.Second},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -104,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error {
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.infoClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -644,8 +647,7 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.apiClient.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
@@ -313,7 +313,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
if len(m.channels) == 0 {
|
||||
logger.WarnC("channels", "No channels enabled")
|
||||
return nil
|
||||
return errors.New("no channels enabled")
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "Starting all channels")
|
||||
|
||||
@@ -274,13 +274,12 @@ func TestWorkerRateLimiter(t *testing.T) {
|
||||
limiter: rate.NewLimiter(2, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := t.Context()
|
||||
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
// Enqueue 4 messages
|
||||
for i := 0; i < 4; i++ {
|
||||
for i := range 4 {
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
|
||||
}
|
||||
|
||||
@@ -352,8 +351,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := t.Context()
|
||||
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
@@ -576,7 +574,7 @@ func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
@@ -591,7 +589,7 @@ func TestRecordTypingStop_ConcurrentSafe(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
@@ -834,7 +832,7 @@ func TestLazyWorkerCreation(t *testing.T) {
|
||||
func TestBuildMediaScope_FastIDUniqueness(t *testing.T) {
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
for range 1000 {
|
||||
scope := BuildMediaScope("test", "chat1", "")
|
||||
if seen[scope] {
|
||||
t.Fatalf("duplicate scope generated: %s", scope)
|
||||
|
||||
@@ -337,10 +337,7 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) reconnectLoop() {
|
||||
interval := time.Duration(c.config.ReconnectInterval) * time.Second
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
interval := max(time.Duration(c.config.ReconnectInterval)*time.Second, 5*time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -292,8 +292,8 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
|
||||
// Check Authorization header
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
if strings.TrimPrefix(auth, "Bearer ") == token {
|
||||
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||||
if after == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
+8
-24
@@ -23,10 +23,7 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
var messages []string
|
||||
|
||||
// Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
|
||||
codeBlockBuffer := maxLen / 10
|
||||
if codeBlockBuffer < 50 {
|
||||
codeBlockBuffer = 50
|
||||
}
|
||||
codeBlockBuffer := max(maxLen/10, 50)
|
||||
if codeBlockBuffer > maxLen/2 {
|
||||
codeBlockBuffer = maxLen / 2
|
||||
}
|
||||
@@ -40,10 +37,7 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
}
|
||||
|
||||
// Effective split point: maxLen minus buffer, to leave room for code blocks
|
||||
effectiveLimit := maxLen - codeBlockBuffer
|
||||
if effectiveLimit < maxLen/2 {
|
||||
effectiveLimit = maxLen / 2
|
||||
}
|
||||
effectiveLimit := max(maxLen-codeBlockBuffer, maxLen/2)
|
||||
|
||||
end := start + effectiveLimit
|
||||
|
||||
@@ -85,10 +79,9 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
// If we have a reasonable amount of content after the header, split inside
|
||||
if msgEnd > headerEndIdx+20 {
|
||||
// Find a better split point closer to maxLen
|
||||
innerLimit := start + maxLen - 5 // Leave room for "\n```"
|
||||
if innerLimit > totalLen {
|
||||
innerLimit = totalLen
|
||||
}
|
||||
innerLimit := min(
|
||||
// Leave room for "\n```"
|
||||
start+maxLen-5, totalLen)
|
||||
betterEnd := findLastNewlineInRange(runes, start, innerLimit, 200)
|
||||
if betterEnd > headerEndIdx {
|
||||
msgEnd = betterEnd
|
||||
@@ -117,10 +110,7 @@ func SplitMessage(content string, maxLen int) []string {
|
||||
if unclosedIdx-start > 20 {
|
||||
msgEnd = unclosedIdx
|
||||
} else {
|
||||
splitAt := start + maxLen - 5
|
||||
if splitAt > totalLen {
|
||||
splitAt = totalLen
|
||||
}
|
||||
splitAt := min(start+maxLen-5, totalLen)
|
||||
chunk := strings.TrimRight(string(runes[start:splitAt]), " \t\n\r") + "\n```"
|
||||
messages = append(messages, chunk)
|
||||
remaining := strings.TrimSpace(header + "\n" + string(runes[splitAt:totalLen]))
|
||||
@@ -196,10 +186,7 @@ func findNewlineFrom(runes []rune, from int) int {
|
||||
// findLastNewlineInRange finds the last newline within the last searchWindow runes
|
||||
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
|
||||
func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
|
||||
searchStart := end - searchWindow
|
||||
if searchStart < start {
|
||||
searchStart = start
|
||||
}
|
||||
searchStart := max(end-searchWindow, start)
|
||||
for i := end - 1; i >= searchStart; i-- {
|
||||
if runes[i] == '\n' {
|
||||
return i
|
||||
@@ -211,10 +198,7 @@ func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
|
||||
// findLastSpaceInRange finds the last space/tab within the last searchWindow runes
|
||||
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
|
||||
func findLastSpaceInRange(runes []rune, start, end, searchWindow int) int {
|
||||
searchStart := end - searchWindow
|
||||
if searchStart < start {
|
||||
searchStart = start
|
||||
}
|
||||
searchStart := max(end-searchWindow, start)
|
||||
for i := end - 1; i >= searchStart; i-- {
|
||||
if runes[i] == ' ' || runes[i] == '\t' {
|
||||
return i
|
||||
|
||||
+27
-13
@@ -32,6 +32,7 @@ const (
|
||||
type WeComAppChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComAppConfig
|
||||
client *http.Client
|
||||
accessToken string
|
||||
tokenExpiry time.Time
|
||||
tokenMu sync.RWMutex
|
||||
@@ -129,9 +130,20 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComAppChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: make(map[string]bool),
|
||||
}, nil
|
||||
}
|
||||
@@ -145,6 +157,10 @@ func (c *WeComAppChannel) Name() string {
|
||||
func (c *WeComAppChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Starting WeCom App channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Get initial access token
|
||||
@@ -299,8 +315,7 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", channels.ClassifyNetError(err)
|
||||
}
|
||||
@@ -357,8 +372,7 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
@@ -567,8 +581,9 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message with context
|
||||
go c.processMessage(ctx, msg)
|
||||
// Process the message with the channel's long-lived context (not the HTTP
|
||||
// request context, which is canceled as soon as we return the response).
|
||||
go c.processMessage(c.ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom App requires response within configured timeout (default 5 seconds)
|
||||
@@ -597,14 +612,14 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
c.msgMu.Unlock()
|
||||
|
||||
// Clean up old messages periodically (keep last 1000)
|
||||
// Clean up old messages while still holding the lock to avoid a data race
|
||||
// on len(). Reset the map but re-insert the current msgID so it remains
|
||||
// deduplicated.
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.msgMu.Lock()
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.msgMu.Unlock()
|
||||
c.processedMsgs[msgID] = true
|
||||
}
|
||||
c.msgMu.Unlock()
|
||||
|
||||
senderID := msg.FromUserName
|
||||
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
|
||||
@@ -738,8 +753,7 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func encryptTestMessageApp(message, aesKey string) (string, error) {
|
||||
|
||||
// Prepare message: random(16) + msg_len(4) + msg + corp_id
|
||||
random := make([]byte, 0, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
for i := range 16 {
|
||||
random = append(random, byte(i+1))
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
type WeComBotChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComConfig
|
||||
client *http.Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
|
||||
@@ -93,9 +94,20 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: make(map[string]bool),
|
||||
}, nil
|
||||
}
|
||||
@@ -109,6 +121,10 @@ func (c *WeComBotChannel) Name() string {
|
||||
func (c *WeComBotChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Starting WeCom Bot channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
c.SetRunning(true)
|
||||
@@ -292,8 +308,9 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message asynchronously with context
|
||||
go c.processMessage(ctx, msg)
|
||||
// Process the message with the channel's long-lived context (not the HTTP
|
||||
// request context, which is canceled as soon as we return the response).
|
||||
go c.processMessage(c.ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom Bot requires response within configured timeout (default 5 seconds)
|
||||
@@ -322,14 +339,14 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
c.msgMu.Unlock()
|
||||
|
||||
// Clean up old messages periodically (keep last 1000)
|
||||
// Clean up old messages while still holding the lock to avoid a data race
|
||||
// on len(). Reset the map but re-insert the current msgID so it remains
|
||||
// deduplicated.
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.msgMu.Lock()
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.msgMu.Unlock()
|
||||
c.processedMsgs[msgID] = true
|
||||
}
|
||||
c.msgMu.Unlock()
|
||||
|
||||
senderID := msg.From.UserID
|
||||
|
||||
@@ -442,8 +459,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func encryptTestMessage(message, aesKey string) (string, error) {
|
||||
|
||||
// Prepare message: random(16) + msg_len(4) + msg + receiveid
|
||||
random := make([]byte, 0, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
for i := range 16 {
|
||||
random = append(random, byte(i))
|
||||
}
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("padding size larger than data")
|
||||
}
|
||||
// Verify all padding bytes
|
||||
for i := 0; i < padding; i++ {
|
||||
for i := range padding {
|
||||
if data[len(data)-1-i] != byte(padding) {
|
||||
return nil, fmt.Errorf("invalid padding byte at position %d", i)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
@@ -56,6 +57,8 @@ type WhatsAppNativeChannel struct {
|
||||
runCancel context.CancelFunc
|
||||
reconnectMu sync.Mutex
|
||||
reconnecting bool
|
||||
stopping atomic.Bool // set once Stop begins; prevents new wg.Add calls
|
||||
wg sync.WaitGroup // tracks background goroutines (QR handler, reconnect)
|
||||
}
|
||||
|
||||
// NewWhatsAppNativeChannel creates a WhatsApp channel that uses whatsmeow for connection.
|
||||
@@ -80,6 +83,14 @@ func NewWhatsAppNativeChannel(
|
||||
func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
|
||||
logger.InfoCF("whatsapp", "Starting WhatsApp native channel (whatsmeow)", map[string]any{"store": c.storePath})
|
||||
|
||||
// Reset lifecycle state from any previous Stop() so a restarted channel
|
||||
// behaves correctly. Use reconnectMu to be consistent with eventHandler
|
||||
// and Stop() which coordinate under the same lock.
|
||||
c.reconnectMu.Lock()
|
||||
c.stopping.Store(false)
|
||||
c.reconnecting = false
|
||||
c.reconnectMu.Unlock()
|
||||
|
||||
if err := os.MkdirAll(c.storePath, 0o700); err != nil {
|
||||
return fmt.Errorf("create session store dir: %w", err)
|
||||
}
|
||||
@@ -112,6 +123,12 @@ func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
client := whatsmeow.NewClient(deviceStore, waLogger)
|
||||
|
||||
// Create runCtx/runCancel BEFORE registering event handler and starting
|
||||
// goroutines so that Stop() can cancel them at any time, including during
|
||||
// the QR-login flow.
|
||||
c.runCtx, c.runCancel = context.WithCancel(ctx)
|
||||
|
||||
client.AddEventHandler(c.eventHandler)
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -119,36 +136,75 @@ func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
|
||||
c.client = client
|
||||
c.mu.Unlock()
|
||||
|
||||
// cleanupOnError clears struct references and releases resources when
|
||||
// Start() fails after fields are already assigned. This prevents
|
||||
// Stop() from operating on stale references (double-close, disconnect
|
||||
// of a partially-initialized client, or stray event handler callbacks).
|
||||
startOK := false
|
||||
defer func() {
|
||||
if startOK {
|
||||
return
|
||||
}
|
||||
c.runCancel()
|
||||
client.Disconnect()
|
||||
c.mu.Lock()
|
||||
c.client = nil
|
||||
c.container = nil
|
||||
c.mu.Unlock()
|
||||
_ = container.Close()
|
||||
}()
|
||||
|
||||
if client.Store.ID == nil {
|
||||
qrChan, err := client.GetQRChannel(ctx)
|
||||
qrChan, err := client.GetQRChannel(c.runCtx)
|
||||
if err != nil {
|
||||
_ = container.Close()
|
||||
return fmt.Errorf("get QR channel: %w", err)
|
||||
}
|
||||
if err := client.Connect(); err != nil {
|
||||
_ = container.Close()
|
||||
return fmt.Errorf("connect: %w", err)
|
||||
}
|
||||
for evt := range qrChan {
|
||||
if evt.Event == "code" {
|
||||
logger.InfoCF("whatsapp", "Scan this QR code with WhatsApp (Linked Devices):", nil)
|
||||
qrterminal.GenerateWithConfig(evt.Code, qrterminal.Config{
|
||||
Level: qrterminal.L,
|
||||
Writer: os.Stdout,
|
||||
HalfBlocks: true,
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("whatsapp", "WhatsApp login event", map[string]any{"event": evt.Event})
|
||||
}
|
||||
// Handle QR events in a background goroutine so Start() returns
|
||||
// promptly. The goroutine is tracked via c.wg and respects
|
||||
// c.runCtx for cancellation.
|
||||
// Guard wg.Add with reconnectMu + stopping check (same protocol
|
||||
// as eventHandler) so a concurrent Stop() cannot enter wg.Wait()
|
||||
// while we call wg.Add(1).
|
||||
c.reconnectMu.Lock()
|
||||
if c.stopping.Load() {
|
||||
c.reconnectMu.Unlock()
|
||||
return fmt.Errorf("channel stopped during QR setup")
|
||||
}
|
||||
c.wg.Add(1)
|
||||
c.reconnectMu.Unlock()
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-c.runCtx.Done():
|
||||
return
|
||||
case evt, ok := <-qrChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if evt.Event == "code" {
|
||||
logger.InfoCF("whatsapp", "Scan this QR code with WhatsApp (Linked Devices):", nil)
|
||||
qrterminal.GenerateWithConfig(evt.Code, qrterminal.Config{
|
||||
Level: qrterminal.L,
|
||||
Writer: os.Stdout,
|
||||
HalfBlocks: true,
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("whatsapp", "WhatsApp login event", map[string]any{"event": evt.Event})
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
if err := client.Connect(); err != nil {
|
||||
_ = container.Close()
|
||||
return fmt.Errorf("connect: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.runCtx, c.runCancel = context.WithCancel(ctx)
|
||||
startOK = true
|
||||
c.SetRunning(true)
|
||||
logger.InfoC("whatsapp", "WhatsApp native channel connected")
|
||||
return nil
|
||||
@@ -156,19 +212,53 @@ func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
|
||||
|
||||
func (c *WhatsAppNativeChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("whatsapp", "Stopping WhatsApp native channel")
|
||||
|
||||
// Mark as stopping under reconnectMu so the flag is visible to
|
||||
// eventHandler atomically with respect to its wg.Add(1) call.
|
||||
// This closes the TOCTOU window where eventHandler could check
|
||||
// stopping (false), then Stop sets it true + enters wg.Wait,
|
||||
// then eventHandler calls wg.Add(1) — causing a panic.
|
||||
c.reconnectMu.Lock()
|
||||
c.stopping.Store(true)
|
||||
c.reconnectMu.Unlock()
|
||||
|
||||
if c.runCancel != nil {
|
||||
c.runCancel()
|
||||
}
|
||||
|
||||
// Disconnect the client first so any blocking Connect()/reconnect loops
|
||||
// can be interrupted before we wait on the goroutines.
|
||||
c.mu.Lock()
|
||||
client := c.client
|
||||
container := c.container
|
||||
c.client = nil
|
||||
c.container = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if client != nil {
|
||||
client.Disconnect()
|
||||
}
|
||||
|
||||
// Wait for background goroutines (QR handler, reconnect) to finish in a
|
||||
// context-aware way so Stop can be bounded by ctx.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// All goroutines have finished.
|
||||
case <-ctx.Done():
|
||||
// Context canceled or timed out; log and proceed with best-effort cleanup.
|
||||
logger.WarnC("whatsapp", fmt.Sprintf("Stop context canceled before all goroutines finished: %v", ctx.Err()))
|
||||
}
|
||||
|
||||
// Now it is safe to clear and close resources.
|
||||
c.mu.Lock()
|
||||
c.client = nil
|
||||
c.container = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if container != nil {
|
||||
_ = container.Close()
|
||||
}
|
||||
@@ -187,9 +277,20 @@ func (c *WhatsAppNativeChannel) eventHandler(evt any) {
|
||||
c.reconnectMu.Unlock()
|
||||
return
|
||||
}
|
||||
// Check stopping while holding the lock so the check and wg.Add
|
||||
// are atomic with respect to Stop() setting the flag + calling
|
||||
// wg.Wait(). This prevents the TOCTOU race.
|
||||
if c.stopping.Load() {
|
||||
c.reconnectMu.Unlock()
|
||||
return
|
||||
}
|
||||
c.reconnecting = true
|
||||
c.wg.Add(1)
|
||||
c.reconnectMu.Unlock()
|
||||
go c.reconnectWithBackoff()
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
c.reconnectWithBackoff()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -313,6 +414,12 @@ func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessag
|
||||
return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
// Detect unpaired state: the client is connected (to WhatsApp servers)
|
||||
// but has not completed QR-login yet, so sending would fail.
|
||||
if client.Store.ID == nil {
|
||||
return fmt.Errorf("whatsapp not yet paired (QR login pending): %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
to, err := parseJID(msg.ChatID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid chat id %q: %w", msg.ChatID, err)
|
||||
|
||||
@@ -442,3 +442,28 @@ func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
|
||||
// Unset to ensure we test the default
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
// Set a known home for consistent test results
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := "/custom/picoclaw/home/workspace"
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
}
|
||||
}
|
||||
|
||||
+17
-1
@@ -5,12 +5,28 @@
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// DefaultConfig returns the default configuration for PicoClaw.
|
||||
func DefaultConfig() *Config {
|
||||
// Determine the base path for the workspace.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
var homePath string
|
||||
if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else {
|
||||
userHome, _ := os.UserHomeDir()
|
||||
homePath = filepath.Join(userHome, ".picoclaw")
|
||||
}
|
||||
workspacePath := filepath.Join(homePath, "workspace")
|
||||
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "",
|
||||
|
||||
@@ -64,7 +64,7 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
|
||||
|
||||
// Test round-robin distribution
|
||||
results := make(map[string]int)
|
||||
for i := 0; i < 30; i++ {
|
||||
for range 30 {
|
||||
result, err := cfg.GetModelConfig("lb-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig() error = %v", err)
|
||||
@@ -94,17 +94,15 @@ func TestGetModelConfig_Concurrent(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines*iterations)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
for range goroutines {
|
||||
wg.Go(func() {
|
||||
for range iterations {
|
||||
_, err := cfg.GetModelConfig("concurrent-model")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -122,9 +123,7 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
ready := s.ready
|
||||
checks := make(map[string]Check)
|
||||
for k, v := range s.checks {
|
||||
checks[k] = v
|
||||
}
|
||||
maps.Copy(checks, s.checks)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ready {
|
||||
|
||||
+11
-11
@@ -49,7 +49,7 @@ func TestReleaseAll(t *testing.T) {
|
||||
|
||||
paths := make([]string, 3)
|
||||
refs := make([]string, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
for i := range 3 {
|
||||
paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg")
|
||||
var err error
|
||||
refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1")
|
||||
@@ -228,12 +228,12 @@ func TestConcurrentSafety(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for g := 0; g < goroutines; g++ {
|
||||
for g := range goroutines {
|
||||
go func(gIdx int) {
|
||||
defer wg.Done()
|
||||
scope := strings.Repeat("s", gIdx+1)
|
||||
|
||||
for i := 0; i < filesPerGoroutine; i++ {
|
||||
for i := range filesPerGoroutine {
|
||||
path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp")
|
||||
ref, err := store.Store(path, MediaMeta{Source: "test"}, scope)
|
||||
if err != nil {
|
||||
@@ -448,11 +448,11 @@ func TestConcurrentCleanupSafety(t *testing.T) {
|
||||
wg.Add(workers * 4)
|
||||
|
||||
// Store workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for w := range workers {
|
||||
go func(wIdx int) {
|
||||
defer wg.Done()
|
||||
scope := fmt.Sprintf("scope-%d", wIdx)
|
||||
for i := 0; i < ops; i++ {
|
||||
for i := range ops {
|
||||
p := createTempFile(t, dir, fmt.Sprintf("w%d-f%d.tmp", wIdx, i))
|
||||
store.Store(p, MediaMeta{Source: "test"}, scope)
|
||||
}
|
||||
@@ -460,30 +460,30 @@ func TestConcurrentCleanupSafety(t *testing.T) {
|
||||
}
|
||||
|
||||
// Resolve workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for range workers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < ops; i++ {
|
||||
for range ops {
|
||||
store.Resolve("media://nonexistent")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ReleaseAll workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for w := range workers {
|
||||
go func(wIdx int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < ops; i++ {
|
||||
for range ops {
|
||||
store.ReleaseAll(fmt.Sprintf("scope-%d", wIdx))
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
|
||||
// CleanExpired workers
|
||||
for w := 0; w < workers; w++ {
|
||||
for range workers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < ops; i++ {
|
||||
for range ops {
|
||||
store.CleanExpired()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -1,414 +0,0 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
var supportedProviders = map[string]bool{
|
||||
"anthropic": true,
|
||||
"openai": true,
|
||||
"openrouter": true,
|
||||
"groq": true,
|
||||
"zhipu": true,
|
||||
"vllm": true,
|
||||
"gemini": true,
|
||||
"qwen": true,
|
||||
"deepseek": true,
|
||||
"github_copilot": true,
|
||||
"mistral": true,
|
||||
}
|
||||
|
||||
var supportedChannels = map[string]bool{
|
||||
"telegram": true,
|
||||
"discord": true,
|
||||
"whatsapp": true,
|
||||
"feishu": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"maixcam": true,
|
||||
}
|
||||
|
||||
func findOpenClawConfig(openclawHome string) (string, error) {
|
||||
candidates := []string{
|
||||
filepath.Join(openclawHome, "openclaw.json"),
|
||||
filepath.Join(openclawHome, "config.json"),
|
||||
}
|
||||
for _, p := range candidates {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
|
||||
}
|
||||
|
||||
func LoadOpenClawConfig(configPath string) (map[string]any, error) {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
|
||||
}
|
||||
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
|
||||
}
|
||||
|
||||
converted := convertKeysToSnake(raw)
|
||||
result, ok := converted.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected config format")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ConvertConfig(data map[string]any) (*config.Config, []string, error) {
|
||||
cfg := config.DefaultConfig()
|
||||
var warnings []string
|
||||
|
||||
if agents, ok := getMap(data, "agents"); ok {
|
||||
if defaults, ok := getMap(agents, "defaults"); ok {
|
||||
// Prefer model_name, fallback to model for backward compatibility
|
||||
if v, ok := getString(defaults, "model_name"); ok {
|
||||
cfg.Agents.Defaults.ModelName = v
|
||||
} else if v, ok := getString(defaults, "model"); ok {
|
||||
cfg.Agents.Defaults.Model = v
|
||||
}
|
||||
if v, ok := getFloat(defaults, "max_tokens"); ok {
|
||||
cfg.Agents.Defaults.MaxTokens = int(v)
|
||||
}
|
||||
if v, ok := getFloat(defaults, "temperature"); ok {
|
||||
cfg.Agents.Defaults.Temperature = &v
|
||||
}
|
||||
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
|
||||
cfg.Agents.Defaults.MaxToolIterations = int(v)
|
||||
}
|
||||
if v, ok := getString(defaults, "workspace"); ok {
|
||||
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if providers, ok := getMap(data, "providers"); ok {
|
||||
for name, val := range providers {
|
||||
pMap, ok := val.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
apiKey, _ := getString(pMap, "api_key")
|
||||
apiBase, _ := getString(pMap, "api_base")
|
||||
|
||||
if !supportedProviders[name] {
|
||||
if apiKey != "" || apiBase != "" {
|
||||
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
|
||||
switch name {
|
||||
case "anthropic":
|
||||
cfg.Providers.Anthropic = pc
|
||||
case "openai":
|
||||
cfg.Providers.OpenAI = config.OpenAIProviderConfig{
|
||||
ProviderConfig: pc,
|
||||
WebSearch: getBoolOrDefault(pMap, "web_search", true),
|
||||
}
|
||||
case "openrouter":
|
||||
cfg.Providers.OpenRouter = pc
|
||||
case "groq":
|
||||
cfg.Providers.Groq = pc
|
||||
case "zhipu":
|
||||
cfg.Providers.Zhipu = pc
|
||||
case "vllm":
|
||||
cfg.Providers.VLLM = pc
|
||||
case "gemini":
|
||||
cfg.Providers.Gemini = pc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if channels, ok := getMap(data, "channels"); ok {
|
||||
for name, val := range channels {
|
||||
cMap, ok := val.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if !supportedChannels[name] {
|
||||
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
|
||||
continue
|
||||
}
|
||||
enabled, _ := getBool(cMap, "enabled")
|
||||
allowFrom := getStringSlice(cMap, "allow_from")
|
||||
|
||||
switch name {
|
||||
case "telegram":
|
||||
cfg.Channels.Telegram.Enabled = enabled
|
||||
cfg.Channels.Telegram.AllowFrom = allowFrom
|
||||
if v, ok := getString(cMap, "token"); ok {
|
||||
cfg.Channels.Telegram.Token = v
|
||||
}
|
||||
case "discord":
|
||||
cfg.Channels.Discord.Enabled = enabled
|
||||
cfg.Channels.Discord.AllowFrom = allowFrom
|
||||
if v, ok := getString(cMap, "token"); ok {
|
||||
cfg.Channels.Discord.Token = v
|
||||
}
|
||||
case "whatsapp":
|
||||
cfg.Channels.WhatsApp.Enabled = enabled
|
||||
cfg.Channels.WhatsApp.AllowFrom = allowFrom
|
||||
if v, ok := getString(cMap, "bridge_url"); ok {
|
||||
cfg.Channels.WhatsApp.BridgeURL = v
|
||||
}
|
||||
if v, ok := getBool(cMap, "use_native"); ok {
|
||||
cfg.Channels.WhatsApp.UseNative = v
|
||||
}
|
||||
if v, ok := getString(cMap, "session_store_path"); ok {
|
||||
cfg.Channels.WhatsApp.SessionStorePath = v
|
||||
}
|
||||
case "feishu":
|
||||
cfg.Channels.Feishu.Enabled = enabled
|
||||
cfg.Channels.Feishu.AllowFrom = allowFrom
|
||||
if v, ok := getString(cMap, "app_id"); ok {
|
||||
cfg.Channels.Feishu.AppID = v
|
||||
}
|
||||
if v, ok := getString(cMap, "app_secret"); ok {
|
||||
cfg.Channels.Feishu.AppSecret = v
|
||||
}
|
||||
if v, ok := getString(cMap, "encrypt_key"); ok {
|
||||
cfg.Channels.Feishu.EncryptKey = v
|
||||
}
|
||||
if v, ok := getString(cMap, "verification_token"); ok {
|
||||
cfg.Channels.Feishu.VerificationToken = v
|
||||
}
|
||||
case "qq":
|
||||
cfg.Channels.QQ.Enabled = enabled
|
||||
cfg.Channels.QQ.AllowFrom = allowFrom
|
||||
if v, ok := getString(cMap, "app_id"); ok {
|
||||
cfg.Channels.QQ.AppID = v
|
||||
}
|
||||
if v, ok := getString(cMap, "app_secret"); ok {
|
||||
cfg.Channels.QQ.AppSecret = v
|
||||
}
|
||||
case "dingtalk":
|
||||
cfg.Channels.DingTalk.Enabled = enabled
|
||||
cfg.Channels.DingTalk.AllowFrom = allowFrom
|
||||
if v, ok := getString(cMap, "client_id"); ok {
|
||||
cfg.Channels.DingTalk.ClientID = v
|
||||
}
|
||||
if v, ok := getString(cMap, "client_secret"); ok {
|
||||
cfg.Channels.DingTalk.ClientSecret = v
|
||||
}
|
||||
case "maixcam":
|
||||
cfg.Channels.MaixCam.Enabled = enabled
|
||||
cfg.Channels.MaixCam.AllowFrom = allowFrom
|
||||
if v, ok := getString(cMap, "host"); ok {
|
||||
cfg.Channels.MaixCam.Host = v
|
||||
}
|
||||
if v, ok := getFloat(cMap, "port"); ok {
|
||||
cfg.Channels.MaixCam.Port = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if gateway, ok := getMap(data, "gateway"); ok {
|
||||
if v, ok := getString(gateway, "host"); ok {
|
||||
cfg.Gateway.Host = v
|
||||
}
|
||||
if v, ok := getFloat(gateway, "port"); ok {
|
||||
cfg.Gateway.Port = int(v)
|
||||
}
|
||||
}
|
||||
|
||||
if tools, ok := getMap(data, "tools"); ok {
|
||||
if web, ok := getMap(tools, "web"); ok {
|
||||
// Migrate old "search" config to "brave" if api_key is present
|
||||
if search, ok := getMap(web, "search"); ok {
|
||||
if v, ok := getString(search, "api_key"); ok {
|
||||
cfg.Tools.Web.Brave.APIKey = v
|
||||
if v != "" {
|
||||
cfg.Tools.Web.Brave.Enabled = true
|
||||
}
|
||||
}
|
||||
if v, ok := getFloat(search, "max_results"); ok {
|
||||
cfg.Tools.Web.Brave.MaxResults = int(v)
|
||||
cfg.Tools.Web.DuckDuckGo.MaxResults = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, warnings, nil
|
||||
}
|
||||
|
||||
func MergeConfig(existing, incoming *config.Config) *config.Config {
|
||||
if existing.Providers.Anthropic.APIKey == "" {
|
||||
existing.Providers.Anthropic = incoming.Providers.Anthropic
|
||||
}
|
||||
if existing.Providers.OpenAI.APIKey == "" {
|
||||
existing.Providers.OpenAI = incoming.Providers.OpenAI
|
||||
}
|
||||
if existing.Providers.OpenRouter.APIKey == "" {
|
||||
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
|
||||
}
|
||||
if existing.Providers.Groq.APIKey == "" {
|
||||
existing.Providers.Groq = incoming.Providers.Groq
|
||||
}
|
||||
if existing.Providers.Zhipu.APIKey == "" {
|
||||
existing.Providers.Zhipu = incoming.Providers.Zhipu
|
||||
}
|
||||
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
|
||||
existing.Providers.VLLM = incoming.Providers.VLLM
|
||||
}
|
||||
if existing.Providers.Gemini.APIKey == "" {
|
||||
existing.Providers.Gemini = incoming.Providers.Gemini
|
||||
}
|
||||
if existing.Providers.DeepSeek.APIKey == "" {
|
||||
existing.Providers.DeepSeek = incoming.Providers.DeepSeek
|
||||
}
|
||||
if existing.Providers.GitHubCopilot.APIBase == "" {
|
||||
existing.Providers.GitHubCopilot = incoming.Providers.GitHubCopilot
|
||||
}
|
||||
if existing.Providers.Qwen.APIKey == "" {
|
||||
existing.Providers.Qwen = incoming.Providers.Qwen
|
||||
}
|
||||
|
||||
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
|
||||
existing.Channels.Telegram = incoming.Channels.Telegram
|
||||
}
|
||||
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
|
||||
existing.Channels.Discord = incoming.Channels.Discord
|
||||
}
|
||||
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
|
||||
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
|
||||
}
|
||||
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
|
||||
existing.Channels.Feishu = incoming.Channels.Feishu
|
||||
}
|
||||
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
|
||||
existing.Channels.QQ = incoming.Channels.QQ
|
||||
}
|
||||
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
|
||||
existing.Channels.DingTalk = incoming.Channels.DingTalk
|
||||
}
|
||||
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
|
||||
existing.Channels.MaixCam = incoming.Channels.MaixCam
|
||||
}
|
||||
|
||||
if existing.Tools.Web.Brave.APIKey == "" {
|
||||
existing.Tools.Web.Brave = incoming.Tools.Web.Brave
|
||||
}
|
||||
|
||||
return existing
|
||||
}
|
||||
|
||||
func camelToSnake(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if unicode.IsUpper(r) {
|
||||
if i > 0 {
|
||||
prev := rune(s[i-1])
|
||||
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
|
||||
result.WriteRune('_')
|
||||
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
} else {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func convertKeysToSnake(data any) any {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
result := make(map[string]any, len(v))
|
||||
for key, val := range v {
|
||||
result[camelToSnake(key)] = convertKeysToSnake(val)
|
||||
}
|
||||
return result
|
||||
case []any:
|
||||
result := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
result[i] = convertKeysToSnake(val)
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteWorkspacePath(path string) string {
|
||||
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
|
||||
return path
|
||||
}
|
||||
|
||||
func getMap(data map[string]any, key string) (map[string]any, bool) {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
m, ok := v.(map[string]any)
|
||||
return m, ok
|
||||
}
|
||||
|
||||
func getString(data map[string]any, key string) (string, bool) {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
return s, ok
|
||||
}
|
||||
|
||||
func getFloat(data map[string]any, key string) (float64, bool) {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
f, ok := v.(float64)
|
||||
return f, ok
|
||||
}
|
||||
|
||||
func getBool(data map[string]any, key string) (bool, bool) {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
return b, ok
|
||||
}
|
||||
|
||||
func getBoolOrDefault(data map[string]any, key string, defaultVal bool) bool {
|
||||
if v, ok := getBool(data, key); ok {
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func getStringSlice(data map[string]any, key string) []string {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
result := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,24 +1,50 @@
|
||||
package migrate
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var migrateableFiles = []string{
|
||||
"AGENTS.md",
|
||||
"SOUL.md",
|
||||
"USER.md",
|
||||
"TOOLS.md",
|
||||
"HEARTBEAT.md",
|
||||
func ResolveTargetHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return ExpandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
|
||||
return ExpandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw"), nil
|
||||
}
|
||||
|
||||
var migrateableDirs = []string{
|
||||
"memory",
|
||||
"skills",
|
||||
func ExpandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
}
|
||||
if path[0] == '~' {
|
||||
home, _ := os.UserHomeDir()
|
||||
if len(path) > 1 && path[1] == '/' {
|
||||
return home + path[1:]
|
||||
}
|
||||
return home
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
|
||||
func ResolveWorkspace(homeDir string) string {
|
||||
return filepath.Join(homeDir, "workspace")
|
||||
}
|
||||
|
||||
func PlanWorkspaceMigration(
|
||||
srcWorkspace, dstWorkspace string,
|
||||
migrateableFiles []string,
|
||||
migrateableDirs []string,
|
||||
force bool,
|
||||
) ([]Action, error) {
|
||||
var actions []Action
|
||||
|
||||
for _, filename := range migrateableFiles {
|
||||
@@ -50,7 +76,7 @@ func planFileCopy(src, dst string, force bool) Action {
|
||||
return Action{
|
||||
Type: ActionSkip,
|
||||
Source: src,
|
||||
Destination: dst,
|
||||
Target: dst,
|
||||
Description: "source file not found",
|
||||
}
|
||||
}
|
||||
@@ -60,7 +86,7 @@ func planFileCopy(src, dst string, force bool) Action {
|
||||
return Action{
|
||||
Type: ActionBackup,
|
||||
Source: src,
|
||||
Destination: dst,
|
||||
Target: dst,
|
||||
Description: "destination exists, will backup and overwrite",
|
||||
}
|
||||
}
|
||||
@@ -68,7 +94,7 @@ func planFileCopy(src, dst string, force bool) Action {
|
||||
return Action{
|
||||
Type: ActionCopy,
|
||||
Source: src,
|
||||
Destination: dst,
|
||||
Target: dst,
|
||||
Description: "copy file",
|
||||
}
|
||||
}
|
||||
@@ -91,7 +117,7 @@ func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
|
||||
if info.IsDir() {
|
||||
actions = append(actions, Action{
|
||||
Type: ActionCreateDir,
|
||||
Destination: dst,
|
||||
Target: dst,
|
||||
Description: "create directory",
|
||||
})
|
||||
return nil
|
||||
@@ -104,3 +130,33 @@ func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
|
||||
|
||||
return actions, err
|
||||
}
|
||||
|
||||
func RelPath(path, base string) string {
|
||||
rel, err := filepath.Rel(base, path)
|
||||
if err != nil {
|
||||
return filepath.Base(path)
|
||||
}
|
||||
return rel
|
||||
}
|
||||
|
||||
func CopyFile(src, dst string) error {
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
info, err := srcFile.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dstFile.Close()
|
||||
|
||||
_, err = io.Copy(dstFile, srcFile)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExpandHome(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"/absolute/path", "/absolute/path"},
|
||||
{"relative/path", "relative/path"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := ExpandHome(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandHomeWithTilde(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
result := ExpandHome("~/path")
|
||||
assert.Equal(t, home+"/path", result)
|
||||
|
||||
result = ExpandHome("~")
|
||||
assert.Equal(t, home, result)
|
||||
}
|
||||
|
||||
func TestResolveWorkspace(t *testing.T) {
|
||||
result := ResolveWorkspace("/home/user/.picoclaw")
|
||||
assert.Equal(t, "/home/user/.picoclaw/workspace", result)
|
||||
}
|
||||
|
||||
func TestRelPath(t *testing.T) {
|
||||
result := RelPath("/home/user/.picoclaw/workspace/file.txt", "/home/user/.picoclaw")
|
||||
assert.Equal(t, "workspace/file.txt", result)
|
||||
}
|
||||
|
||||
func TestRelPathError(t *testing.T) {
|
||||
result := RelPath("relative/path", "/different/base")
|
||||
assert.Equal(t, "path", result)
|
||||
}
|
||||
|
||||
func TestResolveTargetHome(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := ResolveTargetHome("")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(home, ".picoclaw"), result)
|
||||
}
|
||||
|
||||
func TestResolveTargetHomeWithOverride(t *testing.T) {
|
||||
result, err := ResolveTargetHome("/custom/path")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/custom/path", result)
|
||||
}
|
||||
|
||||
func TestCopyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
sourceFile := filepath.Join(tmpDir, "source.txt")
|
||||
err := os.WriteFile(sourceFile, []byte("test content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
dstFile := filepath.Join(tmpDir, "dest.txt")
|
||||
err = CopyFile(sourceFile, dstFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(dstFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test content", string(content))
|
||||
}
|
||||
|
||||
func TestCopyFileSourceNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
err := CopyFile(filepath.Join(tmpDir, "nonexistent.txt"), filepath.Join(tmpDir, "dest.txt"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(filepath.Join(srcWorkspace, "subdir"), 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "subdir", "file2.txt"), []byte("content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{"subdir"},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.GreaterOrEqual(t, len(actions), 1)
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionBackup, actions[0].Type)
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationForce(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
|
||||
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
|
||||
|
||||
err := os.MkdirAll(srcWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(dstWorkspace, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
srcWorkspace,
|
||||
dstWorkspace,
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(actions), 1)
|
||||
assert.Equal(t, ActionCopy, actions[0].Type)
|
||||
}
|
||||
|
||||
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
actions, err := PlanWorkspaceMigration(
|
||||
filepath.Join(tmpDir, "nonexistent"),
|
||||
filepath.Join(tmpDir, "dst", "workspace"),
|
||||
[]string{"file1.txt"},
|
||||
[]string{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, actions, 1)
|
||||
assert.Equal(t, ActionSkip, actions[0].Type)
|
||||
assert.Contains(t, actions[0].Description, "source file not found")
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package internal
|
||||
|
||||
type Options struct {
|
||||
DryRun bool
|
||||
ConfigOnly bool
|
||||
WorkspaceOnly bool
|
||||
Force bool
|
||||
Refresh bool
|
||||
Source string
|
||||
SourceHome string
|
||||
TargetHome string
|
||||
}
|
||||
|
||||
type Operation interface {
|
||||
GetSourceName() string
|
||||
GetSourceHome() (string, error)
|
||||
GetSourceWorkspace() (string, error)
|
||||
GetSourceConfigFile() (string, error)
|
||||
ExecuteConfigMigration(srcConfigPath, dstConfigPath string) error
|
||||
GetMigrateableFiles() []string
|
||||
GetMigrateableDirs() []string
|
||||
}
|
||||
|
||||
type HandlerFactory func(opts Options) Operation
|
||||
|
||||
type ActionType int
|
||||
|
||||
const (
|
||||
ActionCopy ActionType = iota
|
||||
ActionSkip
|
||||
ActionBackup
|
||||
ActionConvertConfig
|
||||
ActionCreateDir
|
||||
ActionMergeConfig
|
||||
)
|
||||
|
||||
type Action struct {
|
||||
Type ActionType
|
||||
Source string
|
||||
Target string
|
||||
Description string
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
FilesCopied int
|
||||
FilesSkipped int
|
||||
BackupsCreated int
|
||||
ConfigMigrated bool
|
||||
DirsCreated int
|
||||
Warnings []string
|
||||
Errors []error
|
||||
}
|
||||
+136
-214
@@ -2,53 +2,73 @@ package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/migrate/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/migrate/sources/openclaw"
|
||||
)
|
||||
|
||||
type ActionType int
|
||||
type (
|
||||
Options = internal.Options
|
||||
Operation = internal.Operation
|
||||
ActionType = internal.ActionType
|
||||
Action = internal.Action
|
||||
Result = internal.Result
|
||||
HandlerFactory = internal.HandlerFactory
|
||||
)
|
||||
|
||||
const (
|
||||
ActionCopy ActionType = iota
|
||||
ActionSkip
|
||||
ActionBackup
|
||||
ActionConvertConfig
|
||||
ActionCreateDir
|
||||
ActionMergeConfig
|
||||
ActionCopy = internal.ActionCopy
|
||||
ActionSkip = internal.ActionSkip
|
||||
ActionBackup = internal.ActionBackup
|
||||
ActionConvertConfig = internal.ActionConvertConfig
|
||||
ActionCreateDir = internal.ActionCreateDir
|
||||
ActionMergeConfig = internal.ActionMergeConfig
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
DryRun bool
|
||||
ConfigOnly bool
|
||||
WorkspaceOnly bool
|
||||
Force bool
|
||||
Refresh bool
|
||||
OpenClawHome string
|
||||
PicoClawHome string
|
||||
type MigrateInstance struct {
|
||||
options Options
|
||||
handlers map[string]Operation
|
||||
}
|
||||
|
||||
type Action struct {
|
||||
Type ActionType
|
||||
Source string
|
||||
Destination string
|
||||
Description string
|
||||
func NewMigrateInstance(opts Options) *MigrateInstance {
|
||||
instance := &MigrateInstance{
|
||||
options: opts,
|
||||
handlers: make(map[string]Operation),
|
||||
}
|
||||
|
||||
openclaw_handler, err := openclaw.NewOpenclawHandler(opts)
|
||||
if err == nil {
|
||||
instance.Register(openclaw_handler.GetSourceName(), openclaw_handler)
|
||||
}
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
FilesCopied int
|
||||
FilesSkipped int
|
||||
BackupsCreated int
|
||||
ConfigMigrated bool
|
||||
DirsCreated int
|
||||
Warnings []string
|
||||
Errors []error
|
||||
func (m *MigrateInstance) Register(moduleName string, module Operation) {
|
||||
m.handlers[moduleName] = module
|
||||
}
|
||||
|
||||
func Run(opts Options) (*Result, error) {
|
||||
func (m *MigrateInstance) getCurrentHandler() (Operation, error) {
|
||||
source := m.options.Source
|
||||
if source == "" {
|
||||
source = "openclaw"
|
||||
}
|
||||
handler, ok := m.handlers[source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Source '%s' not found", source)
|
||||
}
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
func (m *MigrateInstance) Run(opts Options) (*Result, error) {
|
||||
handler, err := m.getCurrentHandler()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.ConfigOnly && opts.WorkspaceOnly {
|
||||
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
|
||||
}
|
||||
@@ -57,28 +77,28 @@ func Run(opts Options) (*Result, error) {
|
||||
opts.WorkspaceOnly = true
|
||||
}
|
||||
|
||||
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
|
||||
sourceHome, err := handler.GetSourceHome()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
|
||||
targetHome, err := internal.ResolveTargetHome(opts.TargetHome)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err = os.Stat(openclawHome); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
|
||||
if _, err = os.Stat(sourceHome); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("Source installation not found at %s", sourceHome)
|
||||
}
|
||||
|
||||
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
|
||||
actions, warnings, err := m.Plan(opts, sourceHome, targetHome)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("Migrating from OpenClaw to PicoClaw")
|
||||
fmt.Printf(" Source: %s\n", openclawHome)
|
||||
fmt.Printf(" Destination: %s\n", picoClawHome)
|
||||
fmt.Println("Migrating from Source to PicoClaw")
|
||||
fmt.Printf(" Source: %s\n", sourceHome)
|
||||
fmt.Printf(" Target: %s\n", targetHome)
|
||||
fmt.Println()
|
||||
|
||||
if opts.DryRun {
|
||||
@@ -95,19 +115,23 @@ func Run(opts Options) (*Result, error) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
result := Execute(actions, openclawHome, picoClawHome)
|
||||
result := m.Execute(actions, sourceHome, targetHome)
|
||||
result.Warnings = warnings
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
|
||||
func (m *MigrateInstance) Plan(opts Options, sourceHome, targetHome string) ([]Action, []string, error) {
|
||||
var actions []Action
|
||||
var warnings []string
|
||||
handler, err := m.getCurrentHandler()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
force := opts.Force || opts.Refresh
|
||||
|
||||
if !opts.WorkspaceOnly {
|
||||
configPath, err := findOpenClawConfig(openclawHome)
|
||||
configPath, err := handler.GetSourceConfigFile()
|
||||
if err != nil {
|
||||
if opts.ConfigOnly {
|
||||
return nil, nil, err
|
||||
@@ -117,91 +141,95 @@ func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string,
|
||||
actions = append(actions, Action{
|
||||
Type: ActionConvertConfig,
|
||||
Source: configPath,
|
||||
Destination: filepath.Join(picoClawHome, "config.json"),
|
||||
Description: "convert OpenClaw config to PicoClaw format",
|
||||
Target: filepath.Join(targetHome, "config.json"),
|
||||
Description: "convert Source config to PicoClaw format",
|
||||
})
|
||||
|
||||
data, err := LoadOpenClawConfig(configPath)
|
||||
if err == nil {
|
||||
_, configWarnings, _ := ConvertConfig(data)
|
||||
warnings = append(warnings, configWarnings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !opts.ConfigOnly {
|
||||
srcWorkspace := resolveWorkspace(openclawHome)
|
||||
dstWorkspace := resolveWorkspace(picoClawHome)
|
||||
srcWorkspace, err := handler.GetSourceWorkspace()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("getting source workspace: %w", err)
|
||||
}
|
||||
dstWorkspace := internal.ResolveWorkspace(targetHome)
|
||||
|
||||
if _, err := os.Stat(srcWorkspace); err == nil {
|
||||
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
|
||||
wsActions, err := internal.PlanWorkspaceMigration(srcWorkspace, dstWorkspace,
|
||||
handler.GetMigrateableFiles(),
|
||||
handler.GetMigrateableDirs(),
|
||||
force)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
|
||||
}
|
||||
actions = append(actions, wsActions...)
|
||||
} else {
|
||||
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
|
||||
warnings = append(warnings, "Source workspace directory not found, skipping workspace migration")
|
||||
}
|
||||
}
|
||||
|
||||
return actions, warnings, nil
|
||||
}
|
||||
|
||||
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
|
||||
func (m *MigrateInstance) Execute(actions []Action, sourceHome, targetHome string) *Result {
|
||||
result := &Result{}
|
||||
handler, err := m.getCurrentHandler()
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
switch action.Type {
|
||||
case ActionConvertConfig:
|
||||
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
|
||||
if err := handler.ExecuteConfigMigration(action.Source, action.Target); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
|
||||
fmt.Printf(" ✗ Config migration failed: %v\n", err)
|
||||
} else {
|
||||
result.ConfigMigrated = true
|
||||
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
|
||||
fmt.Printf(" ✓ Converted config: %s\n", action.Target)
|
||||
}
|
||||
case ActionCreateDir:
|
||||
if err := os.MkdirAll(action.Destination, 0o755); err != nil {
|
||||
if err := os.MkdirAll(action.Target, 0o755); err != nil {
|
||||
result.Errors = append(result.Errors, err)
|
||||
} else {
|
||||
result.DirsCreated++
|
||||
}
|
||||
case ActionBackup:
|
||||
bakPath := action.Destination + ".bak"
|
||||
if err := copyFile(action.Destination, bakPath); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
|
||||
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
|
||||
bakPath := action.Target + ".bak"
|
||||
if err := internal.CopyFile(action.Target, bakPath); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Target, err))
|
||||
fmt.Printf(" ✗ Backup failed: %s\n", action.Target)
|
||||
continue
|
||||
}
|
||||
result.BackupsCreated++
|
||||
fmt.Printf(
|
||||
" ✓ Backed up %s -> %s.bak\n",
|
||||
filepath.Base(action.Destination),
|
||||
filepath.Base(action.Destination),
|
||||
filepath.Base(action.Target),
|
||||
filepath.Base(action.Target),
|
||||
)
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
|
||||
if err := os.MkdirAll(filepath.Dir(action.Target), 0o755); err != nil {
|
||||
result.Errors = append(result.Errors, err)
|
||||
continue
|
||||
}
|
||||
if err := copyFile(action.Source, action.Destination); err != nil {
|
||||
if err := internal.CopyFile(action.Source, action.Target); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
|
||||
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
|
||||
} else {
|
||||
result.FilesCopied++
|
||||
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
|
||||
fmt.Printf(" ✓ Copied %s\n", internal.RelPath(action.Source, sourceHome))
|
||||
}
|
||||
case ActionCopy:
|
||||
if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
|
||||
if err := os.MkdirAll(filepath.Dir(action.Target), 0o755); err != nil {
|
||||
result.Errors = append(result.Errors, err)
|
||||
continue
|
||||
}
|
||||
if err := copyFile(action.Source, action.Destination); err != nil {
|
||||
if err := internal.CopyFile(action.Source, action.Target); err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
|
||||
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
|
||||
} else {
|
||||
result.FilesCopied++
|
||||
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
|
||||
fmt.Printf(" ✓ Copied %s\n", internal.RelPath(action.Source, sourceHome))
|
||||
}
|
||||
case ActionSkip:
|
||||
result.FilesSkipped++
|
||||
@@ -211,31 +239,6 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
|
||||
return result
|
||||
}
|
||||
|
||||
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
|
||||
data, err := LoadOpenClawConfig(srcConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
incoming, _, err := ConvertConfig(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(dstConfigPath); err == nil {
|
||||
existing, err := config.LoadConfig(dstConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading existing PicoClaw config: %w", err)
|
||||
}
|
||||
incoming = MergeConfig(existing, incoming)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return config.SaveConfig(dstConfigPath, incoming)
|
||||
}
|
||||
|
||||
func Confirm() bool {
|
||||
fmt.Print("Proceed with migration? (y/n): ")
|
||||
var response string
|
||||
@@ -243,49 +246,7 @@ func Confirm() bool {
|
||||
return strings.ToLower(strings.TrimSpace(response)) == "y"
|
||||
}
|
||||
|
||||
func PrintPlan(actions []Action, warnings []string) {
|
||||
fmt.Println("Planned actions:")
|
||||
copies := 0
|
||||
skips := 0
|
||||
backups := 0
|
||||
configCount := 0
|
||||
|
||||
for _, action := range actions {
|
||||
switch action.Type {
|
||||
case ActionConvertConfig:
|
||||
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
|
||||
configCount++
|
||||
case ActionCopy:
|
||||
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
|
||||
copies++
|
||||
case ActionBackup:
|
||||
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
|
||||
backups++
|
||||
copies++
|
||||
case ActionSkip:
|
||||
if action.Description != "" {
|
||||
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
|
||||
}
|
||||
skips++
|
||||
case ActionCreateDir:
|
||||
fmt.Printf(" [mkdir] %s\n", action.Destination)
|
||||
}
|
||||
}
|
||||
|
||||
if len(warnings) > 0 {
|
||||
fmt.Println()
|
||||
fmt.Println("Warnings:")
|
||||
for _, w := range warnings {
|
||||
fmt.Printf(" - %s\n", w)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
|
||||
copies, configCount, backups, skips)
|
||||
}
|
||||
|
||||
func PrintSummary(result *Result) {
|
||||
func (m *MigrateInstance) PrintSummary(result *Result) {
|
||||
fmt.Println()
|
||||
parts := []string{}
|
||||
if result.FilesCopied > 0 {
|
||||
@@ -316,83 +277,44 @@ func PrintSummary(result *Result) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveOpenClawHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return expandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
|
||||
return expandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".openclaw"), nil
|
||||
}
|
||||
func PrintPlan(actions []Action, warnings []string) {
|
||||
fmt.Println("Planned actions:")
|
||||
copies := 0
|
||||
skips := 0
|
||||
backups := 0
|
||||
configCount := 0
|
||||
|
||||
func resolvePicoClawHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return expandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
|
||||
return expandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw"), nil
|
||||
}
|
||||
|
||||
func resolveWorkspace(homeDir string) string {
|
||||
return filepath.Join(homeDir, "workspace")
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
}
|
||||
if path[0] == '~' {
|
||||
home, _ := os.UserHomeDir()
|
||||
if len(path) > 1 && path[1] == '/' {
|
||||
return home + path[1:]
|
||||
for _, action := range actions {
|
||||
switch action.Type {
|
||||
case ActionConvertConfig:
|
||||
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Target)
|
||||
configCount++
|
||||
case ActionCopy:
|
||||
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
|
||||
copies++
|
||||
case ActionBackup:
|
||||
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Target))
|
||||
backups++
|
||||
copies++
|
||||
case ActionSkip:
|
||||
if action.Description != "" {
|
||||
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
|
||||
}
|
||||
skips++
|
||||
case ActionCreateDir:
|
||||
fmt.Printf(" [mkdir] %s\n", action.Target)
|
||||
}
|
||||
return home
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func backupFile(path string) error {
|
||||
bakPath := path + ".bak"
|
||||
return copyFile(path, bakPath)
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
info, err := srcFile.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dstFile.Close()
|
||||
|
||||
_, err = io.Copy(dstFile, srcFile)
|
||||
return err
|
||||
}
|
||||
|
||||
func relPath(path, base string) string {
|
||||
rel, err := filepath.Rel(base, path)
|
||||
if err != nil {
|
||||
return filepath.Base(path)
|
||||
}
|
||||
return rel
|
||||
|
||||
if len(warnings) > 0 {
|
||||
fmt.Println()
|
||||
fmt.Println("Warnings:")
|
||||
for _, w := range warnings {
|
||||
fmt.Printf(" - %s\n", w)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
|
||||
copies, configCount, backups, skips)
|
||||
}
|
||||
|
||||
+355
-819
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
package openclaw
|
||||
|
||||
var migrateableFiles = []string{
|
||||
"AGENTS.md",
|
||||
"SOUL.md",
|
||||
"USER.md",
|
||||
"TOOLS.md",
|
||||
"HEARTBEAT.md",
|
||||
}
|
||||
|
||||
var migrateableDirs = []string{
|
||||
"memory",
|
||||
"skills",
|
||||
}
|
||||
|
||||
var supportedChannels = map[string]bool{
|
||||
"whatsapp": true,
|
||||
"telegram": true,
|
||||
"feishu": true,
|
||||
"discord": true,
|
||||
"maixcam": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"slack": true,
|
||||
"line": true,
|
||||
"onebot": true,
|
||||
"wecom": true,
|
||||
"wecom_app": true,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,714 @@
|
||||
package openclaw
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadOpenClawConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
|
||||
testConfig := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": {
|
||||
"primary": "anthropic/claude-sonnet-4-20250514"
|
||||
},
|
||||
"workspace": "~/.openclaw/workspace"
|
||||
},
|
||||
"list": [
|
||||
{
|
||||
"id": "main",
|
||||
"name": "Main Agent",
|
||||
"model": {
|
||||
"primary": "openai/gpt-4o",
|
||||
"fallbacks": ["claude-3-opus"]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"botToken": "test-token",
|
||||
"allowFrom": ["user1", "user2"]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "discord-token"
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"api_key": "sk-ant-test",
|
||||
"base_url": "https://api.anthropic.com"
|
||||
},
|
||||
"openai": {
|
||||
"api_key": "sk-test"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadOpenClawConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Agents == nil {
|
||||
t.Error("agents should not be nil")
|
||||
}
|
||||
|
||||
if cfg.Agents.Defaults == nil {
|
||||
t.Error("agents.defaults should not be nil")
|
||||
}
|
||||
|
||||
provider, model := cfg.GetDefaultModel()
|
||||
if provider != "anthropic" {
|
||||
t.Errorf("expected provider 'anthropic', got '%s'", provider)
|
||||
}
|
||||
if model != "claude-sonnet-4-20250514" {
|
||||
t.Errorf("expected model 'claude-sonnet-4-20250514', got '%s'", model)
|
||||
}
|
||||
|
||||
workspace := cfg.GetDefaultWorkspace()
|
||||
if workspace != "~/.picoclaw/workspace" {
|
||||
t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", workspace)
|
||||
}
|
||||
|
||||
agents := cfg.GetAgents()
|
||||
if len(agents) != 1 {
|
||||
t.Errorf("expected 1 agent, got %d", len(agents))
|
||||
}
|
||||
if agents[0].ID != "main" {
|
||||
t.Errorf("expected agent id 'main', got '%s'", agents[0].ID)
|
||||
}
|
||||
|
||||
if cfg.Channels == nil {
|
||||
t.Error("channels should not be nil")
|
||||
}
|
||||
if cfg.Channels.Telegram == nil {
|
||||
t.Error("telegram channel should not be nil")
|
||||
}
|
||||
if cfg.Channels.Telegram.BotToken == nil || *cfg.Channels.Telegram.BotToken != "test-token" {
|
||||
t.Error("telegram bot token not parsed correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProviderConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
|
||||
testConfig := `{
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"api_key": "sk-ant-test",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"max_tokens": 4096
|
||||
},
|
||||
"openai": {
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://api.openai.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadOpenClawConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
providers := GetProviderConfig(cfg.Models)
|
||||
if len(providers) != 2 {
|
||||
t.Errorf("expected 2 providers, got %d", len(providers))
|
||||
}
|
||||
|
||||
if anthropic, ok := providers["anthropic"]; ok {
|
||||
if anthropic.APIKey != "sk-ant-test" {
|
||||
t.Errorf("expected anthropic api_key 'sk-ant-test', got '%s'", anthropic.APIKey)
|
||||
}
|
||||
if anthropic.BaseURL != "https://api.anthropic.com" {
|
||||
t.Errorf("expected anthropic base_url 'https://api.anthropic.com', got '%s'", anthropic.BaseURL)
|
||||
}
|
||||
} else {
|
||||
t.Error("anthropic provider not found")
|
||||
}
|
||||
|
||||
if openai, ok := providers["openai"]; ok {
|
||||
if openai.APIKey != "sk-test" {
|
||||
t.Errorf("expected openai api_key 'sk-test', got '%s'", openai.APIKey)
|
||||
}
|
||||
} else {
|
||||
t.Error("openai provider not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToPicoClaw(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
|
||||
testConfig := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": {
|
||||
"primary": "anthropic/claude-sonnet-4-20250514"
|
||||
},
|
||||
"workspace": "~/.openclaw/workspace"
|
||||
},
|
||||
"list": [
|
||||
{
|
||||
"id": "main",
|
||||
"name": "Main Agent"
|
||||
},
|
||||
{
|
||||
"id": "assistant",
|
||||
"name": "Assistant",
|
||||
"skills": ["skill1", "skill2"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"botToken": "test-token",
|
||||
"allowFrom": ["user1", "user2"]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": false,
|
||||
"token": "discord-token"
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": true,
|
||||
"bridgeUrl": "http://localhost:3000"
|
||||
},
|
||||
"feishu": {
|
||||
"enabled": true,
|
||||
"appId": "app-id",
|
||||
"appSecret": "app-secret",
|
||||
"allowFrom": ["user3"]
|
||||
},
|
||||
"signal": {
|
||||
"enabled": true
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"api_key": "sk-ant-test"
|
||||
},
|
||||
"openai": {
|
||||
"api_key": "sk-test"
|
||||
}
|
||||
}
|
||||
},
|
||||
"skills": {
|
||||
"entries": {
|
||||
"skill1": {}
|
||||
}
|
||||
},
|
||||
"memory": {"enabled": true},
|
||||
"cron": {"enabled": true}
|
||||
}`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadOpenClawConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
picoCfg, warnings, err := cfg.ConvertToPicoClaw("")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert config: %v", err)
|
||||
}
|
||||
|
||||
if picoCfg.Agents.Defaults.ModelName != "claude-sonnet-4-20250514" {
|
||||
t.Errorf("expected model 'claude-sonnet-4-20250514', got '%s'", picoCfg.Agents.Defaults.ModelName)
|
||||
}
|
||||
if picoCfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
|
||||
t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", picoCfg.Agents.Defaults.Workspace)
|
||||
}
|
||||
|
||||
if len(picoCfg.Agents.List) != 2 {
|
||||
t.Errorf("expected 2 agents, got %d", len(picoCfg.Agents.List))
|
||||
}
|
||||
if picoCfg.Agents.List[0].ID != "main" {
|
||||
t.Errorf("expected first agent id 'main', got '%s'", picoCfg.Agents.List[0].ID)
|
||||
}
|
||||
if picoCfg.Agents.List[1].Skills == nil || len(picoCfg.Agents.List[1].Skills) != 2 {
|
||||
t.Errorf("expected 2 skills for assistant agent")
|
||||
}
|
||||
|
||||
if !picoCfg.Channels.Telegram.Enabled {
|
||||
t.Error("telegram should be enabled")
|
||||
}
|
||||
if picoCfg.Channels.Telegram.Token != "test-token" {
|
||||
t.Errorf("expected telegram token 'test-token', got '%s'", picoCfg.Channels.Telegram.Token)
|
||||
}
|
||||
|
||||
if picoCfg.Channels.WhatsApp.BridgeURL != "http://localhost:3000" {
|
||||
t.Errorf("expected whatsapp bridge URL 'http://localhost:3000', got '%s'", picoCfg.Channels.WhatsApp.BridgeURL)
|
||||
}
|
||||
|
||||
if picoCfg.Channels.Feishu.AppID != "app-id" {
|
||||
t.Errorf("expected feishu app ID 'app-id', got '%s'", picoCfg.Channels.Feishu.AppID)
|
||||
}
|
||||
|
||||
if len(picoCfg.ModelList) != 1 {
|
||||
t.Errorf("expected 1 model config (no models.json provided), got %d", len(picoCfg.ModelList))
|
||||
}
|
||||
|
||||
foundWarning := false
|
||||
for _, w := range warnings {
|
||||
if len(w) > 0 {
|
||||
foundWarning = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundWarning {
|
||||
t.Log("warnings should be generated for skills, memory, cron, and unsupported channels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToPicoClawWithQQAndDingTalk(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
|
||||
testConfig := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": {
|
||||
"primary": "anthropic/claude-sonnet-4-20250514"
|
||||
}
|
||||
}
|
||||
},
|
||||
"channels": {
|
||||
"qq": {
|
||||
"enabled": true,
|
||||
"appId": "qq-app-id",
|
||||
"appSecret": "qq-app-secret"
|
||||
},
|
||||
"dingtalk": {
|
||||
"enabled": true,
|
||||
"appId": "ding-app-id",
|
||||
"appSecret": "ding-app-secret"
|
||||
},
|
||||
"maixcam": {
|
||||
"enabled": true,
|
||||
"host": "192.168.1.100",
|
||||
"port": 9000
|
||||
},
|
||||
"slack": {
|
||||
"enabled": true,
|
||||
"botToken": "xoxb-test",
|
||||
"appToken": "xapp-test"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadOpenClawConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
picoCfg, _, err := cfg.ConvertToPicoClaw("")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert config: %v", err)
|
||||
}
|
||||
|
||||
if !picoCfg.Channels.QQ.Enabled {
|
||||
t.Error("qq should be enabled")
|
||||
}
|
||||
if picoCfg.Channels.QQ.AppID != "qq-app-id" {
|
||||
t.Errorf("expected qq app ID 'qq-app-id', got '%s'", picoCfg.Channels.QQ.AppID)
|
||||
}
|
||||
|
||||
if !picoCfg.Channels.DingTalk.Enabled {
|
||||
t.Error("dingtalk should be enabled")
|
||||
}
|
||||
if picoCfg.Channels.DingTalk.ClientID != "ding-app-id" {
|
||||
t.Errorf("expected dingtalk client ID 'ding-app-id', got '%s'", picoCfg.Channels.DingTalk.ClientID)
|
||||
}
|
||||
|
||||
if !picoCfg.Channels.MaixCam.Enabled {
|
||||
t.Error("maixcam should be enabled")
|
||||
}
|
||||
if picoCfg.Channels.MaixCam.Host != "192.168.1.100" {
|
||||
t.Errorf("expected maixcam host '192.168.1.100', got '%s'", picoCfg.Channels.MaixCam.Host)
|
||||
}
|
||||
if picoCfg.Channels.MaixCam.Port != 9000 {
|
||||
t.Errorf("expected maixcam port 9000, got %d", picoCfg.Channels.MaixCam.Port)
|
||||
}
|
||||
|
||||
if !picoCfg.Channels.Slack.Enabled {
|
||||
t.Error("slack should be enabled")
|
||||
}
|
||||
if picoCfg.Channels.Slack.BotToken != "xoxb-test" {
|
||||
t.Errorf("expected slack bot token 'xoxb-test', got '%s'", picoCfg.Channels.Slack.BotToken)
|
||||
}
|
||||
if picoCfg.Channels.Slack.AppToken != "xapp-test" {
|
||||
t.Errorf("expected slack app token 'xapp-test', got '%s'", picoCfg.Channels.Slack.AppToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenClawAgentModel(t *testing.T) {
|
||||
model := &OpenClawAgentModel{
|
||||
Primary: strPtr("anthropic/claude-3-opus"),
|
||||
Fallbacks: []string{"claude-3-sonnet", "claude-3-haiku"},
|
||||
}
|
||||
|
||||
primary := model.GetPrimary()
|
||||
if primary != "anthropic/claude-3-opus" {
|
||||
t.Errorf("expected primary 'anthropic/claude-3-opus', got '%s'", primary)
|
||||
}
|
||||
|
||||
fallbacks := model.GetFallbacks()
|
||||
if len(fallbacks) != 2 {
|
||||
t.Errorf("expected 2 fallbacks, got %d", len(fallbacks))
|
||||
}
|
||||
|
||||
model2 := &OpenClawAgentModel{
|
||||
Simple: "claude-3-opus",
|
||||
}
|
||||
|
||||
primary2 := model2.GetPrimary()
|
||||
if primary2 != "claude-3-opus" {
|
||||
t.Errorf("expected primary 'claude-3-opus' from Simple, got '%s'", primary2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelEnabled(t *testing.T) {
|
||||
cfg := &OpenClawConfig{
|
||||
Channels: &OpenClawChannels{
|
||||
Telegram: &OpenClawTelegramConfig{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
Discord: &OpenClawDiscordConfig{
|
||||
Enabled: boolPtr(false),
|
||||
},
|
||||
Slack: &OpenClawSlackConfig{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !cfg.IsChannelEnabled("telegram") {
|
||||
t.Error("telegram should be enabled")
|
||||
}
|
||||
if cfg.IsChannelEnabled("discord") {
|
||||
t.Error("discord should be disabled")
|
||||
}
|
||||
if !cfg.IsChannelEnabled("slack") {
|
||||
t.Error("slack should be enabled (explicitly set)")
|
||||
}
|
||||
if cfg.IsChannelEnabled("line") {
|
||||
t.Error("line should return false (not in switch cases)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultModel(t *testing.T) {
|
||||
cfg := &OpenClawConfig{
|
||||
Agents: &OpenClawAgents{
|
||||
Defaults: &OpenClawAgentDefaults{
|
||||
Model: &OpenClawAgentModel{
|
||||
Primary: strPtr("openai/gpt-4"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, model := cfg.GetDefaultModel()
|
||||
if provider != "openai" {
|
||||
t.Errorf("expected provider 'openai', got '%s'", provider)
|
||||
}
|
||||
if model != "gpt-4" {
|
||||
t.Errorf("expected model 'gpt-4', got '%s'", model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultModelWithNoDefaults(t *testing.T) {
|
||||
cfg := &OpenClawConfig{}
|
||||
|
||||
provider, model := cfg.GetDefaultModel()
|
||||
if provider != "anthropic" {
|
||||
t.Errorf("expected default provider 'anthropic', got '%s'", provider)
|
||||
}
|
||||
if model != "claude-sonnet-4-20250514" {
|
||||
t.Errorf("expected default model 'claude-sonnet-4-20250514', got '%s'", model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasFunctions(t *testing.T) {
|
||||
cfg := &OpenClawConfig{
|
||||
Skills: &OpenClawSkills{Entries: map[string]json.RawMessage{"skill1": nil}},
|
||||
Memory: json.RawMessage(`{"enabled": true}`),
|
||||
Cron: json.RawMessage(`{"enabled": true}`),
|
||||
Hooks: json.RawMessage(`{"enabled": true}`),
|
||||
Session: json.RawMessage(`{"enabled": true}`),
|
||||
Auth: &OpenClawAuth{Profiles: json.RawMessage(`{"profile1": {}}`)},
|
||||
}
|
||||
|
||||
if !cfg.HasSkills() {
|
||||
t.Error("should have skills")
|
||||
}
|
||||
if !cfg.HasMemory() {
|
||||
t.Error("should have memory")
|
||||
}
|
||||
if !cfg.HasCron() {
|
||||
t.Error("should have cron")
|
||||
}
|
||||
if !cfg.HasHooks() {
|
||||
t.Error("should have hooks")
|
||||
}
|
||||
if !cfg.HasSession() {
|
||||
t.Error("should have session")
|
||||
}
|
||||
if !cfg.HasAuthProfiles() {
|
||||
t.Error("should have auth profiles")
|
||||
}
|
||||
|
||||
cfg2 := &OpenClawConfig{}
|
||||
if cfg2.HasSkills() {
|
||||
t.Error("should not have skills")
|
||||
}
|
||||
if cfg2.HasMemory() {
|
||||
t.Error("should not have memory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOpenClawConfigFromDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
|
||||
testConfig := `{"agents": {}}`
|
||||
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadOpenClawConfigFromDir(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config from dir: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Agents == nil {
|
||||
t.Error("agents should not be nil")
|
||||
}
|
||||
|
||||
_, err = LoadOpenClawConfigFromDir("/nonexistent/dir")
|
||||
if err == nil {
|
||||
t.Error("should return error for nonexistent dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStandardConfig(t *testing.T) {
|
||||
picoCfg := &PicoClawConfig{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Provider: "anthropic",
|
||||
ModelName: "claude-sonnet-4-20250514",
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
},
|
||||
List: []AgentConfig{
|
||||
{
|
||||
ID: "main",
|
||||
Name: "Main Agent",
|
||||
Default: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
ModelList: []ModelConfig{
|
||||
{
|
||||
ModelName: "claude-sonnet-4-20250514",
|
||||
Model: "anthropic/claude-sonnet-4-20250514",
|
||||
APIKey: "sk-ant-test",
|
||||
},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
Telegram: TelegramConfig{
|
||||
Enabled: true,
|
||||
Token: "test-token",
|
||||
AllowFrom: []string{"user1"},
|
||||
},
|
||||
WhatsApp: WhatsAppConfig{
|
||||
Enabled: true,
|
||||
BridgeURL: "http://localhost:3000",
|
||||
},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
},
|
||||
}
|
||||
|
||||
stdCfg := picoCfg.ToStandardConfig()
|
||||
|
||||
if stdCfg.Agents.Defaults.Provider != "anthropic" {
|
||||
t.Errorf("expected provider 'anthropic', got '%s'", stdCfg.Agents.Defaults.Provider)
|
||||
}
|
||||
if stdCfg.Agents.Defaults.ModelName != "claude-sonnet-4-20250514" {
|
||||
t.Errorf("expected model name 'claude-sonnet-4-20250514', got '%s'", stdCfg.Agents.Defaults.ModelName)
|
||||
}
|
||||
if stdCfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
|
||||
t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", stdCfg.Agents.Defaults.Workspace)
|
||||
}
|
||||
|
||||
if len(stdCfg.Agents.List) != 1 {
|
||||
t.Errorf("expected 1 agent, got %d", len(stdCfg.Agents.List))
|
||||
}
|
||||
if stdCfg.Agents.List[0].ID != "main" {
|
||||
t.Errorf("expected agent id 'main', got '%s'", stdCfg.Agents.List[0].ID)
|
||||
}
|
||||
|
||||
foundModel := false
|
||||
var foundAPIKey string
|
||||
for _, m := range stdCfg.ModelList {
|
||||
if m.ModelName == "claude-sonnet-4-20250514" {
|
||||
foundModel = true
|
||||
foundAPIKey = m.APIKey
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundModel {
|
||||
t.Error("expected to find claude-sonnet-4-20250514 model config")
|
||||
}
|
||||
if foundAPIKey != "sk-ant-test" {
|
||||
t.Errorf("expected api key 'sk-ant-test', got '%s'", foundAPIKey)
|
||||
}
|
||||
|
||||
if !stdCfg.Channels.Telegram.Enabled {
|
||||
t.Error("telegram should be enabled")
|
||||
}
|
||||
if stdCfg.Channels.Telegram.Token != "test-token" {
|
||||
t.Errorf("expected token 'test-token', got '%s'", stdCfg.Channels.Telegram.Token)
|
||||
}
|
||||
|
||||
if stdCfg.Gateway.Port != 8080 {
|
||||
t.Errorf("expected gateway port 8080, got %d", stdCfg.Gateway.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProviderConfigFromAgentsDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
agentsDir := filepath.Join(tmpDir, "agents", "main", "agent")
|
||||
err := os.MkdirAll(agentsDir, 0o755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create agents dir: %v", err)
|
||||
}
|
||||
|
||||
modelsJSON := `{
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"baseUrl": "https://api.anthropic.com",
|
||||
"api": "anthropic",
|
||||
"apiKey": "sk-ant-from-models",
|
||||
"models": [
|
||||
{
|
||||
"id": "claude-sonnet-4-20250514",
|
||||
"name": "Claude Sonnet 4"
|
||||
}
|
||||
]
|
||||
},
|
||||
"openai": {
|
||||
"baseUrl": "https://api.openai.com",
|
||||
"api": "openai",
|
||||
"apiKey": "sk-from-models",
|
||||
"models": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"name": "GPT-4o"
|
||||
}
|
||||
]
|
||||
},
|
||||
"zhipu": {
|
||||
"baseUrl": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"api": "openai",
|
||||
"apiKey": "zhipu-key",
|
||||
"models": []
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
err = os.WriteFile(filepath.Join(agentsDir, "models.json"), []byte(modelsJSON), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write models.json: %v", err)
|
||||
}
|
||||
|
||||
providers := GetProviderConfigFromDir(tmpDir)
|
||||
if len(providers) != 3 {
|
||||
t.Errorf("expected 3 providers, got %d", len(providers))
|
||||
}
|
||||
|
||||
if anthropic, ok := providers["anthropic"]; ok {
|
||||
if anthropic.ApiKey != "sk-ant-from-models" {
|
||||
t.Errorf("expected anthropic apiKey 'sk-ant-from-models', got '%s'", anthropic.ApiKey)
|
||||
}
|
||||
if anthropic.BaseUrl != "https://api.anthropic.com" {
|
||||
t.Errorf("expected anthropic baseUrl 'https://api.anthropic.com', got '%s'", anthropic.BaseUrl)
|
||||
}
|
||||
} else {
|
||||
t.Error("anthropic provider not found")
|
||||
}
|
||||
|
||||
if openai, ok := providers["openai"]; ok {
|
||||
if openai.ApiKey != "sk-from-models" {
|
||||
t.Errorf("expected openai apiKey 'sk-from-models', got '%s'", openai.ApiKey)
|
||||
}
|
||||
if openai.BaseUrl != "https://api.openai.com" {
|
||||
t.Errorf("expected openai baseUrl 'https://api.openai.com', got '%s'", openai.BaseUrl)
|
||||
}
|
||||
} else {
|
||||
t.Error("openai provider not found")
|
||||
}
|
||||
|
||||
if zhipu, ok := providers["zhipu"]; ok {
|
||||
if zhipu.ApiKey != "zhipu-key" {
|
||||
t.Errorf("expected zhipu apiKey 'zhipu-key', got '%s'", zhipu.ApiKey)
|
||||
}
|
||||
if zhipu.BaseUrl != "https://open.bigmodel.cn/api/paas/v4" {
|
||||
t.Errorf("expected zhipu baseUrl 'https://open.bigmodel.cn/api/paas/v4', got '%s'", zhipu.BaseUrl)
|
||||
}
|
||||
} else {
|
||||
t.Error("zhipu provider not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProviderConfigFromDirNotExist(t *testing.T) {
|
||||
providers := GetProviderConfigFromDir("/nonexistent/path")
|
||||
if len(providers) != 0 {
|
||||
t.Errorf("expected 0 providers for nonexistent path, got %d", len(providers))
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package openclaw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/migrate/internal"
|
||||
)
|
||||
|
||||
var providerMapping = map[string]string{
|
||||
"anthropic": "anthropic",
|
||||
"claude": "anthropic",
|
||||
"openai": "openai",
|
||||
"gpt": "openai",
|
||||
"groq": "groq",
|
||||
"ollama": "ollama",
|
||||
"openrouter": "openrouter",
|
||||
"deepseek": "deepseek",
|
||||
"together": "together",
|
||||
"mistral": "mistral",
|
||||
"fireworks": "fireworks",
|
||||
"google": "google",
|
||||
"gemini": "google",
|
||||
"xai": "xai",
|
||||
"grok": "xai",
|
||||
"cerebras": "cerebras",
|
||||
"sambanova": "sambanova",
|
||||
}
|
||||
|
||||
type OpenclawHandler struct {
|
||||
opts Options
|
||||
sourceConfigFile string
|
||||
sourceWorkspace string
|
||||
}
|
||||
|
||||
type (
|
||||
Options = internal.Options
|
||||
Action = internal.Action
|
||||
Result = internal.Result
|
||||
Operation = internal.Operation
|
||||
)
|
||||
|
||||
func NewOpenclawHandler(opts Options) (Operation, error) {
|
||||
home, err := resolveSourceHome(opts.SourceHome)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts.SourceHome = home
|
||||
|
||||
configFile, err := findSourceConfig(home)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OpenclawHandler{
|
||||
opts: opts,
|
||||
sourceWorkspace: filepath.Join(opts.SourceHome, "workspace"),
|
||||
sourceConfigFile: configFile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OpenclawHandler) GetSourceName() string {
|
||||
return "openclaw"
|
||||
}
|
||||
|
||||
func (o *OpenclawHandler) GetSourceHome() (string, error) {
|
||||
return o.opts.SourceHome, nil
|
||||
}
|
||||
|
||||
func (o *OpenclawHandler) GetSourceWorkspace() (string, error) {
|
||||
return o.sourceWorkspace, nil
|
||||
}
|
||||
|
||||
func (o *OpenclawHandler) GetSourceConfigFile() (string, error) {
|
||||
return o.sourceConfigFile, nil
|
||||
}
|
||||
|
||||
func (o *OpenclawHandler) GetMigrateableFiles() []string {
|
||||
return migrateableFiles
|
||||
}
|
||||
|
||||
func (o *OpenclawHandler) GetMigrateableDirs() []string {
|
||||
return migrateableDirs
|
||||
}
|
||||
|
||||
func (o *OpenclawHandler) ExecuteConfigMigration(srcConfigPath, dstConfigPath string) error {
|
||||
openclawCfg, err := LoadOpenClawConfig(srcConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
picoCfg, warnings, err := openclawCfg.ConvertToPicoClaw(o.opts.SourceHome)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, w := range warnings {
|
||||
fmt.Printf(" Warning: %s\n", w)
|
||||
}
|
||||
|
||||
incoming := picoCfg.ToStandardConfig()
|
||||
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return config.SaveConfig(dstConfigPath, incoming)
|
||||
}
|
||||
|
||||
func resolveSourceHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return internal.ExpandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
|
||||
return internal.ExpandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".openclaw"), nil
|
||||
}
|
||||
|
||||
func findSourceConfig(sourceHome string) (string, error) {
|
||||
candidates := []string{
|
||||
filepath.Join(sourceHome, "openclaw.json"),
|
||||
filepath.Join(sourceHome, "config.json"),
|
||||
}
|
||||
for _, p := range candidates {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", sourceHome)
|
||||
}
|
||||
|
||||
func rewriteWorkspacePath(path string) string {
|
||||
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
|
||||
return path
|
||||
}
|
||||
|
||||
func mapProvider(provider string) string {
|
||||
if mapped, ok := providerMapping[strings.ToLower(provider)]; ok {
|
||||
return mapped
|
||||
}
|
||||
return strings.ToLower(provider)
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package openclaw
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewOpenclawHandler(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, handler)
|
||||
}
|
||||
|
||||
func TestNewOpenclawHandlerNoConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOpenclawHandlerGetSourceName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "openclaw", handler.GetSourceName())
|
||||
}
|
||||
|
||||
func TestOpenclawHandlerGetSourceHome(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
home, err := handler.GetSourceHome()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tmpDir, home)
|
||||
}
|
||||
|
||||
func TestOpenclawHandlerGetSourceWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
workspace, err := handler.GetSourceWorkspace()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpDir, "workspace"), workspace)
|
||||
}
|
||||
|
||||
func TestOpenclawHandlerGetSourceConfigFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
configFile, err := handler.GetSourceConfigFile()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, configPath, configFile)
|
||||
}
|
||||
|
||||
func TestOpenclawHandlerGetSourceConfigFileWithConfigJson(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
configFile, err := handler.GetSourceConfigFile()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, configPath, configFile)
|
||||
}
|
||||
|
||||
func TestOpenclawHandlerGetMigrateableFiles(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
files := handler.GetMigrateableFiles()
|
||||
assert.NotEmpty(t, files)
|
||||
assert.Contains(t, files, "AGENTS.md")
|
||||
assert.Contains(t, files, "SOUL.md")
|
||||
assert.Contains(t, files, "USER.md")
|
||||
}
|
||||
|
||||
func TestOpenclawHandlerGetMigrateableDirs(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, err := NewOpenclawHandler(Options{
|
||||
SourceHome: tmpDir,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dirs := handler.GetMigrateableDirs()
|
||||
assert.NotEmpty(t, dirs)
|
||||
assert.Contains(t, dirs, "memory")
|
||||
assert.Contains(t, dirs, "skills")
|
||||
}
|
||||
|
||||
func TestResolveSourceHome(t *testing.T) {
|
||||
result, err := resolveSourceHome("/custom/path")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/custom/path", result)
|
||||
}
|
||||
|
||||
func TestResolveSourceHomeWithEnvVar(t *testing.T) {
|
||||
t.Setenv("OPENCLAW_HOME", "/env/path")
|
||||
|
||||
result, err := resolveSourceHome("")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/env/path", result)
|
||||
}
|
||||
|
||||
func TestResolveSourceHomeWithTilde(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := resolveSourceHome("~/openclaw")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(home, "openclaw"), result)
|
||||
}
|
||||
|
||||
func TestFindSourceConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := findSourceConfig(tmpDir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, configPath, result)
|
||||
}
|
||||
|
||||
func TestFindSourceConfigWithConfigJson(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
err := os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := findSourceConfig(tmpDir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, configPath, result)
|
||||
}
|
||||
|
||||
func TestFindSourceConfigNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, err := findSourceConfig(tmpDir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no config file found")
|
||||
}
|
||||
|
||||
func TestMapProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"anthropic", "anthropic"},
|
||||
{"claude", "anthropic"},
|
||||
{"openai", "openai"},
|
||||
{"gpt", "openai"},
|
||||
{"groq", "groq"},
|
||||
{"ollama", "ollama"},
|
||||
{"openrouter", "openrouter"},
|
||||
{"deepseek", "deepseek"},
|
||||
{"together", "together"},
|
||||
{"mistral", "mistral"},
|
||||
{"fireworks", "fireworks"},
|
||||
{"google", "google"},
|
||||
{"gemini", "google"},
|
||||
{"xai", "xai"},
|
||||
{"grok", "xai"},
|
||||
{"cerebras", "cerebras"},
|
||||
{"sambanova", "sambanova"},
|
||||
{"unknown", "unknown"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := mapProvider(tt.input)
|
||||
assert.Equal(t, tt.expected, result, "mapProvider(%q)", tt.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteWorkspacePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"~/.openclaw/workspace", "~/.picoclaw/workspace"},
|
||||
{"/home/user/.openclaw/workspace", "/home/user/.picoclaw/workspace"},
|
||||
{"/path/without/openclaw/change", "/path/without/openclaw/change"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := rewriteWorkspacePath(tt.input)
|
||||
assert.Equal(t, tt.expected, result, "rewriteWorkspacePath(%q)", tt.input)
|
||||
}
|
||||
}
|
||||
@@ -212,14 +212,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
}
|
||||
|
||||
func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content string
|
||||
var content strings.Builder
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content += tb.Text
|
||||
content.WriteString(tb.Text)
|
||||
case "tool_use":
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]any
|
||||
@@ -246,7 +246,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content,
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
@@ -264,8 +264,8 @@ func normalizeBaseURL(apiBase string) string {
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if strings.HasSuffix(base, "/v1") {
|
||||
base = strings.TrimSuffix(base, "/v1")
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
|
||||
@@ -163,8 +163,8 @@ func resolveCodexModel(model string) (string, string) {
|
||||
return codexDefaultModel, "empty model"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "openai/") {
|
||||
m = strings.TrimPrefix(m, "openai/")
|
||||
if after, ok := strings.CutPrefix(m, "openai/"); ok {
|
||||
m = after
|
||||
} else if strings.Contains(m, "/") {
|
||||
return codexDefaultModel, "non-openai model namespace"
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func TestCooldown_FailureWindowReset(t *testing.T) {
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 4 errors → 1h cooldown
|
||||
for i := 0; i < 4; i++ {
|
||||
for range 4 {
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = current.Add(2 * time.Second) // small advance between errors
|
||||
}
|
||||
@@ -230,7 +230,7 @@ func TestCooldown_ConcurrentAccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
for range 100 {
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -6,6 +6,13 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Common patterns in Go HTTP error messages
|
||||
var httpStatusPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`status[:\s]+(\d{3})`),
|
||||
regexp.MustCompile(`http[/\s]+\d*\.?\d*\s+(\d{3})`),
|
||||
regexp.MustCompile(`\b([3-5]\d{2})\b`),
|
||||
}
|
||||
|
||||
// errorPattern defines a single pattern (string or regex) for error classification.
|
||||
type errorPattern struct {
|
||||
substring string
|
||||
@@ -198,20 +205,13 @@ func classifyByMessage(msg string) FailoverReason {
|
||||
}
|
||||
|
||||
// extractHTTPStatus extracts an HTTP status code from an error message.
|
||||
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
|
||||
// Looks for patterns like "status: 429", "status 429", "http/1.1 429", "http 429", or standalone "429".
|
||||
func extractHTTPStatus(msg string) int {
|
||||
// Common patterns in Go HTTP error messages
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`status[:\s]+(\d{3})`),
|
||||
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
for _, p := range httpStatusPatterns {
|
||||
if m := p.FindStringSubmatch(msg); len(m) > 1 {
|
||||
return parseDigits(m[1])
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
@@ -305,7 +305,8 @@ func TestExtractHTTPStatus(t *testing.T) {
|
||||
}{
|
||||
{"status: 429 rate limited", 429},
|
||||
{"status 401 unauthorized", 401},
|
||||
{"HTTP/1.1 502 Bad Gateway", 502},
|
||||
{"http/1.1 502 bad gateway", 502},
|
||||
{"error 429", 429},
|
||||
{"no status code here", 0},
|
||||
{"random number 12345", 0},
|
||||
}
|
||||
|
||||
@@ -26,8 +26,9 @@ func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*Gi
|
||||
|
||||
switch connectMode {
|
||||
case "stdio":
|
||||
// TODO:
|
||||
return nil, fmt.Errorf("stdio mode not implemented")
|
||||
// TODO: Implement stdio mode for GitHub Copilot provider
|
||||
// See https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md for details
|
||||
return nil, fmt.Errorf("stdio mode not implemented for GitHub Copilot provider; please use 'grpc' mode instead")
|
||||
case "grpc":
|
||||
client := copilot.NewClient(&copilot.ClientOptions{
|
||||
CLIUrl: uri,
|
||||
@@ -100,9 +101,12 @@ func (p *GitHubCopilotProvider) Chat(
|
||||
return nil, fmt.Errorf("provider closed")
|
||||
}
|
||||
|
||||
resp, _ := session.SendAndWait(ctx, copilot.MessageOptions{
|
||||
resp, err := session.SendAndWait(ctx, copilot.MessageOptions{
|
||||
Prompt: string(fullcontent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send message to copilot: %w", err)
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("empty response from copilot")
|
||||
|
||||
@@ -312,8 +312,8 @@ func stripSystemParts(messages []Message) []openaiMessage {
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
idx := strings.Index(model, "/")
|
||||
if idx == -1 {
|
||||
before, after, ok := strings.Cut(model, "/")
|
||||
if !ok {
|
||||
return model
|
||||
}
|
||||
|
||||
@@ -321,10 +321,10 @@ func normalizeModel(model, apiBase string) string {
|
||||
return model
|
||||
}
|
||||
|
||||
prefix := strings.ToLower(model[:idx])
|
||||
prefix := strings.ToLower(before)
|
||||
switch prefix {
|
||||
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
||||
return model[idx+1:]
|
||||
return after
|
||||
default:
|
||||
return model
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeAgentID_Empty(t *testing.T) {
|
||||
if got := NormalizeAgentID(""); got != DefaultAgentID {
|
||||
@@ -57,11 +60,11 @@ func TestNormalizeAgentID_AllInvalid(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
|
||||
long := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
long += "a"
|
||||
var long strings.Builder
|
||||
for range 100 {
|
||||
long.WriteString("a")
|
||||
}
|
||||
got := NormalizeAgentID(long)
|
||||
got := NormalizeAgentID(long.String())
|
||||
if len(got) > MaxAgentIDLength {
|
||||
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -18,14 +17,6 @@ type SkillInstaller struct {
|
||||
workspace string
|
||||
}
|
||||
|
||||
type AvailableSkill struct {
|
||||
Name string `json:"name"`
|
||||
Repository string `json:"repository"`
|
||||
Description string `json:"description"`
|
||||
Author string `json:"author"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
func NewSkillInstaller(workspace string) *SkillInstaller {
|
||||
return &SkillInstaller{
|
||||
workspace: workspace,
|
||||
@@ -89,35 +80,3 @@ func (si *SkillInstaller) Uninstall(skillName string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableSkill, error) {
|
||||
url := "https://raw.githubusercontent.com/sipeed/picoclaw-skills/main/skills.json"
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := utils.DoRequestWithRetry(client, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch skills list: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed to fetch skills list: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var skills []AvailableSkill
|
||||
if err := json.Unmarshal(body, &skills); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse skills list: %w", err)
|
||||
}
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalized = strings.ReplaceAll(normalized, "\r", "\n")
|
||||
|
||||
for _, line := range strings.Split(normalized, "\n") {
|
||||
for line := range strings.SplitSeq(normalized, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -183,7 +183,7 @@ func buildTrigrams(s string) []uint32 {
|
||||
}
|
||||
|
||||
// Sort and Deduplication
|
||||
sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
|
||||
slices.Sort(trigrams)
|
||||
n := 1
|
||||
for i := 1; i < len(trigrams); i++ {
|
||||
if trigrams[i] != trigrams[i-1] {
|
||||
|
||||
@@ -153,7 +153,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
|
||||
|
||||
// Concurrent writes
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
|
||||
}
|
||||
done <- struct{}{}
|
||||
@@ -161,7 +161,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
|
||||
|
||||
// Concurrent reads
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
for range 100 {
|
||||
cache.Get("query-write-a")
|
||||
}
|
||||
done <- struct{}{}
|
||||
|
||||
@@ -135,7 +135,7 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
|
||||
// Test concurrent writes
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
go func(idx int) {
|
||||
channel := fmt.Sprintf("channel-%d", idx)
|
||||
sm.SetLastChannel(channel)
|
||||
@@ -144,7 +144,7 @@ func TestConcurrentAccess(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
<-done
|
||||
}
|
||||
|
||||
|
||||
+12
-6
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -33,15 +34,19 @@ type CronTool struct {
|
||||
func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) *CronTool {
|
||||
execTool := NewExecToolWithConfig(workspace, restrict, config)
|
||||
) (*CronTool, error) {
|
||||
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
}
|
||||
|
||||
execTool.SetTimeout(execTimeout)
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the tool name
|
||||
@@ -218,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
return SilentResult("No scheduled jobs")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
var result strings.Builder
|
||||
result.WriteString("Scheduled jobs:\n")
|
||||
for _, j := range jobs {
|
||||
var scheduleInfo string
|
||||
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||
@@ -230,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
|
||||
} else {
|
||||
scheduleInfo = "unknown"
|
||||
}
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
|
||||
}
|
||||
|
||||
return SilentResult(result)
|
||||
return SilentResult(result.String())
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
|
||||
|
||||
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
for i := range 50 {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
|
||||
+57
-51
@@ -24,56 +24,64 @@ type ExecTool struct {
|
||||
restrictToWorkspace bool
|
||||
}
|
||||
|
||||
var defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
var (
|
||||
defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
// Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(
|
||||
`\b(format|mkfs|diskpart)\b\s`,
|
||||
),
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) *ExecTool {
|
||||
// absolutePathPattern matches absolute file paths in commands (Unix and Windows).
|
||||
absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
|
||||
if config != nil {
|
||||
@@ -86,8 +94,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
for _, pattern := range execConfig.CustomDenyPatterns {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err)
|
||||
continue
|
||||
return nil, fmt.Errorf("invalid custom deny pattern %q: %w", pattern, err)
|
||||
}
|
||||
denyPatterns = append(denyPatterns, re)
|
||||
}
|
||||
@@ -106,7 +113,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
restrictToWorkspace: restrict,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *ExecTool) Name() string {
|
||||
@@ -288,8 +295,7 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
pathPattern := regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
|
||||
matches := pathPattern.FindAllString(cmd, -1)
|
||||
matches := absolutePathPattern.FindAllString(cmd, -1)
|
||||
|
||||
for _, raw := range matches {
|
||||
p, err := filepath.Abs(raw)
|
||||
|
||||
+48
-11
@@ -11,7 +11,10 @@ import (
|
||||
|
||||
// TestShellTool_Success verifies successful command execution
|
||||
func TestShellTool_Success(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -38,7 +41,10 @@ func TestShellTool_Success(t *testing.T) {
|
||||
|
||||
// TestShellTool_Failure verifies failed command execution
|
||||
func TestShellTool_Failure(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -65,7 +71,11 @@ func TestShellTool_Failure(t *testing.T) {
|
||||
|
||||
// TestShellTool_Timeout verifies command timeout handling
|
||||
func TestShellTool_Timeout(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
tool.SetTimeout(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -93,7 +103,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0o644)
|
||||
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -114,7 +127,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
|
||||
|
||||
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
|
||||
func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -135,7 +151,10 @@ func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
|
||||
// TestShellTool_MissingCommand verifies error handling for missing command
|
||||
func TestShellTool_MissingCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
@@ -150,7 +169,10 @@ func TestShellTool_MissingCommand(t *testing.T) {
|
||||
|
||||
// TestShellTool_StderrCapture verifies stderr is captured and included
|
||||
func TestShellTool_StderrCapture(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -170,7 +192,10 @@ func TestShellTool_StderrCapture(t *testing.T) {
|
||||
|
||||
// TestShellTool_OutputTruncation verifies long output is truncated
|
||||
func TestShellTool_OutputTruncation(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
// Generate long output (>10000 chars)
|
||||
@@ -198,7 +223,11 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
|
||||
t.Fatalf("failed to create outside dir: %v", err)
|
||||
}
|
||||
|
||||
tool := NewExecTool(workspace, true)
|
||||
tool, err := NewExecTool(workspace, true)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "pwd",
|
||||
"working_dir": outsideDir,
|
||||
@@ -232,7 +261,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
|
||||
t.Skipf("symlinks not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
tool := NewExecTool(workspace, true)
|
||||
tool, err := NewExecTool(workspace, true)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"command": "cat secret.txt",
|
||||
"working_dir": link,
|
||||
@@ -249,7 +282,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
|
||||
// TestShellTool_RestrictToWorkspace verifies workspace restriction
|
||||
func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool := NewExecTool(tmpDir, false)
|
||||
tool, err := NewExecTool(tmpDir, false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
tool.SetRestrictToWorkspace(true)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -22,7 +22,11 @@ func processExists(pid int) bool {
|
||||
}
|
||||
|
||||
func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
|
||||
tool := NewExecTool(t.TempDir(), false)
|
||||
tool, err := NewExecTool(t.TempDir(), false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
tool.SetTimeout(500 * time.Millisecond)
|
||||
|
||||
args := map[string]any{
|
||||
|
||||
+58
-48
@@ -16,6 +16,14 @@ import (
|
||||
|
||||
const (
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
|
||||
// HTTP client timeouts for web tool providers.
|
||||
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
|
||||
perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
|
||||
fetchTimeout = 60 * time.Second // WebFetchTool
|
||||
|
||||
defaultMaxChars = 50000
|
||||
maxRedirects = 5
|
||||
)
|
||||
|
||||
// Pre-compiled regexes for HTML text extraction
|
||||
@@ -75,6 +83,7 @@ type SearchProvider interface {
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -89,11 +98,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -144,6 +149,7 @@ type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -175,11 +181,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -227,7 +229,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct {
|
||||
proxy string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -240,11 +243,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 10*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -286,7 +285,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
|
||||
maxItems := min(len(matches), count)
|
||||
|
||||
for i := 0; i < maxItems; i++ {
|
||||
for i := range maxItems {
|
||||
urlStr := matches[i][1]
|
||||
title := stripTags(matches[i][2])
|
||||
title = strings.TrimSpace(title)
|
||||
@@ -294,9 +293,9 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
|
||||
// URL decoding if needed
|
||||
if strings.Contains(urlStr, "uddg=") {
|
||||
if u, err := url.QueryUnescape(urlStr); err == nil {
|
||||
idx := strings.Index(u, "uddg=")
|
||||
if idx != -1 {
|
||||
urlStr = u[idx+5:]
|
||||
_, after, ok := strings.Cut(u, "uddg=")
|
||||
if ok {
|
||||
urlStr = after
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -323,6 +322,7 @@ func stripTags(content string) string {
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -357,11 +357,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(p.proxy, 30*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
@@ -416,43 +412,60 @@ type WebSearchToolOptions struct {
|
||||
Proxy string
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
|
||||
}
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
|
||||
}
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
|
||||
}
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.TavilyMaxResults > 0 {
|
||||
maxResults = opts.TavilyMaxResults
|
||||
}
|
||||
} else if opts.DuckDuckGoEnabled {
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
|
||||
}
|
||||
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
maxResults = opts.DuckDuckGoMaxResults
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &WebSearchTool{
|
||||
provider: provider,
|
||||
maxResults: maxResults,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Name() string {
|
||||
@@ -527,7 +540,17 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) *WebFetchTool {
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) *WebFetchTool {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
client, err := createHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if fetchLimitBytes <= 0 {
|
||||
fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback
|
||||
@@ -598,20 +621,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client, err := createHTTPClient(t.proxy, 60*time.Second)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
|
||||
}
|
||||
|
||||
// Configure redirect handling
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 5 {
|
||||
return fmt.Errorf("stopped after 5 redirects")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
@@ -669,14 +679,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
ForLLM: string(resultJSON),
|
||||
ForUser: fmt.Sprintf(
|
||||
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
|
||||
len(text),
|
||||
urlStr,
|
||||
extractor,
|
||||
truncated,
|
||||
),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+45
-24
@@ -36,14 +36,14 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain the fetched content
|
||||
if !strings.Contains(result.ForUser, "Test Page") {
|
||||
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
|
||||
// ForLLM should contain the fetched content (full JSON result)
|
||||
if !strings.Contains(result.ForLLM, "Test Page") {
|
||||
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should contain summary
|
||||
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
|
||||
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
|
||||
// ForUser should contain summary
|
||||
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
||||
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,9 +72,9 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain formatted JSON
|
||||
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
|
||||
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
|
||||
// ForLLM should contain formatted JSON
|
||||
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
|
||||
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,9 +163,9 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain truncated content (not the full 20000 chars)
|
||||
// ForLLM should contain truncated content (not the full 20000 chars)
|
||||
resultMap := make(map[string]any)
|
||||
json.Unmarshal([]byte(result.ForUser), &resultMap)
|
||||
json.Unmarshal([]byte(result.ForLLM), &resultMap)
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if len(text) > 1100 { // Allow some margin
|
||||
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
||||
@@ -220,13 +220,19 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when Brave API key is empty")
|
||||
}
|
||||
|
||||
// Also nil when nothing is enabled
|
||||
tool = NewWebSearchTool(WebSearchToolOptions{})
|
||||
tool, err = NewWebSearchTool(WebSearchToolOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if tool != nil {
|
||||
t.Errorf("Expected nil tool when no provider is enabled")
|
||||
}
|
||||
@@ -234,7 +240,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
args := map[string]any{}
|
||||
|
||||
@@ -272,14 +281,14 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
|
||||
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
|
||||
// ForLLM should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
|
||||
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should NOT contain script or style tags
|
||||
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
|
||||
// Should NOT contain script or style tags in ForLLM
|
||||
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,12 +507,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "k",
|
||||
PerplexityMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*PerplexitySearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
|
||||
@@ -514,12 +526,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("brave", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "k",
|
||||
BraveMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*BraveSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
|
||||
@@ -530,11 +545,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("duckduckgo", func(t *testing.T) {
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
DuckDuckGoEnabled: true,
|
||||
DuckDuckGoMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
|
||||
@@ -586,12 +604,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
|
||||
@@ -37,6 +37,9 @@ func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response,
|
||||
|
||||
if i < maxRetries-1 {
|
||||
if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil {
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to sleep: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -77,6 +80,91 @@ func TestDoRequestWithRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoRequestWithRetry_ContextCancel(t *testing.T) {
|
||||
// Use a long retry delay so cancellation always hits during sleepWithCtx.
|
||||
retryDelayUnit = 10 * time.Second
|
||||
t.Cleanup(func() { retryDelayUnit = time.Second })
|
||||
|
||||
bodyClosed := false
|
||||
firstRoundTripDone := make(chan struct{}, 1)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("error"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := server.Client()
|
||||
client.Timeout = 30 * time.Second
|
||||
client.Transport = &bodyCloseTracker{
|
||||
rt: client.Transport,
|
||||
onClose: func() { bodyClosed = true },
|
||||
// Signal after the first round-trip response is fully constructed on the client side.
|
||||
onRoundTrip: func() {
|
||||
select {
|
||||
case firstRoundTripDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
trackURL: server.URL,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Cancel the context after the first round-trip completes on the client side.
|
||||
// This ensures client.Do has returned a valid resp (with body) and the retry
|
||||
// loop is about to enter sleepWithCtx, where the cancel will be detected.
|
||||
go func() {
|
||||
<-firstRoundTripDone
|
||||
cancel()
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := DoRequestWithRetry(client, req)
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
require.Error(t, err, "expected error from context cancellation")
|
||||
assert.Nil(t, resp, "expected nil response when context is canceled")
|
||||
assert.True(t, bodyClosed, "expected resp.Body to be closed on context cancellation")
|
||||
}
|
||||
|
||||
// bodyCloseTracker wraps an http.RoundTripper and records when response bodies are closed.
|
||||
type bodyCloseTracker struct {
|
||||
rt http.RoundTripper
|
||||
onClose func()
|
||||
onRoundTrip func() // called after each successful round-trip
|
||||
trackURL string
|
||||
}
|
||||
|
||||
func (t *bodyCloseTracker) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := t.rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if strings.HasPrefix(req.URL.String(), t.trackURL) {
|
||||
resp.Body = &closeNotifier{ReadCloser: resp.Body, onClose: t.onClose}
|
||||
if t.onRoundTrip != nil {
|
||||
t.onRoundTrip()
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// closeNotifier wraps an io.ReadCloser to detect Close calls.
|
||||
type closeNotifier struct {
|
||||
io.ReadCloser
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func (c *closeNotifier) Close() error {
|
||||
c.onClose()
|
||||
return c.ReadCloser.Close()
|
||||
}
|
||||
|
||||
func TestDoRequestWithRetry_Delay(t *testing.T) {
|
||||
retryDelayUnit = time.Millisecond
|
||||
t.Cleanup(func() { retryDelayUnit = time.Second })
|
||||
|
||||
Reference in New Issue
Block a user