Merge branch 'main' into fix/max-payload-size-in-web-fetch

This commit is contained in:
Mauro
2026-03-01 22:38:16 +01:00
committed by GitHub
113 changed files with 10380 additions and 2018 deletions
+3 -4
View File
@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"time"
@@ -249,10 +250,8 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
}
// Check tracked source files (bootstrap + memory).
for _, p := range cb.sourcePaths() {
if cb.fileChangedSince(p) {
return true
}
if slices.ContainsFunc(cb.sourcePaths(), cb.fileChangedSince) {
return true
}
// --- Skills directory (handled separately from sourcePaths) ---
+2 -2
View File
@@ -404,11 +404,11 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
var wg sync.WaitGroup
errs := make(chan string, goroutines*iterations)
for g := 0; g < goroutines; g++ {
for g := range goroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
for i := 0; i < iterations; i++ {
for i := range iterations {
result := cb.BuildSystemPromptWithCache()
if result == "" {
errs <- "empty prompt returned"
+7 -1
View File
@@ -1,6 +1,7 @@
package agent
import (
"log"
"os"
"path/filepath"
"strings"
@@ -51,7 +52,12 @@ func NewAgentInstance(
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
+67 -8
View File
@@ -9,6 +9,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"path/filepath"
"strings"
@@ -98,7 +99,7 @@ func registerSharedTools(
}
// Web tools
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
@@ -112,10 +113,18 @@ func registerSharedTools(
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
Proxy: cfg.Tools.Web.Proxy,
}); searchTool != nil {
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
}
agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes))
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
}
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
@@ -574,11 +583,36 @@ func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, chan
return
}
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
// Use a short timeout so the goroutine does not block indefinitely when
// the outbound bus is full. Reasoning output is best-effort; dropping it
// is acceptable to avoid goroutine accumulation.
pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
defer pubCancel()
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channelName,
ChatID: channelID,
Content: reasoningContent,
})
}); err != nil {
// Treat context.DeadlineExceeded / context.Canceled as expected
// (bus full under load, or parent canceled). Check the error
// itself rather than ctx.Err(), because pubCtx may time out
// (5 s) while the parent ctx is still active.
// Also treat ErrBusClosed as expected — it occurs during normal
// shutdown when the bus is closed before all goroutines finish.
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) ||
errors.Is(err, bus.ErrBusClosed) {
logger.DebugCF("agent", "Reasoning publish skipped (timeout/cancel)", map[string]any{
"channel": channelName,
"error": err.Error(),
})
} else {
logger.WarnCF("agent", "Failed to publish reasoning (best-effort)", map[string]any{
"channel": channelName,
"error": err.Error(),
})
}
}
}
// runLLMIteration executes the LLM call loop with tool handling.
@@ -666,10 +700,35 @@ func (al *AgentLoop) runLLMIteration(
}
errMsg := strings.ToLower(err.Error())
isContextError := strings.Contains(errMsg, "token") ||
strings.Contains(errMsg, "context") ||
// Check if this is a network/HTTP timeout — not a context window error.
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(errMsg, "deadline exceeded") ||
strings.Contains(errMsg, "client.timeout") ||
strings.Contains(errMsg, "timed out") ||
strings.Contains(errMsg, "timeout exceeded")
// Detect real context window / token limit errors, excluding network timeouts.
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
strings.Contains(errMsg, "context window") ||
strings.Contains(errMsg, "maximum context length") ||
strings.Contains(errMsg, "token limit") ||
strings.Contains(errMsg, "too many tokens") ||
strings.Contains(errMsg, "max_tokens") ||
strings.Contains(errMsg, "invalidparameter") ||
strings.Contains(errMsg, "length")
strings.Contains(errMsg, "prompt is too long") ||
strings.Contains(errMsg, "request too large"))
if isTimeoutError && retry < maxRetries {
backoff := time.Duration(retry+1) * 5 * time.Second
logger.WarnCF("agent", "Timeout error, retrying after backoff", map[string]any{
"error": err.Error(),
"retry": retry,
"backoff": backoff.String(),
})
time.Sleep(backoff)
continue
}
if isContextError && retry < maxRetries {
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
+56 -14
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"slices"
"testing"
"time"
@@ -187,13 +188,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -262,13 +257,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -797,4 +786,57 @@ func TestHandleReasoning(t *testing.T) {
t.Fatalf("expected no outbound message, got %+v", msg)
}
})
t.Run("returns promptly when bus is full", func(t *testing.T) {
al, msgBus := newLoop(t)
// Fill the outbound bus buffer until a publish would block.
// Use a short timeout to detect when the buffer is full,
// rather than hardcoding the buffer size.
for i := 0; ; i++ {
fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{
Channel: "filler",
ChatID: "filler",
Content: fmt.Sprintf("filler-%d", i),
})
fillCancel()
if err != nil {
// Buffer is full (timed out trying to send).
break
}
}
// Use a short-deadline parent context to bound the test.
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
start := time.Now()
al.handleReasoning(ctx, "should timeout", "slack", "channel-full")
elapsed := time.Since(start)
// handleReasoning uses a 5s internal timeout, but the parent ctx
// expires in 500ms. It should return within ~500ms, not 5s.
if elapsed > 2*time.Second {
t.Fatalf("handleReasoning blocked too long (%v); expected prompt return", elapsed)
}
// Drain the bus and verify the reasoning message was NOT published
// (it should have been dropped due to timeout).
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer drainCancel()
foundReasoning := false
for {
msg, ok := msgBus.SubscribeOutbound(drainCtx)
if !ok {
break
}
if msg.Content == "should timeout" {
foundReasoning = true
}
}
if foundReasoning {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
})
}
+1 -1
View File
@@ -111,7 +111,7 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
var sb strings.Builder
first := true
for i := 0; i < days; i++ {
for i := range days {
date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
+64 -8
View File
@@ -66,7 +66,8 @@ func decodeBase64(s string) string {
return string(data)
}
func generateState() (string, error) {
// GenerateState generates a random state string for OAuth CSRF protection.
func GenerateState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
@@ -80,7 +81,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
return nil, fmt.Errorf("generating PKCE: %w", err)
}
state, err := generateState()
state, err := GenerateState()
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
@@ -127,7 +128,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)
if err := openBrowser(authURL); err != nil {
if err := OpenBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
@@ -153,7 +154,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
if result.err != nil {
return nil, result.err
}
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
return ExchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case manualInput := <-manualCh:
if manualInput == "" {
return nil, fmt.Errorf("manual input canceled")
@@ -169,7 +170,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
if code == "" {
return nil, fmt.Errorf("could not find authorization code in input")
}
return exchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI)
return ExchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI)
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("authentication timed out after 5 minutes")
}
@@ -186,6 +187,59 @@ type deviceCodeResponse struct {
Interval int
}
// DeviceCodeInfo holds the device code information returned by the OAuth provider.
type DeviceCodeInfo struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
VerifyURL string `json:"verify_url"`
Interval int `json:"interval"`
}
// RequestDeviceCode requests a device code from the OAuth provider.
// Returns the info needed for the user to authenticate in a browser.
func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) {
reqBody, _ := json.Marshal(map[string]string{
"client_id": cfg.ClientID,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/usercode",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, fmt.Errorf("requesting device code: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
deviceResp, err := parseDeviceCodeResponse(body)
if err != nil {
return nil, fmt.Errorf("parsing device code response: %w", err)
}
if deviceResp.Interval < 1 {
deviceResp.Interval = 5
}
return &DeviceCodeInfo{
DeviceAuthID: deviceResp.DeviceAuthID,
UserCode: deviceResp.UserCode,
VerifyURL: cfg.Issuer + "/codex/device",
Interval: deviceResp.Interval,
}, nil
}
// PollDeviceCodeOnce makes a single poll attempt to check if the user has authenticated.
// Returns (credential, nil) on success, (nil, nil) if still pending, or (nil, err) on failure.
func PollDeviceCodeOnce(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
return pollDeviceCode(cfg, deviceAuthID, userCode)
}
func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
var raw struct {
DeviceAuthID string `json:"device_auth_id"`
@@ -318,7 +372,7 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au
}
redirectURI := cfg.Issuer + "/deviceauth/callback"
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
return ExchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
}
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
@@ -410,7 +464,8 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
}
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
// ExchangeCodeForTokens exchanges an authorization code for tokens.
func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
@@ -552,7 +607,8 @@ func base64URLDecode(s string) ([]byte, error) {
return base64.StdEncoding.DecodeString(s)
}
func openBrowser(url string) error {
// OpenBrowser opens the given URL in the user's default browser.
func OpenBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
+2 -2
View File
@@ -219,9 +219,9 @@ func TestExchangeCodeForTokens(t *testing.T) {
Port: 1455,
}
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
cred, err := ExchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
if err != nil {
t.Fatalf("exchangeCodeForTokens() error: %v", err)
t.Fatalf("ExchangeCodeForTokens() error: %v", err)
}
if cred.AccessToken != "mock-access-token" {
+3 -3
View File
@@ -67,7 +67,7 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
// Fill the buffer
ctx := context.Background()
for i := 0; i < defaultBusBufferSize; i++ {
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
@@ -154,7 +154,7 @@ func TestConcurrentPublishClose(t *testing.T) {
wg.Add(numGoroutines + 1)
// Spawn many goroutines trying to publish
for i := 0; i < numGoroutines; i++ {
for range numGoroutines {
go func() {
defer wg.Done()
// Use a short timeout context so we don't block forever after close
@@ -194,7 +194,7 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
ctx := context.Background()
// Fill the buffer
for i := 0; i < defaultBusBufferSize; i++ {
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
+80 -27
View File
@@ -1,7 +1,5 @@
# PicoClaw Channel System Refactor: Complete Development Guide
# PicoClaw Channel System: Complete Development Guide
> **Branch**: `refactor/channel-system`
> **Status**: Active development (~40 commits)
> **Scope**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
---
@@ -46,6 +44,8 @@ pkg/channels/
pkg/channels/
├── base.go # BaseChannel shared abstraction layer
├── interfaces.go # Optional capability interfaces (TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
├── README.md # English documentation
├── README.zh.md # Chinese documentation
├── media.go # MediaSender optional interface
├── webhook.go # WebhookHandler, HealthChecker optional interfaces
├── errors.go # Sentinel errors (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
@@ -60,7 +60,7 @@ pkg/channels/
├── discord/
│ ├── init.go
│ └── discord.go
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
│ └── ...
pkg/bus/
@@ -111,7 +111,7 @@ pkg/identity/
|-----------|-------------|
| **Sub-package Isolation** | Each channel is a standalone Go sub-package, depending on `BaseChannel` and interfaces from the `channels` parent package |
| **Factory Registration** | Sub-packages self-register via `init()`, Manager looks up factories by name, eliminating import coupling |
| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`), discovered by Manager via runtime type assertions |
| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`), discovered by Manager via runtime type assertions |
| **Structured Messages** | Peer, MessageID, and SenderInfo promoted from Metadata to first-class fields on InboundMessage |
| **Error Classification** | Channels return sentinel errors (`ErrRateLimit`, `ErrTemporary`, etc.), Manager uses these to determine retry strategy |
| **Centralized Orchestration** | Rate limiting, message splitting, retries, and Typing/Reaction/Placeholder management are all handled by Manager and BaseChannel; channels only need to implement Send |
@@ -145,6 +145,7 @@ After refactoring, these files have been removed and code moved to corresponding
| _(did not exist)_ | `pkg/channels/interfaces.go` | New optional capability interfaces |
| _(did not exist)_ | `pkg/channels/media.go` | New MediaSender interface |
| _(did not exist)_ | `pkg/channels/webhook.go` | New WebhookHandler/HealthChecker |
| _(did not exist)_ | `pkg/channels/whatsapp_native/` | New WhatsApp native mode (whatsmeow) |
| _(did not exist)_ | `pkg/channels/split.go` | New message splitting (migrated from utils) |
| _(did not exist)_ | `pkg/bus/types.go` | New structured message types |
| _(did not exist)_ | `pkg/media/store.go` | New media file lifecycle management |
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
cfg.Channels.Telegram.AllowFrom, // Allow list
channels.WithMaxMessageLength(4096), // Platform message length limit
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // Group trigger config
channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // Reasoning chain routing
)
return &TelegramChannel{
BaseChannel: base,
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
matrixCfg.AllowFrom, // Allow list
channels.WithMaxMessageLength(65536), // Matrix message length limit
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // Reasoning chain routing (optional)
)
return &MatrixChannel{
@@ -666,6 +669,32 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
}
```
#### PlaceholderCapable — Placeholder Messages
```go
// If the platform supports sending placeholder messages (e.g. "Thinking... 💭"),
// and the channel also implements MessageEditor, then Manager's preSend will
// automatically edit the placeholder into the final response on outbound.
// SendPlaceholder checks PlaceholderConfig.Enabled internally;
// returning ("", nil) means skip.
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
cfg := c.config.Channels.Matrix.Placeholder
if !cfg.Enabled {
return "", nil
}
text := cfg.Text
if text == "" {
text = "Thinking... 💭"
}
// Call Matrix API to send placeholder message
msg, err := c.sendText(ctx, chatID, text)
if err != nil {
return "", err
}
return msg.ID, nil
}
```
#### WebhookHandler — HTTP Webhook Reception
```go
@@ -746,15 +775,17 @@ When the Agent finishes processing a message, Manager's `preSend` automatically:
```go
type ChannelsConfig struct {
// ... existing channels
Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
Matrix MatrixChannelConfig `json:"matrix"`
}
type MatrixChannelConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
HomeServer string `yaml:"home_server" json:"home_server"`
Token string `yaml:"token" json:"token"`
AllowFrom []string `yaml:"allow_from" json:"allow_from"`
GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
Enabled bool `json:"enabled"`
HomeServer string `json:"home_server"`
Token string `json:"token"`
AllowFrom []string `json:"allow_from"`
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
Placeholder PlaceholderConfig `json:"placeholder"`
ReasoningChannelID string `json:"reasoning_channel_id"`
}
```
@@ -767,6 +798,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
}
```
> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config:
> ```go
> if cfg.UseNative {
> m.initChannel("whatsapp_native", "WhatsApp Native")
> } else {
> m.initChannel("whatsapp", "WhatsApp")
> }
> ```
#### Add blank import in Gateway
```go
@@ -882,19 +922,21 @@ BaseChannel is the shared abstraction layer for all channels, providing the foll
| `IsRunning() bool` | Atomically read running state |
| `SetRunning(bool)` | Atomically set running state |
| `MaxMessageLength() int` | Message length limit (rune count), 0 = unlimited |
| `ReasoningChannelID() string` | Reasoning chain routing target channel ID (empty = no routing) |
| `IsAllowed(senderID string) bool` | Legacy allow-list check (supports `"id\|username"` and `"@username"` formats) |
| `IsAllowedSender(sender SenderInfo) bool` | New allow-list check (delegates to `identity.MatchAllowed`) |
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | Unified group chat trigger filtering logic |
| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction → publish to Bus |
| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction/Placeholder → publish to Bus |
| `SetMediaStore(s) / GetMediaStore()` | MediaStore injected by Manager |
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | PlaceholderRecorder injected by Manager |
| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction type assertions in HandleMessage) |
| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction/Placeholder type assertions in HandleMessage) |
**Functional Options**:
```go
channels.WithMaxMessageLength(4096) // Set platform message length limit
channels.WithGroupTrigger(groupTriggerCfg) // Set group trigger configuration
channels.WithReasoningChannelID(id) // Set reasoning chain routing target channel
```
### 4.4 Factory Registry
@@ -998,7 +1040,7 @@ StartAll:
- runMediaWorker (per-channel outbound media)
- dispatchOutbound (route from bus to worker queues)
- dispatchOutboundMedia (route from bus to media worker queues)
- runTTLJanitor (every 10s clean up expired typing/placeholder)
- runTTLJanitor (every 10s clean up expired typing/reaction/placeholder)
4. Start shared HTTP server (if configured)
StopAll:
@@ -1206,18 +1248,20 @@ make test # Full test suite
| Sub-package | Registered Name | Optional Interfaces |
|-------------|----------------|-------------------|
| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (architecture-specific build tags) |
| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
| `pkg/channels/dingtalk/` | `"dingtalk"` | |
| `pkg/channels/feishu/` | `"feishu"` | (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) |
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
| `pkg/channels/qq/` | `"qq"` | — |
| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) |
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) |
| `pkg/channels/maixcam/` | `"maixcam"` | — |
| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
### A.3 Interface Quick Reference
@@ -1231,6 +1275,7 @@ type Channel interface {
IsRunning() bool
IsAllowed(senderID string) bool
IsAllowedSender(sender bus.SenderInfo) bool
ReasoningChannelID() string
}
// ===== Optional =====
@@ -1324,8 +1369,16 @@ agentLoop.Stop() // Stop Agent
1. **Media cleanup temporarily disabled**: The `ReleaseAll` call in the Agent loop is commented out (`refactor(loop): disable media cleanup to prevent premature file deletion`) because session boundaries are not yet clearly defined. TTL cleanup remains active.
2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`).
2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
3. **WeCom has two factories**: `"wecom"` (Bot mode) and `"wecom_app"` (App mode) are registered separately.
3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`.
4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via webhook.
4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`).
5. **WhatsApp has two modes**: `"whatsapp"` (Bridge mode, communicates via external bridge URL) and `"whatsapp_native"` (native whatsmeow mode, connects directly to WhatsApp). Manager selects which to initialize based on `WhatsAppConfig.UseNative`.
6. **DingTalk uses Stream mode**: DingTalk uses the SDK's Stream/WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields.
8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method.
+79 -27
View File
@@ -1,7 +1,5 @@
# PicoClaw Channel System 重构:完整开发指南
# PicoClaw Channel System:完整开发指南
> **分支**: `refactor/channel-system`
> **状态**: 活跃开发中(约 40 commits
> **影响范围**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
---
@@ -46,6 +44,8 @@ pkg/channels/
pkg/channels/
├── base.go # BaseChannel 共享抽象层
├── interfaces.go # 可选能力接口(TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder
├── README.md # 英文文档
├── README.zh.md # 中文文档
├── media.go # MediaSender 可选接口
├── webhook.go # WebhookHandler, HealthChecker 可选接口
├── errors.go # 错误哨兵值(ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed
@@ -60,7 +60,7 @@ pkg/channels/
├── discord/
│ ├── init.go
│ └── discord.go
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
│ └── ...
pkg/bus/
@@ -111,7 +111,7 @@ pkg/identity/
|------|------|
| **子包隔离** | 每个 channel 一个独立 Go 子包,依赖 `channels` 父包提供的 `BaseChannel` 和接口 |
| **工厂注册** | 各子包通过 `init()` 自注册,Manager 通过名字查找工厂,消除 import 耦合 |
| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`)声明,Manager 运行时类型断言发现 |
| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`)声明,Manager 运行时类型断言发现 |
| **结构化消息** | Peer、MessageID、SenderInfo 从 Metadata 提升为 InboundMessage 的一等字段 |
| **错误分类** | Channel 返回哨兵错误(`ErrRateLimit`, `ErrTemporary` 等),Manager 据此决定重试策略 |
| **集中编排** | 速率限制、消息分割、重试、Typing/Reaction/Placeholder 全部由 Manager 和 BaseChannel 统一处理,Channel 只负责 Send |
@@ -145,6 +145,7 @@ pkg/identity/
| _(不存在)_ | `pkg/channels/interfaces.go` | 新增可选能力接口 |
| _(不存在)_ | `pkg/channels/media.go` | 新增 MediaSender 接口 |
| _(不存在)_ | `pkg/channels/webhook.go` | 新增 WebhookHandler/HealthChecker |
| _(不存在)_ | `pkg/channels/whatsapp_native/` | 新增 WhatsApp 原生模式(whatsmeow |
| _(不存在)_ | `pkg/channels/split.go` | 新增消息分割(从 utils 迁入) |
| _(不存在)_ | `pkg/bus/types.go` | 新增结构化消息类型 |
| _(不存在)_ | `pkg/media/store.go` | 新增媒体文件生命周期管理 |
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
cfg.Channels.Telegram.AllowFrom, // 允许列表
channels.WithMaxMessageLength(4096), // 平台消息长度上限
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // 群聊触发配置
channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // 思维链路由
)
return &TelegramChannel{
BaseChannel: base,
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
matrixCfg.AllowFrom, // 允许列表
channels.WithMaxMessageLength(65536), // Matrix 消息长度限制
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // 思维链路由(可选)
)
return &MatrixChannel{
@@ -666,6 +669,31 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
}
```
#### PlaceholderCapable — 占位消息
```go
// 如果平台支持发送占位消息(如 "Thinking... 💭"),并且实现了 MessageEditor
// 则 Manager 的 preSend 会在出站时自动将占位消息编辑为最终回复。
// SendPlaceholder 内部根据 PlaceholderConfig.Enabled 决定是否发送;
// 返回 ("", nil) 表示跳过。
func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
cfg := c.config.Channels.Matrix.Placeholder
if !cfg.Enabled {
return "", nil
}
text := cfg.Text
if text == "" {
text = "Thinking... 💭"
}
// 调用 Matrix API 发送占位消息
msg, err := c.sendText(ctx, chatID, text)
if err != nil {
return "", err
}
return msg.ID, nil
}
```
#### WebhookHandler — HTTP Webhook 接收
```go
@@ -746,15 +774,17 @@ if c.owner != nil && c.placeholderRecorder != nil {
```go
type ChannelsConfig struct {
// ... 现有 channels
Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
Matrix MatrixChannelConfig `json:"matrix"`
}
type MatrixChannelConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
HomeServer string `yaml:"home_server" json:"home_server"`
Token string `yaml:"token" json:"token"`
AllowFrom []string `yaml:"allow_from" json:"allow_from"`
GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
Enabled bool `json:"enabled"`
HomeServer string `json:"home_server"`
Token string `json:"token"`
AllowFrom []string `json:"allow_from"`
GroupTrigger GroupTriggerConfig `json:"group_trigger"`
Placeholder PlaceholderConfig `json:"placeholder"`
ReasoningChannelID string `json:"reasoning_channel_id"`
}
```
@@ -767,6 +797,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
}
```
> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),需要在 initChannels 中根据配置分支:
> ```go
> if cfg.UseNative {
> m.initChannel("whatsapp_native", "WhatsApp Native")
> } else {
> m.initChannel("whatsapp", "WhatsApp")
> }
> ```
#### 在 Gateway 中添加 blank import
```go
@@ -882,19 +921,21 @@ BaseChannel 是所有 channel 的共享抽象层,提供以下能力:
| `IsRunning() bool` | 原子读取运行状态 |
| `SetRunning(bool)` | 原子设置运行状态 |
| `MaxMessageLength() int` | 消息长度限制(rune 计数),0 = 无限制 |
| `ReasoningChannelID() string` | 思维链路由目标 channel ID(空 = 不路由) |
| `IsAllowed(senderID string) bool` | 旧格式允许列表检查(支持 `"id\|username"``"@username"` 格式) |
| `IsAllowedSender(sender SenderInfo) bool` | 新格式允许列表检查(委托给 `identity.MatchAllowed` |
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | 统一群聊触发过滤逻辑 |
| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction → 发布到 Bus |
| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction/Placeholder → 发布到 Bus |
| `SetMediaStore(s) / GetMediaStore()` | Manager 注入的媒体存储 |
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | Manager 注入的占位符记录器 |
| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction 类型断言) |
| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction/Placeholder 类型断言) |
**功能选项**
```go
channels.WithMaxMessageLength(4096) // 设置平台消息长度限制
channels.WithGroupTrigger(groupTriggerCfg) // 设置群聊触发配置
channels.WithReasoningChannelID(id) // 设置思维链路由目标 channel
```
### 4.4 工厂注册表
@@ -998,7 +1039,7 @@ StartAll:
- runMediaWorker (per-channel 出站媒体)
- dispatchOutbound (从 bus 路由到 worker 队列)
- dispatchOutboundMedia (从 bus 路由到 media worker 队列)
- runTTLJanitor (每 10s 清理过期 typing/placeholder)
- runTTLJanitor (每 10s 清理过期 typing/reaction/placeholder)
4. 启动共享 HTTP 服务器(如已配置)
StopAll:
@@ -1206,18 +1247,20 @@ make test # 全量测试
| 子包 | 注册名 | 可选接口 |
|------|--------|----------|
| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (架构特定 build tags) |
| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
| `pkg/channels/dingtalk/` | `"dingtalk"` | |
| `pkg/channels/feishu/` | `"feishu"` | (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) |
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
| `pkg/channels/qq/` | `"qq"` | — |
| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) |
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) |
| `pkg/channels/maixcam/` | `"maixcam"` | — |
| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
### A.3 接口速查表
@@ -1231,6 +1274,7 @@ type Channel interface {
IsRunning() bool
IsAllowed(senderID string) bool
IsAllowedSender(sender bus.SenderInfo) bool
ReasoningChannelID() string
}
// ===== 可选实现 =====
@@ -1324,8 +1368,16 @@ agentLoop.Stop() // 停止 Agent
1. **媒体清理暂时禁用**Agent loop 中的 `ReleaseAll` 调用被注释掉了(`refactor(loop): disable media cleanup to prevent premature file deletion`),因为会话边界尚未明确定义。TTL 清理仍然有效。
2. **Feishu 架构特定编译**Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。
2. **Feishu 架构特定编译**Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`
3. **WeCom 有两个工厂**`"wecom"`Bot 模式)和 `"wecom_app"`(应用模式)分别注册
3. **WeCom 有两个工厂**`"wecom"`Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler``HealthChecker`
4. **Pico Protocol**`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 webhook 接收消息。
4. **Pico Protocol**`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。
5. **WhatsApp 有两种模式**`"whatsapp"`Bridge 模式,通过外部 bridge URL 通信)和 `"whatsapp_native"`(原生 whatsmeow 模式,直接连接 WhatsApp)。Manager 根据 `WhatsAppConfig.UseNative` 决定初始化哪个。
6. **DingTalk 使用 Stream 模式**DingTalk 使用 SDK 的 Stream/WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`
7. **PlaceholderConfig 的配置与实现**`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channelTelegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。
8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channelWhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。
+11 -9
View File
@@ -45,11 +45,13 @@ type replyTokenEntry struct {
type LINEChannel struct {
*channels.BaseChannel
config config.LINEConfig
botUserID string // Bot's user ID
botBasicID string // Bot's basic ID (e.g. @216ru...)
botDisplayName string // Bot's display name for text-based mention detection
replyTokens sync.Map // chatID -> replyTokenEntry
quoteTokens sync.Map // chatID -> quoteToken (string)
infoClient *http.Client // for bot info lookups (short timeout)
apiClient *http.Client // for messaging API calls
botUserID string // Bot's user ID
botBasicID string // Bot's basic ID (e.g. @216ru...)
botDisplayName string // Bot's display name for text-based mention detection
replyTokens sync.Map // chatID -> replyTokenEntry
quoteTokens sync.Map // chatID -> quoteToken (string)
ctx context.Context
cancel context.CancelFunc
}
@@ -69,6 +71,8 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
return &LINEChannel{
BaseChannel: base,
config: cfg,
infoClient: &http.Client{Timeout: 10 * time.Second},
apiClient: &http.Client{Timeout: 30 * time.Second},
}, nil
}
@@ -104,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error {
}
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
resp, err := c.infoClient.Do(req)
if err != nil {
return err
}
@@ -644,8 +647,7 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
resp, err := c.apiClient.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
+1 -1
View File
@@ -313,7 +313,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
if len(m.channels) == 0 {
logger.WarnC("channels", "No channels enabled")
return nil
return errors.New("no channels enabled")
}
logger.InfoC("channels", "Starting all channels")
+6 -8
View File
@@ -274,13 +274,12 @@ func TestWorkerRateLimiter(t *testing.T) {
limiter: rate.NewLimiter(2, 1),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
go m.runWorker(ctx, "test", w)
// Enqueue 4 messages
for i := 0; i < 4; i++ {
for i := range 4 {
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
}
@@ -352,8 +351,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
go m.runWorker(ctx, "test", w)
@@ -576,7 +574,7 @@ func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) {
m := newTestManager()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
for i := range 100 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -591,7 +589,7 @@ func TestRecordTypingStop_ConcurrentSafe(t *testing.T) {
m := newTestManager()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
for i := range 100 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -834,7 +832,7 @@ func TestLazyWorkerCreation(t *testing.T) {
func TestBuildMediaScope_FastIDUniqueness(t *testing.T) {
seen := make(map[string]bool)
for i := 0; i < 1000; i++ {
for range 1000 {
scope := BuildMediaScope("test", "chat1", "")
if seen[scope] {
t.Fatalf("duplicate scope generated: %s", scope)
+1 -4
View File
@@ -337,10 +337,7 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
}
func (c *OneBotChannel) reconnectLoop() {
interval := time.Duration(c.config.ReconnectInterval) * time.Second
if interval < 5*time.Second {
interval = 5 * time.Second
}
interval := max(time.Duration(c.config.ReconnectInterval)*time.Second, 5*time.Second)
for {
select {
+2 -2
View File
@@ -292,8 +292,8 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
// Check Authorization header
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
if strings.TrimPrefix(auth, "Bearer ") == token {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
if after == token {
return true
}
}
+8 -24
View File
@@ -23,10 +23,7 @@ func SplitMessage(content string, maxLen int) []string {
var messages []string
// Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
codeBlockBuffer := maxLen / 10
if codeBlockBuffer < 50 {
codeBlockBuffer = 50
}
codeBlockBuffer := max(maxLen/10, 50)
if codeBlockBuffer > maxLen/2 {
codeBlockBuffer = maxLen / 2
}
@@ -40,10 +37,7 @@ func SplitMessage(content string, maxLen int) []string {
}
// Effective split point: maxLen minus buffer, to leave room for code blocks
effectiveLimit := maxLen - codeBlockBuffer
if effectiveLimit < maxLen/2 {
effectiveLimit = maxLen / 2
}
effectiveLimit := max(maxLen-codeBlockBuffer, maxLen/2)
end := start + effectiveLimit
@@ -85,10 +79,9 @@ func SplitMessage(content string, maxLen int) []string {
// If we have a reasonable amount of content after the header, split inside
if msgEnd > headerEndIdx+20 {
// Find a better split point closer to maxLen
innerLimit := start + maxLen - 5 // Leave room for "\n```"
if innerLimit > totalLen {
innerLimit = totalLen
}
innerLimit := min(
// Leave room for "\n```"
start+maxLen-5, totalLen)
betterEnd := findLastNewlineInRange(runes, start, innerLimit, 200)
if betterEnd > headerEndIdx {
msgEnd = betterEnd
@@ -117,10 +110,7 @@ func SplitMessage(content string, maxLen int) []string {
if unclosedIdx-start > 20 {
msgEnd = unclosedIdx
} else {
splitAt := start + maxLen - 5
if splitAt > totalLen {
splitAt = totalLen
}
splitAt := min(start+maxLen-5, totalLen)
chunk := strings.TrimRight(string(runes[start:splitAt]), " \t\n\r") + "\n```"
messages = append(messages, chunk)
remaining := strings.TrimSpace(header + "\n" + string(runes[splitAt:totalLen]))
@@ -196,10 +186,7 @@ func findNewlineFrom(runes []rune, from int) int {
// findLastNewlineInRange finds the last newline within the last searchWindow runes
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
searchStart := end - searchWindow
if searchStart < start {
searchStart = start
}
searchStart := max(end-searchWindow, start)
for i := end - 1; i >= searchStart; i-- {
if runes[i] == '\n' {
return i
@@ -211,10 +198,7 @@ func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
// findLastSpaceInRange finds the last space/tab within the last searchWindow runes
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
func findLastSpaceInRange(runes []rune, start, end, searchWindow int) int {
searchStart := end - searchWindow
if searchStart < start {
searchStart = start
}
searchStart := max(end-searchWindow, start)
for i := end - 1; i >= searchStart; i-- {
if runes[i] == ' ' || runes[i] == '\t' {
return i
+27 -13
View File
@@ -32,6 +32,7 @@ const (
type WeComAppChannel struct {
*channels.BaseChannel
config config.WeComAppConfig
client *http.Client
accessToken string
tokenExpiry time.Time
tokenMu sync.RWMutex
@@ -129,9 +130,20 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
// Client timeout must be >= the configured ReplyTimeout so the
// per-request context deadline is always the effective limit.
clientTimeout := 30 * time.Second
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
clientTimeout = d
}
ctx, cancel := context.WithCancel(context.Background())
return &WeComAppChannel{
BaseChannel: base,
config: cfg,
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: make(map[string]bool),
}, nil
}
@@ -145,6 +157,10 @@ func (c *WeComAppChannel) Name() string {
func (c *WeComAppChannel) Start(ctx context.Context) error {
logger.InfoC("wecom_app", "Starting WeCom App channel...")
// Cancel the context created in the constructor to avoid a resource leak.
if c.cancel != nil {
c.cancel()
}
c.ctx, c.cancel = context.WithCancel(ctx)
// Get initial access token
@@ -299,8 +315,7 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
}
req.Header.Set("Content-Type", writer.FormDataContentType())
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return "", channels.ClassifyNetError(err)
}
@@ -357,8 +372,7 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
@@ -567,8 +581,9 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp
return
}
// Process the message with context
go c.processMessage(ctx, msg)
// Process the message with the channel's long-lived context (not the HTTP
// request context, which is canceled as soon as we return the response).
go c.processMessage(c.ctx, msg)
// Return success response immediately
// WeCom App requires response within configured timeout (default 5 seconds)
@@ -597,14 +612,14 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
return
}
c.processedMsgs[msgID] = true
c.msgMu.Unlock()
// Clean up old messages periodically (keep last 1000)
// Clean up old messages while still holding the lock to avoid a data race
// on len(). Reset the map but re-insert the current msgID so it remains
// deduplicated.
if len(c.processedMsgs) > 1000 {
c.msgMu.Lock()
c.processedMsgs = make(map[string]bool)
c.msgMu.Unlock()
c.processedMsgs[msgID] = true
}
c.msgMu.Unlock()
senderID := msg.FromUserName
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
@@ -738,8 +753,7 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
+1 -1
View File
@@ -43,7 +43,7 @@ func encryptTestMessageApp(message, aesKey string) (string, error) {
// Prepare message: random(16) + msg_len(4) + msg + corp_id
random := make([]byte, 0, 16)
for i := 0; i < 16; i++ {
for i := range 16 {
random = append(random, byte(i+1))
}
+25 -9
View File
@@ -25,6 +25,7 @@ import (
type WeComBotChannel struct {
*channels.BaseChannel
config config.WeComConfig
client *http.Client
ctx context.Context
cancel context.CancelFunc
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
@@ -93,9 +94,20 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
// Client timeout must be >= the configured ReplyTimeout so the
// per-request context deadline is always the effective limit.
clientTimeout := 30 * time.Second
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
clientTimeout = d
}
ctx, cancel := context.WithCancel(context.Background())
return &WeComBotChannel{
BaseChannel: base,
config: cfg,
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: make(map[string]bool),
}, nil
}
@@ -109,6 +121,10 @@ func (c *WeComBotChannel) Name() string {
func (c *WeComBotChannel) Start(ctx context.Context) error {
logger.InfoC("wecom", "Starting WeCom Bot channel...")
// Cancel the context created in the constructor to avoid a resource leak.
if c.cancel != nil {
c.cancel()
}
c.ctx, c.cancel = context.WithCancel(ctx)
c.SetRunning(true)
@@ -292,8 +308,9 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp
return
}
// Process the message asynchronously with context
go c.processMessage(ctx, msg)
// Process the message with the channel's long-lived context (not the HTTP
// request context, which is canceled as soon as we return the response).
go c.processMessage(c.ctx, msg)
// Return success response immediately
// WeCom Bot requires response within configured timeout (default 5 seconds)
@@ -322,14 +339,14 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
return
}
c.processedMsgs[msgID] = true
c.msgMu.Unlock()
// Clean up old messages periodically (keep last 1000)
// Clean up old messages while still holding the lock to avoid a data race
// on len(). Reset the map but re-insert the current msgID so it remains
// deduplicated.
if len(c.processedMsgs) > 1000 {
c.msgMu.Lock()
c.processedMsgs = make(map[string]bool)
c.msgMu.Unlock()
c.processedMsgs[msgID] = true
}
c.msgMu.Unlock()
senderID := msg.From.UserID
@@ -442,8 +459,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
+1 -1
View File
@@ -42,7 +42,7 @@ func encryptTestMessage(message, aesKey string) (string, error) {
// Prepare message: random(16) + msg_len(4) + msg + receiveid
random := make([]byte, 0, 16)
for i := 0; i < 16; i++ {
for i := range 16 {
random = append(random, byte(i))
}
+1 -1
View File
@@ -125,7 +125,7 @@ func pkcs7Unpad(data []byte) ([]byte, error) {
return nil, fmt.Errorf("padding size larger than data")
}
// Verify all padding bytes
for i := 0; i < padding; i++ {
for i := range padding {
if data[len(data)-1-i] != byte(padding) {
return nil, fmt.Errorf("invalid padding byte at position %d", i)
}
+126 -19
View File
@@ -15,6 +15,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mdp/qrterminal/v3"
@@ -56,6 +57,8 @@ type WhatsAppNativeChannel struct {
runCancel context.CancelFunc
reconnectMu sync.Mutex
reconnecting bool
stopping atomic.Bool // set once Stop begins; prevents new wg.Add calls
wg sync.WaitGroup // tracks background goroutines (QR handler, reconnect)
}
// NewWhatsAppNativeChannel creates a WhatsApp channel that uses whatsmeow for connection.
@@ -80,6 +83,14 @@ func NewWhatsAppNativeChannel(
func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
logger.InfoCF("whatsapp", "Starting WhatsApp native channel (whatsmeow)", map[string]any{"store": c.storePath})
// Reset lifecycle state from any previous Stop() so a restarted channel
// behaves correctly. Use reconnectMu to be consistent with eventHandler
// and Stop() which coordinate under the same lock.
c.reconnectMu.Lock()
c.stopping.Store(false)
c.reconnecting = false
c.reconnectMu.Unlock()
if err := os.MkdirAll(c.storePath, 0o700); err != nil {
return fmt.Errorf("create session store dir: %w", err)
}
@@ -112,6 +123,12 @@ func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
}
client := whatsmeow.NewClient(deviceStore, waLogger)
// Create runCtx/runCancel BEFORE registering event handler and starting
// goroutines so that Stop() can cancel them at any time, including during
// the QR-login flow.
c.runCtx, c.runCancel = context.WithCancel(ctx)
client.AddEventHandler(c.eventHandler)
c.mu.Lock()
@@ -119,36 +136,75 @@ func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
c.client = client
c.mu.Unlock()
// cleanupOnError clears struct references and releases resources when
// Start() fails after fields are already assigned. This prevents
// Stop() from operating on stale references (double-close, disconnect
// of a partially-initialized client, or stray event handler callbacks).
startOK := false
defer func() {
if startOK {
return
}
c.runCancel()
client.Disconnect()
c.mu.Lock()
c.client = nil
c.container = nil
c.mu.Unlock()
_ = container.Close()
}()
if client.Store.ID == nil {
qrChan, err := client.GetQRChannel(ctx)
qrChan, err := client.GetQRChannel(c.runCtx)
if err != nil {
_ = container.Close()
return fmt.Errorf("get QR channel: %w", err)
}
if err := client.Connect(); err != nil {
_ = container.Close()
return fmt.Errorf("connect: %w", err)
}
for evt := range qrChan {
if evt.Event == "code" {
logger.InfoCF("whatsapp", "Scan this QR code with WhatsApp (Linked Devices):", nil)
qrterminal.GenerateWithConfig(evt.Code, qrterminal.Config{
Level: qrterminal.L,
Writer: os.Stdout,
HalfBlocks: true,
})
} else {
logger.InfoCF("whatsapp", "WhatsApp login event", map[string]any{"event": evt.Event})
}
// Handle QR events in a background goroutine so Start() returns
// promptly. The goroutine is tracked via c.wg and respects
// c.runCtx for cancellation.
// Guard wg.Add with reconnectMu + stopping check (same protocol
// as eventHandler) so a concurrent Stop() cannot enter wg.Wait()
// while we call wg.Add(1).
c.reconnectMu.Lock()
if c.stopping.Load() {
c.reconnectMu.Unlock()
return fmt.Errorf("channel stopped during QR setup")
}
c.wg.Add(1)
c.reconnectMu.Unlock()
go func() {
defer c.wg.Done()
for {
select {
case <-c.runCtx.Done():
return
case evt, ok := <-qrChan:
if !ok {
return
}
if evt.Event == "code" {
logger.InfoCF("whatsapp", "Scan this QR code with WhatsApp (Linked Devices):", nil)
qrterminal.GenerateWithConfig(evt.Code, qrterminal.Config{
Level: qrterminal.L,
Writer: os.Stdout,
HalfBlocks: true,
})
} else {
logger.InfoCF("whatsapp", "WhatsApp login event", map[string]any{"event": evt.Event})
}
}
}
}()
} else {
if err := client.Connect(); err != nil {
_ = container.Close()
return fmt.Errorf("connect: %w", err)
}
}
c.runCtx, c.runCancel = context.WithCancel(ctx)
startOK = true
c.SetRunning(true)
logger.InfoC("whatsapp", "WhatsApp native channel connected")
return nil
@@ -156,19 +212,53 @@ func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
func (c *WhatsAppNativeChannel) Stop(ctx context.Context) error {
logger.InfoC("whatsapp", "Stopping WhatsApp native channel")
// Mark as stopping under reconnectMu so the flag is visible to
// eventHandler atomically with respect to its wg.Add(1) call.
// This closes the TOCTOU window where eventHandler could check
// stopping (false), then Stop sets it true + enters wg.Wait,
// then eventHandler calls wg.Add(1) — causing a panic.
c.reconnectMu.Lock()
c.stopping.Store(true)
c.reconnectMu.Unlock()
if c.runCancel != nil {
c.runCancel()
}
// Disconnect the client first so any blocking Connect()/reconnect loops
// can be interrupted before we wait on the goroutines.
c.mu.Lock()
client := c.client
container := c.container
c.client = nil
c.container = nil
c.mu.Unlock()
if client != nil {
client.Disconnect()
}
// Wait for background goroutines (QR handler, reconnect) to finish in a
// context-aware way so Stop can be bounded by ctx.
done := make(chan struct{})
go func() {
c.wg.Wait()
close(done)
}()
select {
case <-done:
// All goroutines have finished.
case <-ctx.Done():
// Context canceled or timed out; log and proceed with best-effort cleanup.
logger.WarnC("whatsapp", fmt.Sprintf("Stop context canceled before all goroutines finished: %v", ctx.Err()))
}
// Now it is safe to clear and close resources.
c.mu.Lock()
c.client = nil
c.container = nil
c.mu.Unlock()
if container != nil {
_ = container.Close()
}
@@ -187,9 +277,20 @@ func (c *WhatsAppNativeChannel) eventHandler(evt any) {
c.reconnectMu.Unlock()
return
}
// Check stopping while holding the lock so the check and wg.Add
// are atomic with respect to Stop() setting the flag + calling
// wg.Wait(). This prevents the TOCTOU race.
if c.stopping.Load() {
c.reconnectMu.Unlock()
return
}
c.reconnecting = true
c.wg.Add(1)
c.reconnectMu.Unlock()
go c.reconnectWithBackoff()
go func() {
defer c.wg.Done()
c.reconnectWithBackoff()
}()
}
}
@@ -313,6 +414,12 @@ func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessag
return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary)
}
// Detect unpaired state: the client is connected (to WhatsApp servers)
// but has not completed QR-login yet, so sending would fail.
if client.Store.ID == nil {
return fmt.Errorf("whatsapp not yet paired (QR login pending): %w", channels.ErrTemporary)
}
to, err := parseJID(msg.ChatID)
if err != nil {
return fmt.Errorf("invalid chat id %q: %w", msg.ChatID, err)
+25
View File
@@ -442,3 +442,28 @@ func TestDefaultConfig_DMScope(t *testing.T) {
t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope)
}
}
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
// Unset to ensure we test the default
t.Setenv("PICOCLAW_HOME", "")
// Set a known home for consistent test results
t.Setenv("HOME", "/tmp/home")
cfg := DefaultConfig()
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
}
}
func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
cfg := DefaultConfig()
want := "/custom/picoclaw/home/workspace"
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
}
}
+17 -1
View File
@@ -5,12 +5,28 @@
package config
import (
"os"
"path/filepath"
)
// DefaultConfig returns the default configuration for PicoClaw.
func DefaultConfig() *Config {
// Determine the base path for the workspace.
// Priority: $PICOCLAW_HOME > ~/.picoclaw
var homePath string
if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
homePath = picoclawHome
} else {
userHome, _ := os.UserHomeDir()
homePath = filepath.Join(userHome, ".picoclaw")
}
workspacePath := filepath.Join(homePath, "workspace")
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace",
Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
+5 -7
View File
@@ -64,7 +64,7 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
// Test round-robin distribution
results := make(map[string]int)
for i := 0; i < 30; i++ {
for range 30 {
result, err := cfg.GetModelConfig("lb-model")
if err != nil {
t.Fatalf("GetModelConfig() error = %v", err)
@@ -94,17 +94,15 @@ func TestGetModelConfig_Concurrent(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, goroutines*iterations)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iterations; j++ {
for range goroutines {
wg.Go(func() {
for range iterations {
_, err := cfg.GetModelConfig("concurrent-model")
if err != nil {
errors <- err
}
}
}()
})
}
wg.Wait()
+2 -3
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"maps"
"net/http"
"sync"
"time"
@@ -122,9 +123,7 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
s.mu.RLock()
ready := s.ready
checks := make(map[string]Check)
for k, v := range s.checks {
checks[k] = v
}
maps.Copy(checks, s.checks)
s.mu.RUnlock()
if !ready {
+11 -11
View File
@@ -49,7 +49,7 @@ func TestReleaseAll(t *testing.T) {
paths := make([]string, 3)
refs := make([]string, 3)
for i := 0; i < 3; i++ {
for i := range 3 {
paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg")
var err error
refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1")
@@ -228,12 +228,12 @@ func TestConcurrentSafety(t *testing.T) {
var wg sync.WaitGroup
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
for g := range goroutines {
go func(gIdx int) {
defer wg.Done()
scope := strings.Repeat("s", gIdx+1)
for i := 0; i < filesPerGoroutine; i++ {
for i := range filesPerGoroutine {
path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp")
ref, err := store.Store(path, MediaMeta{Source: "test"}, scope)
if err != nil {
@@ -448,11 +448,11 @@ func TestConcurrentCleanupSafety(t *testing.T) {
wg.Add(workers * 4)
// Store workers
for w := 0; w < workers; w++ {
for w := range workers {
go func(wIdx int) {
defer wg.Done()
scope := fmt.Sprintf("scope-%d", wIdx)
for i := 0; i < ops; i++ {
for i := range ops {
p := createTempFile(t, dir, fmt.Sprintf("w%d-f%d.tmp", wIdx, i))
store.Store(p, MediaMeta{Source: "test"}, scope)
}
@@ -460,30 +460,30 @@ func TestConcurrentCleanupSafety(t *testing.T) {
}
// Resolve workers
for w := 0; w < workers; w++ {
for range workers {
go func() {
defer wg.Done()
for i := 0; i < ops; i++ {
for range ops {
store.Resolve("media://nonexistent")
}
}()
}
// ReleaseAll workers
for w := 0; w < workers; w++ {
for w := range workers {
go func(wIdx int) {
defer wg.Done()
for i := 0; i < ops; i++ {
for range ops {
store.ReleaseAll(fmt.Sprintf("scope-%d", wIdx))
}
}(w)
}
// CleanExpired workers
for w := 0; w < workers; w++ {
for range workers {
go func() {
defer wg.Done()
for i := 0; i < ops; i++ {
for range ops {
store.CleanExpired()
}
}()
-414
View File
@@ -1,414 +0,0 @@
package migrate
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"unicode"
"github.com/sipeed/picoclaw/pkg/config"
)
var supportedProviders = map[string]bool{
"anthropic": true,
"openai": true,
"openrouter": true,
"groq": true,
"zhipu": true,
"vllm": true,
"gemini": true,
"qwen": true,
"deepseek": true,
"github_copilot": true,
"mistral": true,
}
var supportedChannels = map[string]bool{
"telegram": true,
"discord": true,
"whatsapp": true,
"feishu": true,
"qq": true,
"dingtalk": true,
"maixcam": true,
}
func findOpenClawConfig(openclawHome string) (string, error) {
candidates := []string{
filepath.Join(openclawHome, "openclaw.json"),
filepath.Join(openclawHome, "config.json"),
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p, nil
}
}
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
}
func LoadOpenClawConfig(configPath string) (map[string]any, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
}
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
}
converted := convertKeysToSnake(raw)
result, ok := converted.(map[string]any)
if !ok {
return nil, fmt.Errorf("unexpected config format")
}
return result, nil
}
func ConvertConfig(data map[string]any) (*config.Config, []string, error) {
cfg := config.DefaultConfig()
var warnings []string
if agents, ok := getMap(data, "agents"); ok {
if defaults, ok := getMap(agents, "defaults"); ok {
// Prefer model_name, fallback to model for backward compatibility
if v, ok := getString(defaults, "model_name"); ok {
cfg.Agents.Defaults.ModelName = v
} else if v, ok := getString(defaults, "model"); ok {
cfg.Agents.Defaults.Model = v
}
if v, ok := getFloat(defaults, "max_tokens"); ok {
cfg.Agents.Defaults.MaxTokens = int(v)
}
if v, ok := getFloat(defaults, "temperature"); ok {
cfg.Agents.Defaults.Temperature = &v
}
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
cfg.Agents.Defaults.MaxToolIterations = int(v)
}
if v, ok := getString(defaults, "workspace"); ok {
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
}
}
}
if providers, ok := getMap(data, "providers"); ok {
for name, val := range providers {
pMap, ok := val.(map[string]any)
if !ok {
continue
}
apiKey, _ := getString(pMap, "api_key")
apiBase, _ := getString(pMap, "api_base")
if !supportedProviders[name] {
if apiKey != "" || apiBase != "" {
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
}
continue
}
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
switch name {
case "anthropic":
cfg.Providers.Anthropic = pc
case "openai":
cfg.Providers.OpenAI = config.OpenAIProviderConfig{
ProviderConfig: pc,
WebSearch: getBoolOrDefault(pMap, "web_search", true),
}
case "openrouter":
cfg.Providers.OpenRouter = pc
case "groq":
cfg.Providers.Groq = pc
case "zhipu":
cfg.Providers.Zhipu = pc
case "vllm":
cfg.Providers.VLLM = pc
case "gemini":
cfg.Providers.Gemini = pc
}
}
}
if channels, ok := getMap(data, "channels"); ok {
for name, val := range channels {
cMap, ok := val.(map[string]any)
if !ok {
continue
}
if !supportedChannels[name] {
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
continue
}
enabled, _ := getBool(cMap, "enabled")
allowFrom := getStringSlice(cMap, "allow_from")
switch name {
case "telegram":
cfg.Channels.Telegram.Enabled = enabled
cfg.Channels.Telegram.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Telegram.Token = v
}
case "discord":
cfg.Channels.Discord.Enabled = enabled
cfg.Channels.Discord.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Discord.Token = v
}
case "whatsapp":
cfg.Channels.WhatsApp.Enabled = enabled
cfg.Channels.WhatsApp.AllowFrom = allowFrom
if v, ok := getString(cMap, "bridge_url"); ok {
cfg.Channels.WhatsApp.BridgeURL = v
}
if v, ok := getBool(cMap, "use_native"); ok {
cfg.Channels.WhatsApp.UseNative = v
}
if v, ok := getString(cMap, "session_store_path"); ok {
cfg.Channels.WhatsApp.SessionStorePath = v
}
case "feishu":
cfg.Channels.Feishu.Enabled = enabled
cfg.Channels.Feishu.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.Feishu.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.Feishu.AppSecret = v
}
if v, ok := getString(cMap, "encrypt_key"); ok {
cfg.Channels.Feishu.EncryptKey = v
}
if v, ok := getString(cMap, "verification_token"); ok {
cfg.Channels.Feishu.VerificationToken = v
}
case "qq":
cfg.Channels.QQ.Enabled = enabled
cfg.Channels.QQ.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.QQ.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.QQ.AppSecret = v
}
case "dingtalk":
cfg.Channels.DingTalk.Enabled = enabled
cfg.Channels.DingTalk.AllowFrom = allowFrom
if v, ok := getString(cMap, "client_id"); ok {
cfg.Channels.DingTalk.ClientID = v
}
if v, ok := getString(cMap, "client_secret"); ok {
cfg.Channels.DingTalk.ClientSecret = v
}
case "maixcam":
cfg.Channels.MaixCam.Enabled = enabled
cfg.Channels.MaixCam.AllowFrom = allowFrom
if v, ok := getString(cMap, "host"); ok {
cfg.Channels.MaixCam.Host = v
}
if v, ok := getFloat(cMap, "port"); ok {
cfg.Channels.MaixCam.Port = int(v)
}
}
}
}
if gateway, ok := getMap(data, "gateway"); ok {
if v, ok := getString(gateway, "host"); ok {
cfg.Gateway.Host = v
}
if v, ok := getFloat(gateway, "port"); ok {
cfg.Gateway.Port = int(v)
}
}
if tools, ok := getMap(data, "tools"); ok {
if web, ok := getMap(tools, "web"); ok {
// Migrate old "search" config to "brave" if api_key is present
if search, ok := getMap(web, "search"); ok {
if v, ok := getString(search, "api_key"); ok {
cfg.Tools.Web.Brave.APIKey = v
if v != "" {
cfg.Tools.Web.Brave.Enabled = true
}
}
if v, ok := getFloat(search, "max_results"); ok {
cfg.Tools.Web.Brave.MaxResults = int(v)
cfg.Tools.Web.DuckDuckGo.MaxResults = int(v)
}
}
}
}
return cfg, warnings, nil
}
func MergeConfig(existing, incoming *config.Config) *config.Config {
if existing.Providers.Anthropic.APIKey == "" {
existing.Providers.Anthropic = incoming.Providers.Anthropic
}
if existing.Providers.OpenAI.APIKey == "" {
existing.Providers.OpenAI = incoming.Providers.OpenAI
}
if existing.Providers.OpenRouter.APIKey == "" {
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
}
if existing.Providers.Groq.APIKey == "" {
existing.Providers.Groq = incoming.Providers.Groq
}
if existing.Providers.Zhipu.APIKey == "" {
existing.Providers.Zhipu = incoming.Providers.Zhipu
}
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
existing.Providers.VLLM = incoming.Providers.VLLM
}
if existing.Providers.Gemini.APIKey == "" {
existing.Providers.Gemini = incoming.Providers.Gemini
}
if existing.Providers.DeepSeek.APIKey == "" {
existing.Providers.DeepSeek = incoming.Providers.DeepSeek
}
if existing.Providers.GitHubCopilot.APIBase == "" {
existing.Providers.GitHubCopilot = incoming.Providers.GitHubCopilot
}
if existing.Providers.Qwen.APIKey == "" {
existing.Providers.Qwen = incoming.Providers.Qwen
}
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
existing.Channels.Telegram = incoming.Channels.Telegram
}
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
existing.Channels.Discord = incoming.Channels.Discord
}
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
}
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
existing.Channels.Feishu = incoming.Channels.Feishu
}
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
existing.Channels.QQ = incoming.Channels.QQ
}
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
existing.Channels.DingTalk = incoming.Channels.DingTalk
}
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
existing.Channels.MaixCam = incoming.Channels.MaixCam
}
if existing.Tools.Web.Brave.APIKey == "" {
existing.Tools.Web.Brave = incoming.Tools.Web.Brave
}
return existing
}
func camelToSnake(s string) string {
var result strings.Builder
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
prev := rune(s[i-1])
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
result.WriteRune('_')
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
result.WriteRune('_')
}
}
result.WriteRune(unicode.ToLower(r))
} else {
result.WriteRune(r)
}
}
return result.String()
}
func convertKeysToSnake(data any) any {
switch v := data.(type) {
case map[string]any:
result := make(map[string]any, len(v))
for key, val := range v {
result[camelToSnake(key)] = convertKeysToSnake(val)
}
return result
case []any:
result := make([]any, len(v))
for i, val := range v {
result[i] = convertKeysToSnake(val)
}
return result
default:
return data
}
}
func rewriteWorkspacePath(path string) string {
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
return path
}
func getMap(data map[string]any, key string) (map[string]any, bool) {
v, ok := data[key]
if !ok {
return nil, false
}
m, ok := v.(map[string]any)
return m, ok
}
func getString(data map[string]any, key string) (string, bool) {
v, ok := data[key]
if !ok {
return "", false
}
s, ok := v.(string)
return s, ok
}
func getFloat(data map[string]any, key string) (float64, bool) {
v, ok := data[key]
if !ok {
return 0, false
}
f, ok := v.(float64)
return f, ok
}
func getBool(data map[string]any, key string) (bool, bool) {
v, ok := data[key]
if !ok {
return false, false
}
b, ok := v.(bool)
return b, ok
}
func getBoolOrDefault(data map[string]any, key string, defaultVal bool) bool {
if v, ok := getBool(data, key); ok {
return v
}
return defaultVal
}
func getStringSlice(data map[string]any, key string) []string {
v, ok := data[key]
if !ok {
return []string{}
}
arr, ok := v.([]any)
if !ok {
return []string{}
}
result := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}
@@ -1,24 +1,50 @@
package migrate
package internal
import (
"fmt"
"io"
"os"
"path/filepath"
)
var migrateableFiles = []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"TOOLS.md",
"HEARTBEAT.md",
func ResolveTargetHome(override string) (string, error) {
if override != "" {
return ExpandHome(override), nil
}
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
return ExpandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".picoclaw"), nil
}
var migrateableDirs = []string{
"memory",
"skills",
func ExpandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
}
return home
}
return path
}
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
func ResolveWorkspace(homeDir string) string {
return filepath.Join(homeDir, "workspace")
}
func PlanWorkspaceMigration(
srcWorkspace, dstWorkspace string,
migrateableFiles []string,
migrateableDirs []string,
force bool,
) ([]Action, error) {
var actions []Action
for _, filename := range migrateableFiles {
@@ -50,7 +76,7 @@ func planFileCopy(src, dst string, force bool) Action {
return Action{
Type: ActionSkip,
Source: src,
Destination: dst,
Target: dst,
Description: "source file not found",
}
}
@@ -60,7 +86,7 @@ func planFileCopy(src, dst string, force bool) Action {
return Action{
Type: ActionBackup,
Source: src,
Destination: dst,
Target: dst,
Description: "destination exists, will backup and overwrite",
}
}
@@ -68,7 +94,7 @@ func planFileCopy(src, dst string, force bool) Action {
return Action{
Type: ActionCopy,
Source: src,
Destination: dst,
Target: dst,
Description: "copy file",
}
}
@@ -91,7 +117,7 @@ func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
if info.IsDir() {
actions = append(actions, Action{
Type: ActionCreateDir,
Destination: dst,
Target: dst,
Description: "create directory",
})
return nil
@@ -104,3 +130,33 @@ func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
return actions, err
}
func RelPath(path, base string) string {
rel, err := filepath.Rel(base, path)
if err != nil {
return filepath.Base(path)
}
return rel
}
func CopyFile(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
info, err := srcFile.Stat()
if err != nil {
return err
}
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
}
+195
View File
@@ -0,0 +1,195 @@
package internal
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExpandHome(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"", ""},
{"/absolute/path", "/absolute/path"},
{"relative/path", "relative/path"},
}
for _, tt := range tests {
result := ExpandHome(tt.input)
assert.Equal(t, tt.expected, result)
}
}
func TestExpandHomeWithTilde(t *testing.T) {
home, err := os.UserHomeDir()
require.NoError(t, err)
result := ExpandHome("~/path")
assert.Equal(t, home+"/path", result)
result = ExpandHome("~")
assert.Equal(t, home, result)
}
func TestResolveWorkspace(t *testing.T) {
result := ResolveWorkspace("/home/user/.picoclaw")
assert.Equal(t, "/home/user/.picoclaw/workspace", result)
}
func TestRelPath(t *testing.T) {
result := RelPath("/home/user/.picoclaw/workspace/file.txt", "/home/user/.picoclaw")
assert.Equal(t, "workspace/file.txt", result)
}
func TestRelPathError(t *testing.T) {
result := RelPath("relative/path", "/different/base")
assert.Equal(t, "path", result)
}
func TestResolveTargetHome(t *testing.T) {
home, err := os.UserHomeDir()
require.NoError(t, err)
result, err := ResolveTargetHome("")
require.NoError(t, err)
assert.Equal(t, filepath.Join(home, ".picoclaw"), result)
}
func TestResolveTargetHomeWithOverride(t *testing.T) {
result, err := ResolveTargetHome("/custom/path")
require.NoError(t, err)
assert.Equal(t, "/custom/path", result)
}
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
sourceFile := filepath.Join(tmpDir, "source.txt")
err := os.WriteFile(sourceFile, []byte("test content"), 0o644)
require.NoError(t, err)
dstFile := filepath.Join(tmpDir, "dest.txt")
err = CopyFile(sourceFile, dstFile)
require.NoError(t, err)
content, err := os.ReadFile(dstFile)
require.NoError(t, err)
assert.Equal(t, "test content", string(content))
}
func TestCopyFileSourceNotFound(t *testing.T) {
tmpDir := t.TempDir()
err := CopyFile(filepath.Join(tmpDir, "nonexistent.txt"), filepath.Join(tmpDir, "dest.txt"))
require.Error(t, err)
}
func TestPlanWorkspaceMigration(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("content"), 0o644)
require.NoError(t, err)
err = os.MkdirAll(filepath.Join(srcWorkspace, "subdir"), 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "subdir", "file2.txt"), []byte("content"), 0o644)
require.NoError(t, err)
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{"subdir"},
false,
)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(actions), 1)
}
func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
false,
)
require.NoError(t, err)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, ActionBackup, actions[0].Type)
}
func TestPlanWorkspaceMigrationForce(t *testing.T) {
tmpDir := t.TempDir()
srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
err := os.MkdirAll(srcWorkspace, 0o755)
require.NoError(t, err)
err = os.MkdirAll(dstWorkspace, 0o755)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
require.NoError(t, err)
actions, err := PlanWorkspaceMigration(
srcWorkspace,
dstWorkspace,
[]string{"file1.txt"},
[]string{},
true,
)
require.NoError(t, err)
require.GreaterOrEqual(t, len(actions), 1)
assert.Equal(t, ActionCopy, actions[0].Type)
}
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
tmpDir := t.TempDir()
actions, err := PlanWorkspaceMigration(
filepath.Join(tmpDir, "nonexistent"),
filepath.Join(tmpDir, "dst", "workspace"),
[]string{"file1.txt"},
[]string{},
false,
)
require.NoError(t, err)
require.Len(t, actions, 1)
assert.Equal(t, ActionSkip, actions[0].Type)
assert.Contains(t, actions[0].Description, "source file not found")
}
+52
View File
@@ -0,0 +1,52 @@
package internal
type Options struct {
DryRun bool
ConfigOnly bool
WorkspaceOnly bool
Force bool
Refresh bool
Source string
SourceHome string
TargetHome string
}
type Operation interface {
GetSourceName() string
GetSourceHome() (string, error)
GetSourceWorkspace() (string, error)
GetSourceConfigFile() (string, error)
ExecuteConfigMigration(srcConfigPath, dstConfigPath string) error
GetMigrateableFiles() []string
GetMigrateableDirs() []string
}
type HandlerFactory func(opts Options) Operation
type ActionType int
const (
ActionCopy ActionType = iota
ActionSkip
ActionBackup
ActionConvertConfig
ActionCreateDir
ActionMergeConfig
)
type Action struct {
Type ActionType
Source string
Target string
Description string
}
type Result struct {
FilesCopied int
FilesSkipped int
BackupsCreated int
ConfigMigrated bool
DirsCreated int
Warnings []string
Errors []error
}
+136 -214
View File
@@ -2,53 +2,73 @@ package migrate
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/migrate/internal"
"github.com/sipeed/picoclaw/pkg/migrate/sources/openclaw"
)
type ActionType int
type (
Options = internal.Options
Operation = internal.Operation
ActionType = internal.ActionType
Action = internal.Action
Result = internal.Result
HandlerFactory = internal.HandlerFactory
)
const (
ActionCopy ActionType = iota
ActionSkip
ActionBackup
ActionConvertConfig
ActionCreateDir
ActionMergeConfig
ActionCopy = internal.ActionCopy
ActionSkip = internal.ActionSkip
ActionBackup = internal.ActionBackup
ActionConvertConfig = internal.ActionConvertConfig
ActionCreateDir = internal.ActionCreateDir
ActionMergeConfig = internal.ActionMergeConfig
)
type Options struct {
DryRun bool
ConfigOnly bool
WorkspaceOnly bool
Force bool
Refresh bool
OpenClawHome string
PicoClawHome string
type MigrateInstance struct {
options Options
handlers map[string]Operation
}
type Action struct {
Type ActionType
Source string
Destination string
Description string
func NewMigrateInstance(opts Options) *MigrateInstance {
instance := &MigrateInstance{
options: opts,
handlers: make(map[string]Operation),
}
openclaw_handler, err := openclaw.NewOpenclawHandler(opts)
if err == nil {
instance.Register(openclaw_handler.GetSourceName(), openclaw_handler)
}
return instance
}
type Result struct {
FilesCopied int
FilesSkipped int
BackupsCreated int
ConfigMigrated bool
DirsCreated int
Warnings []string
Errors []error
func (m *MigrateInstance) Register(moduleName string, module Operation) {
m.handlers[moduleName] = module
}
func Run(opts Options) (*Result, error) {
func (m *MigrateInstance) getCurrentHandler() (Operation, error) {
source := m.options.Source
if source == "" {
source = "openclaw"
}
handler, ok := m.handlers[source]
if !ok {
return nil, fmt.Errorf("Source '%s' not found", source)
}
return handler, nil
}
func (m *MigrateInstance) Run(opts Options) (*Result, error) {
handler, err := m.getCurrentHandler()
if err != nil {
return nil, err
}
if opts.ConfigOnly && opts.WorkspaceOnly {
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
}
@@ -57,28 +77,28 @@ func Run(opts Options) (*Result, error) {
opts.WorkspaceOnly = true
}
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
sourceHome, err := handler.GetSourceHome()
if err != nil {
return nil, err
}
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
targetHome, err := internal.ResolveTargetHome(opts.TargetHome)
if err != nil {
return nil, err
}
if _, err = os.Stat(openclawHome); os.IsNotExist(err) {
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
if _, err = os.Stat(sourceHome); os.IsNotExist(err) {
return nil, fmt.Errorf("Source installation not found at %s", sourceHome)
}
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
actions, warnings, err := m.Plan(opts, sourceHome, targetHome)
if err != nil {
return nil, err
}
fmt.Println("Migrating from OpenClaw to PicoClaw")
fmt.Printf(" Source: %s\n", openclawHome)
fmt.Printf(" Destination: %s\n", picoClawHome)
fmt.Println("Migrating from Source to PicoClaw")
fmt.Printf(" Source: %s\n", sourceHome)
fmt.Printf(" Target: %s\n", targetHome)
fmt.Println()
if opts.DryRun {
@@ -95,19 +115,23 @@ func Run(opts Options) (*Result, error) {
fmt.Println()
}
result := Execute(actions, openclawHome, picoClawHome)
result := m.Execute(actions, sourceHome, targetHome)
result.Warnings = warnings
return result, nil
}
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
func (m *MigrateInstance) Plan(opts Options, sourceHome, targetHome string) ([]Action, []string, error) {
var actions []Action
var warnings []string
handler, err := m.getCurrentHandler()
if err != nil {
return nil, nil, err
}
force := opts.Force || opts.Refresh
if !opts.WorkspaceOnly {
configPath, err := findOpenClawConfig(openclawHome)
configPath, err := handler.GetSourceConfigFile()
if err != nil {
if opts.ConfigOnly {
return nil, nil, err
@@ -117,91 +141,95 @@ func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string,
actions = append(actions, Action{
Type: ActionConvertConfig,
Source: configPath,
Destination: filepath.Join(picoClawHome, "config.json"),
Description: "convert OpenClaw config to PicoClaw format",
Target: filepath.Join(targetHome, "config.json"),
Description: "convert Source config to PicoClaw format",
})
data, err := LoadOpenClawConfig(configPath)
if err == nil {
_, configWarnings, _ := ConvertConfig(data)
warnings = append(warnings, configWarnings...)
}
}
}
if !opts.ConfigOnly {
srcWorkspace := resolveWorkspace(openclawHome)
dstWorkspace := resolveWorkspace(picoClawHome)
srcWorkspace, err := handler.GetSourceWorkspace()
if err != nil {
return nil, nil, fmt.Errorf("getting source workspace: %w", err)
}
dstWorkspace := internal.ResolveWorkspace(targetHome)
if _, err := os.Stat(srcWorkspace); err == nil {
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
wsActions, err := internal.PlanWorkspaceMigration(srcWorkspace, dstWorkspace,
handler.GetMigrateableFiles(),
handler.GetMigrateableDirs(),
force)
if err != nil {
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
}
actions = append(actions, wsActions...)
} else {
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
warnings = append(warnings, "Source workspace directory not found, skipping workspace migration")
}
}
return actions, warnings, nil
}
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
func (m *MigrateInstance) Execute(actions []Action, sourceHome, targetHome string) *Result {
result := &Result{}
handler, err := m.getCurrentHandler()
if err != nil {
return result
}
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
if err := handler.ExecuteConfigMigration(action.Source, action.Target); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
fmt.Printf(" ✗ Config migration failed: %v\n", err)
} else {
result.ConfigMigrated = true
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
fmt.Printf(" ✓ Converted config: %s\n", action.Target)
}
case ActionCreateDir:
if err := os.MkdirAll(action.Destination, 0o755); err != nil {
if err := os.MkdirAll(action.Target, 0o755); err != nil {
result.Errors = append(result.Errors, err)
} else {
result.DirsCreated++
}
case ActionBackup:
bakPath := action.Destination + ".bak"
if err := copyFile(action.Destination, bakPath); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
bakPath := action.Target + ".bak"
if err := internal.CopyFile(action.Target, bakPath); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Target, err))
fmt.Printf(" ✗ Backup failed: %s\n", action.Target)
continue
}
result.BackupsCreated++
fmt.Printf(
" ✓ Backed up %s -> %s.bak\n",
filepath.Base(action.Destination),
filepath.Base(action.Destination),
filepath.Base(action.Target),
filepath.Base(action.Target),
)
if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
if err := os.MkdirAll(filepath.Dir(action.Target), 0o755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
if err := internal.CopyFile(action.Source, action.Target); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
fmt.Printf(" ✓ Copied %s\n", internal.RelPath(action.Source, sourceHome))
}
case ActionCopy:
if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
if err := os.MkdirAll(filepath.Dir(action.Target), 0o755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
if err := internal.CopyFile(action.Source, action.Target); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
fmt.Printf(" ✓ Copied %s\n", internal.RelPath(action.Source, sourceHome))
}
case ActionSkip:
result.FilesSkipped++
@@ -211,31 +239,6 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
return result
}
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
data, err := LoadOpenClawConfig(srcConfigPath)
if err != nil {
return err
}
incoming, _, err := ConvertConfig(data)
if err != nil {
return err
}
if _, err := os.Stat(dstConfigPath); err == nil {
existing, err := config.LoadConfig(dstConfigPath)
if err != nil {
return fmt.Errorf("loading existing PicoClaw config: %w", err)
}
incoming = MergeConfig(existing, incoming)
}
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil {
return err
}
return config.SaveConfig(dstConfigPath, incoming)
}
func Confirm() bool {
fmt.Print("Proceed with migration? (y/n): ")
var response string
@@ -243,49 +246,7 @@ func Confirm() bool {
return strings.ToLower(strings.TrimSpace(response)) == "y"
}
func PrintPlan(actions []Action, warnings []string) {
fmt.Println("Planned actions:")
copies := 0
skips := 0
backups := 0
configCount := 0
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
configCount++
case ActionCopy:
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
copies++
case ActionBackup:
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
backups++
copies++
case ActionSkip:
if action.Description != "" {
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
}
skips++
case ActionCreateDir:
fmt.Printf(" [mkdir] %s\n", action.Destination)
}
}
if len(warnings) > 0 {
fmt.Println()
fmt.Println("Warnings:")
for _, w := range warnings {
fmt.Printf(" - %s\n", w)
}
}
fmt.Println()
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
copies, configCount, backups, skips)
}
func PrintSummary(result *Result) {
func (m *MigrateInstance) PrintSummary(result *Result) {
fmt.Println()
parts := []string{}
if result.FilesCopied > 0 {
@@ -316,83 +277,44 @@ func PrintSummary(result *Result) {
}
}
func resolveOpenClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".openclaw"), nil
}
func PrintPlan(actions []Action, warnings []string) {
fmt.Println("Planned actions:")
copies := 0
skips := 0
backups := 0
configCount := 0
func resolvePicoClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".picoclaw"), nil
}
func resolveWorkspace(homeDir string) string {
return filepath.Join(homeDir, "workspace")
}
func expandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Target)
configCount++
case ActionCopy:
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
copies++
case ActionBackup:
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Target))
backups++
copies++
case ActionSkip:
if action.Description != "" {
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
}
skips++
case ActionCreateDir:
fmt.Printf(" [mkdir] %s\n", action.Target)
}
return home
}
return path
}
func backupFile(path string) error {
bakPath := path + ".bak"
return copyFile(path, bakPath)
}
func copyFile(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
info, err := srcFile.Stat()
if err != nil {
return err
}
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
}
func relPath(path, base string) string {
rel, err := filepath.Rel(base, path)
if err != nil {
return filepath.Base(path)
}
return rel
if len(warnings) > 0 {
fmt.Println()
fmt.Println("Warnings:")
for _, w := range warnings {
fmt.Printf(" - %s\n", w)
}
}
fmt.Println()
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
copies, configCount, backups, skips)
}
File diff suppressed because it is too large Load Diff
+29
View File
@@ -0,0 +1,29 @@
package openclaw
var migrateableFiles = []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"TOOLS.md",
"HEARTBEAT.md",
}
var migrateableDirs = []string{
"memory",
"skills",
}
var supportedChannels = map[string]bool{
"whatsapp": true,
"telegram": true,
"feishu": true,
"discord": true,
"maixcam": true,
"qq": true,
"dingtalk": true,
"slack": true,
"line": true,
"onebot": true,
"wecom": true,
"wecom_app": true,
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,714 @@
package openclaw
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestLoadOpenClawConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
testConfig := `{
"agents": {
"defaults": {
"model": {
"primary": "anthropic/claude-sonnet-4-20250514"
},
"workspace": "~/.openclaw/workspace"
},
"list": [
{
"id": "main",
"name": "Main Agent",
"model": {
"primary": "openai/gpt-4o",
"fallbacks": ["claude-3-opus"]
}
}
]
},
"channels": {
"telegram": {
"enabled": true,
"botToken": "test-token",
"allowFrom": ["user1", "user2"]
},
"discord": {
"enabled": true,
"token": "discord-token"
}
},
"models": {
"providers": {
"anthropic": {
"api_key": "sk-ant-test",
"base_url": "https://api.anthropic.com"
},
"openai": {
"api_key": "sk-test"
}
}
}
}`
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
if err != nil {
t.Fatalf("failed to write test config: %v", err)
}
cfg, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
if cfg.Agents == nil {
t.Error("agents should not be nil")
}
if cfg.Agents.Defaults == nil {
t.Error("agents.defaults should not be nil")
}
provider, model := cfg.GetDefaultModel()
if provider != "anthropic" {
t.Errorf("expected provider 'anthropic', got '%s'", provider)
}
if model != "claude-sonnet-4-20250514" {
t.Errorf("expected model 'claude-sonnet-4-20250514', got '%s'", model)
}
workspace := cfg.GetDefaultWorkspace()
if workspace != "~/.picoclaw/workspace" {
t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", workspace)
}
agents := cfg.GetAgents()
if len(agents) != 1 {
t.Errorf("expected 1 agent, got %d", len(agents))
}
if agents[0].ID != "main" {
t.Errorf("expected agent id 'main', got '%s'", agents[0].ID)
}
if cfg.Channels == nil {
t.Error("channels should not be nil")
}
if cfg.Channels.Telegram == nil {
t.Error("telegram channel should not be nil")
}
if cfg.Channels.Telegram.BotToken == nil || *cfg.Channels.Telegram.BotToken != "test-token" {
t.Error("telegram bot token not parsed correctly")
}
}
func TestGetProviderConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
testConfig := `{
"models": {
"providers": {
"anthropic": {
"api_key": "sk-ant-test",
"base_url": "https://api.anthropic.com",
"max_tokens": 4096
},
"openai": {
"api_key": "sk-test",
"base_url": "https://api.openai.com"
}
}
}
}`
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
if err != nil {
t.Fatalf("failed to write test config: %v", err)
}
cfg, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
providers := GetProviderConfig(cfg.Models)
if len(providers) != 2 {
t.Errorf("expected 2 providers, got %d", len(providers))
}
if anthropic, ok := providers["anthropic"]; ok {
if anthropic.APIKey != "sk-ant-test" {
t.Errorf("expected anthropic api_key 'sk-ant-test', got '%s'", anthropic.APIKey)
}
if anthropic.BaseURL != "https://api.anthropic.com" {
t.Errorf("expected anthropic base_url 'https://api.anthropic.com', got '%s'", anthropic.BaseURL)
}
} else {
t.Error("anthropic provider not found")
}
if openai, ok := providers["openai"]; ok {
if openai.APIKey != "sk-test" {
t.Errorf("expected openai api_key 'sk-test', got '%s'", openai.APIKey)
}
} else {
t.Error("openai provider not found")
}
}
func TestConvertToPicoClaw(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
testConfig := `{
"agents": {
"defaults": {
"model": {
"primary": "anthropic/claude-sonnet-4-20250514"
},
"workspace": "~/.openclaw/workspace"
},
"list": [
{
"id": "main",
"name": "Main Agent"
},
{
"id": "assistant",
"name": "Assistant",
"skills": ["skill1", "skill2"]
}
]
},
"channels": {
"telegram": {
"enabled": true,
"botToken": "test-token",
"allowFrom": ["user1", "user2"]
},
"discord": {
"enabled": false,
"token": "discord-token"
},
"whatsapp": {
"enabled": true,
"bridgeUrl": "http://localhost:3000"
},
"feishu": {
"enabled": true,
"appId": "app-id",
"appSecret": "app-secret",
"allowFrom": ["user3"]
},
"signal": {
"enabled": true
}
},
"models": {
"providers": {
"anthropic": {
"api_key": "sk-ant-test"
},
"openai": {
"api_key": "sk-test"
}
}
},
"skills": {
"entries": {
"skill1": {}
}
},
"memory": {"enabled": true},
"cron": {"enabled": true}
}`
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
if err != nil {
t.Fatalf("failed to write test config: %v", err)
}
cfg, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
picoCfg, warnings, err := cfg.ConvertToPicoClaw("")
if err != nil {
t.Fatalf("failed to convert config: %v", err)
}
if picoCfg.Agents.Defaults.ModelName != "claude-sonnet-4-20250514" {
t.Errorf("expected model 'claude-sonnet-4-20250514', got '%s'", picoCfg.Agents.Defaults.ModelName)
}
if picoCfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", picoCfg.Agents.Defaults.Workspace)
}
if len(picoCfg.Agents.List) != 2 {
t.Errorf("expected 2 agents, got %d", len(picoCfg.Agents.List))
}
if picoCfg.Agents.List[0].ID != "main" {
t.Errorf("expected first agent id 'main', got '%s'", picoCfg.Agents.List[0].ID)
}
if picoCfg.Agents.List[1].Skills == nil || len(picoCfg.Agents.List[1].Skills) != 2 {
t.Errorf("expected 2 skills for assistant agent")
}
if !picoCfg.Channels.Telegram.Enabled {
t.Error("telegram should be enabled")
}
if picoCfg.Channels.Telegram.Token != "test-token" {
t.Errorf("expected telegram token 'test-token', got '%s'", picoCfg.Channels.Telegram.Token)
}
if picoCfg.Channels.WhatsApp.BridgeURL != "http://localhost:3000" {
t.Errorf("expected whatsapp bridge URL 'http://localhost:3000', got '%s'", picoCfg.Channels.WhatsApp.BridgeURL)
}
if picoCfg.Channels.Feishu.AppID != "app-id" {
t.Errorf("expected feishu app ID 'app-id', got '%s'", picoCfg.Channels.Feishu.AppID)
}
if len(picoCfg.ModelList) != 1 {
t.Errorf("expected 1 model config (no models.json provided), got %d", len(picoCfg.ModelList))
}
foundWarning := false
for _, w := range warnings {
if len(w) > 0 {
foundWarning = true
break
}
}
if !foundWarning {
t.Log("warnings should be generated for skills, memory, cron, and unsupported channels")
}
}
func TestConvertToPicoClawWithQQAndDingTalk(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
testConfig := `{
"agents": {
"defaults": {
"model": {
"primary": "anthropic/claude-sonnet-4-20250514"
}
}
},
"channels": {
"qq": {
"enabled": true,
"appId": "qq-app-id",
"appSecret": "qq-app-secret"
},
"dingtalk": {
"enabled": true,
"appId": "ding-app-id",
"appSecret": "ding-app-secret"
},
"maixcam": {
"enabled": true,
"host": "192.168.1.100",
"port": 9000
},
"slack": {
"enabled": true,
"botToken": "xoxb-test",
"appToken": "xapp-test"
}
}
}`
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
if err != nil {
t.Fatalf("failed to write test config: %v", err)
}
cfg, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
picoCfg, _, err := cfg.ConvertToPicoClaw("")
if err != nil {
t.Fatalf("failed to convert config: %v", err)
}
if !picoCfg.Channels.QQ.Enabled {
t.Error("qq should be enabled")
}
if picoCfg.Channels.QQ.AppID != "qq-app-id" {
t.Errorf("expected qq app ID 'qq-app-id', got '%s'", picoCfg.Channels.QQ.AppID)
}
if !picoCfg.Channels.DingTalk.Enabled {
t.Error("dingtalk should be enabled")
}
if picoCfg.Channels.DingTalk.ClientID != "ding-app-id" {
t.Errorf("expected dingtalk client ID 'ding-app-id', got '%s'", picoCfg.Channels.DingTalk.ClientID)
}
if !picoCfg.Channels.MaixCam.Enabled {
t.Error("maixcam should be enabled")
}
if picoCfg.Channels.MaixCam.Host != "192.168.1.100" {
t.Errorf("expected maixcam host '192.168.1.100', got '%s'", picoCfg.Channels.MaixCam.Host)
}
if picoCfg.Channels.MaixCam.Port != 9000 {
t.Errorf("expected maixcam port 9000, got %d", picoCfg.Channels.MaixCam.Port)
}
if !picoCfg.Channels.Slack.Enabled {
t.Error("slack should be enabled")
}
if picoCfg.Channels.Slack.BotToken != "xoxb-test" {
t.Errorf("expected slack bot token 'xoxb-test', got '%s'", picoCfg.Channels.Slack.BotToken)
}
if picoCfg.Channels.Slack.AppToken != "xapp-test" {
t.Errorf("expected slack app token 'xapp-test', got '%s'", picoCfg.Channels.Slack.AppToken)
}
}
func TestOpenClawAgentModel(t *testing.T) {
model := &OpenClawAgentModel{
Primary: strPtr("anthropic/claude-3-opus"),
Fallbacks: []string{"claude-3-sonnet", "claude-3-haiku"},
}
primary := model.GetPrimary()
if primary != "anthropic/claude-3-opus" {
t.Errorf("expected primary 'anthropic/claude-3-opus', got '%s'", primary)
}
fallbacks := model.GetFallbacks()
if len(fallbacks) != 2 {
t.Errorf("expected 2 fallbacks, got %d", len(fallbacks))
}
model2 := &OpenClawAgentModel{
Simple: "claude-3-opus",
}
primary2 := model2.GetPrimary()
if primary2 != "claude-3-opus" {
t.Errorf("expected primary 'claude-3-opus' from Simple, got '%s'", primary2)
}
}
func TestChannelEnabled(t *testing.T) {
cfg := &OpenClawConfig{
Channels: &OpenClawChannels{
Telegram: &OpenClawTelegramConfig{
Enabled: boolPtr(true),
},
Discord: &OpenClawDiscordConfig{
Enabled: boolPtr(false),
},
Slack: &OpenClawSlackConfig{
Enabled: boolPtr(true),
},
},
}
if !cfg.IsChannelEnabled("telegram") {
t.Error("telegram should be enabled")
}
if cfg.IsChannelEnabled("discord") {
t.Error("discord should be disabled")
}
if !cfg.IsChannelEnabled("slack") {
t.Error("slack should be enabled (explicitly set)")
}
if cfg.IsChannelEnabled("line") {
t.Error("line should return false (not in switch cases)")
}
}
func TestGetDefaultModel(t *testing.T) {
cfg := &OpenClawConfig{
Agents: &OpenClawAgents{
Defaults: &OpenClawAgentDefaults{
Model: &OpenClawAgentModel{
Primary: strPtr("openai/gpt-4"),
},
},
},
}
provider, model := cfg.GetDefaultModel()
if provider != "openai" {
t.Errorf("expected provider 'openai', got '%s'", provider)
}
if model != "gpt-4" {
t.Errorf("expected model 'gpt-4', got '%s'", model)
}
}
func TestGetDefaultModelWithNoDefaults(t *testing.T) {
cfg := &OpenClawConfig{}
provider, model := cfg.GetDefaultModel()
if provider != "anthropic" {
t.Errorf("expected default provider 'anthropic', got '%s'", provider)
}
if model != "claude-sonnet-4-20250514" {
t.Errorf("expected default model 'claude-sonnet-4-20250514', got '%s'", model)
}
}
func TestHasFunctions(t *testing.T) {
cfg := &OpenClawConfig{
Skills: &OpenClawSkills{Entries: map[string]json.RawMessage{"skill1": nil}},
Memory: json.RawMessage(`{"enabled": true}`),
Cron: json.RawMessage(`{"enabled": true}`),
Hooks: json.RawMessage(`{"enabled": true}`),
Session: json.RawMessage(`{"enabled": true}`),
Auth: &OpenClawAuth{Profiles: json.RawMessage(`{"profile1": {}}`)},
}
if !cfg.HasSkills() {
t.Error("should have skills")
}
if !cfg.HasMemory() {
t.Error("should have memory")
}
if !cfg.HasCron() {
t.Error("should have cron")
}
if !cfg.HasHooks() {
t.Error("should have hooks")
}
if !cfg.HasSession() {
t.Error("should have session")
}
if !cfg.HasAuthProfiles() {
t.Error("should have auth profiles")
}
cfg2 := &OpenClawConfig{}
if cfg2.HasSkills() {
t.Error("should not have skills")
}
if cfg2.HasMemory() {
t.Error("should not have memory")
}
}
func TestLoadOpenClawConfigFromDir(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
testConfig := `{"agents": {}}`
err := os.WriteFile(configPath, []byte(testConfig), 0o644)
if err != nil {
t.Fatalf("failed to write test config: %v", err)
}
cfg, err := LoadOpenClawConfigFromDir(tmpDir)
if err != nil {
t.Fatalf("failed to load config from dir: %v", err)
}
if cfg.Agents == nil {
t.Error("agents should not be nil")
}
_, err = LoadOpenClawConfigFromDir("/nonexistent/dir")
if err == nil {
t.Error("should return error for nonexistent dir")
}
}
func TestToStandardConfig(t *testing.T) {
picoCfg := &PicoClawConfig{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: "anthropic",
ModelName: "claude-sonnet-4-20250514",
Workspace: "~/.picoclaw/workspace",
},
List: []AgentConfig{
{
ID: "main",
Name: "Main Agent",
Default: true,
},
},
},
ModelList: []ModelConfig{
{
ModelName: "claude-sonnet-4-20250514",
Model: "anthropic/claude-sonnet-4-20250514",
APIKey: "sk-ant-test",
},
},
Channels: ChannelsConfig{
Telegram: TelegramConfig{
Enabled: true,
Token: "test-token",
AllowFrom: []string{"user1"},
},
WhatsApp: WhatsAppConfig{
Enabled: true,
BridgeURL: "http://localhost:3000",
},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
Port: 8080,
},
}
stdCfg := picoCfg.ToStandardConfig()
if stdCfg.Agents.Defaults.Provider != "anthropic" {
t.Errorf("expected provider 'anthropic', got '%s'", stdCfg.Agents.Defaults.Provider)
}
if stdCfg.Agents.Defaults.ModelName != "claude-sonnet-4-20250514" {
t.Errorf("expected model name 'claude-sonnet-4-20250514', got '%s'", stdCfg.Agents.Defaults.ModelName)
}
if stdCfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", stdCfg.Agents.Defaults.Workspace)
}
if len(stdCfg.Agents.List) != 1 {
t.Errorf("expected 1 agent, got %d", len(stdCfg.Agents.List))
}
if stdCfg.Agents.List[0].ID != "main" {
t.Errorf("expected agent id 'main', got '%s'", stdCfg.Agents.List[0].ID)
}
foundModel := false
var foundAPIKey string
for _, m := range stdCfg.ModelList {
if m.ModelName == "claude-sonnet-4-20250514" {
foundModel = true
foundAPIKey = m.APIKey
break
}
}
if !foundModel {
t.Error("expected to find claude-sonnet-4-20250514 model config")
}
if foundAPIKey != "sk-ant-test" {
t.Errorf("expected api key 'sk-ant-test', got '%s'", foundAPIKey)
}
if !stdCfg.Channels.Telegram.Enabled {
t.Error("telegram should be enabled")
}
if stdCfg.Channels.Telegram.Token != "test-token" {
t.Errorf("expected token 'test-token', got '%s'", stdCfg.Channels.Telegram.Token)
}
if stdCfg.Gateway.Port != 8080 {
t.Errorf("expected gateway port 8080, got %d", stdCfg.Gateway.Port)
}
}
func TestLoadProviderConfigFromAgentsDir(t *testing.T) {
tmpDir := t.TempDir()
agentsDir := filepath.Join(tmpDir, "agents", "main", "agent")
err := os.MkdirAll(agentsDir, 0o755)
if err != nil {
t.Fatalf("failed to create agents dir: %v", err)
}
modelsJSON := `{
"providers": {
"anthropic": {
"baseUrl": "https://api.anthropic.com",
"api": "anthropic",
"apiKey": "sk-ant-from-models",
"models": [
{
"id": "claude-sonnet-4-20250514",
"name": "Claude Sonnet 4"
}
]
},
"openai": {
"baseUrl": "https://api.openai.com",
"api": "openai",
"apiKey": "sk-from-models",
"models": [
{
"id": "gpt-4o",
"name": "GPT-4o"
}
]
},
"zhipu": {
"baseUrl": "https://open.bigmodel.cn/api/paas/v4",
"api": "openai",
"apiKey": "zhipu-key",
"models": []
}
}
}`
err = os.WriteFile(filepath.Join(agentsDir, "models.json"), []byte(modelsJSON), 0o644)
if err != nil {
t.Fatalf("failed to write models.json: %v", err)
}
providers := GetProviderConfigFromDir(tmpDir)
if len(providers) != 3 {
t.Errorf("expected 3 providers, got %d", len(providers))
}
if anthropic, ok := providers["anthropic"]; ok {
if anthropic.ApiKey != "sk-ant-from-models" {
t.Errorf("expected anthropic apiKey 'sk-ant-from-models', got '%s'", anthropic.ApiKey)
}
if anthropic.BaseUrl != "https://api.anthropic.com" {
t.Errorf("expected anthropic baseUrl 'https://api.anthropic.com', got '%s'", anthropic.BaseUrl)
}
} else {
t.Error("anthropic provider not found")
}
if openai, ok := providers["openai"]; ok {
if openai.ApiKey != "sk-from-models" {
t.Errorf("expected openai apiKey 'sk-from-models', got '%s'", openai.ApiKey)
}
if openai.BaseUrl != "https://api.openai.com" {
t.Errorf("expected openai baseUrl 'https://api.openai.com', got '%s'", openai.BaseUrl)
}
} else {
t.Error("openai provider not found")
}
if zhipu, ok := providers["zhipu"]; ok {
if zhipu.ApiKey != "zhipu-key" {
t.Errorf("expected zhipu apiKey 'zhipu-key', got '%s'", zhipu.ApiKey)
}
if zhipu.BaseUrl != "https://open.bigmodel.cn/api/paas/v4" {
t.Errorf("expected zhipu baseUrl 'https://open.bigmodel.cn/api/paas/v4', got '%s'", zhipu.BaseUrl)
}
} else {
t.Error("zhipu provider not found")
}
}
func TestGetProviderConfigFromDirNotExist(t *testing.T) {
providers := GetProviderConfigFromDir("/nonexistent/path")
if len(providers) != 0 {
t.Errorf("expected 0 providers for nonexistent path, got %d", len(providers))
}
}
func strPtr(s string) *string {
return &s
}
func boolPtr(b bool) *bool {
return &b
}
@@ -0,0 +1,148 @@
package openclaw
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/migrate/internal"
)
var providerMapping = map[string]string{
"anthropic": "anthropic",
"claude": "anthropic",
"openai": "openai",
"gpt": "openai",
"groq": "groq",
"ollama": "ollama",
"openrouter": "openrouter",
"deepseek": "deepseek",
"together": "together",
"mistral": "mistral",
"fireworks": "fireworks",
"google": "google",
"gemini": "google",
"xai": "xai",
"grok": "xai",
"cerebras": "cerebras",
"sambanova": "sambanova",
}
type OpenclawHandler struct {
opts Options
sourceConfigFile string
sourceWorkspace string
}
type (
Options = internal.Options
Action = internal.Action
Result = internal.Result
Operation = internal.Operation
)
func NewOpenclawHandler(opts Options) (Operation, error) {
home, err := resolveSourceHome(opts.SourceHome)
if err != nil {
return nil, err
}
opts.SourceHome = home
configFile, err := findSourceConfig(home)
if err != nil {
return nil, err
}
return &OpenclawHandler{
opts: opts,
sourceWorkspace: filepath.Join(opts.SourceHome, "workspace"),
sourceConfigFile: configFile,
}, nil
}
func (o *OpenclawHandler) GetSourceName() string {
return "openclaw"
}
func (o *OpenclawHandler) GetSourceHome() (string, error) {
return o.opts.SourceHome, nil
}
func (o *OpenclawHandler) GetSourceWorkspace() (string, error) {
return o.sourceWorkspace, nil
}
func (o *OpenclawHandler) GetSourceConfigFile() (string, error) {
return o.sourceConfigFile, nil
}
func (o *OpenclawHandler) GetMigrateableFiles() []string {
return migrateableFiles
}
func (o *OpenclawHandler) GetMigrateableDirs() []string {
return migrateableDirs
}
func (o *OpenclawHandler) ExecuteConfigMigration(srcConfigPath, dstConfigPath string) error {
openclawCfg, err := LoadOpenClawConfig(srcConfigPath)
if err != nil {
return err
}
picoCfg, warnings, err := openclawCfg.ConvertToPicoClaw(o.opts.SourceHome)
if err != nil {
return err
}
for _, w := range warnings {
fmt.Printf(" Warning: %s\n", w)
}
incoming := picoCfg.ToStandardConfig()
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil {
return err
}
return config.SaveConfig(dstConfigPath, incoming)
}
func resolveSourceHome(override string) (string, error) {
if override != "" {
return internal.ExpandHome(override), nil
}
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
return internal.ExpandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".openclaw"), nil
}
func findSourceConfig(sourceHome string) (string, error) {
candidates := []string{
filepath.Join(sourceHome, "openclaw.json"),
filepath.Join(sourceHome, "config.json"),
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p, nil
}
}
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", sourceHome)
}
func rewriteWorkspacePath(path string) string {
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
return path
}
func mapProvider(provider string) string {
if mapped, ok := providerMapping[strings.ToLower(provider)]; ok {
return mapped
}
return strings.ToLower(provider)
}
@@ -0,0 +1,247 @@
package openclaw
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewOpenclawHandler(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
require.NotNil(t, handler)
}
func TestNewOpenclawHandlerNoConfig(t *testing.T) {
tmpDir := t.TempDir()
_, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.Error(t, err)
}
func TestOpenclawHandlerGetSourceName(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
assert.Equal(t, "openclaw", handler.GetSourceName())
}
func TestOpenclawHandlerGetSourceHome(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
home, err := handler.GetSourceHome()
require.NoError(t, err)
assert.Equal(t, tmpDir, home)
}
func TestOpenclawHandlerGetSourceWorkspace(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
workspace, err := handler.GetSourceWorkspace()
require.NoError(t, err)
assert.Equal(t, filepath.Join(tmpDir, "workspace"), workspace)
}
func TestOpenclawHandlerGetSourceConfigFile(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
configFile, err := handler.GetSourceConfigFile()
require.NoError(t, err)
assert.Equal(t, configPath, configFile)
}
func TestOpenclawHandlerGetSourceConfigFileWithConfigJson(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
configFile, err := handler.GetSourceConfigFile()
require.NoError(t, err)
assert.Equal(t, configPath, configFile)
}
func TestOpenclawHandlerGetMigrateableFiles(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
files := handler.GetMigrateableFiles()
assert.NotEmpty(t, files)
assert.Contains(t, files, "AGENTS.md")
assert.Contains(t, files, "SOUL.md")
assert.Contains(t, files, "USER.md")
}
func TestOpenclawHandlerGetMigrateableDirs(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
handler, err := NewOpenclawHandler(Options{
SourceHome: tmpDir,
})
require.NoError(t, err)
dirs := handler.GetMigrateableDirs()
assert.NotEmpty(t, dirs)
assert.Contains(t, dirs, "memory")
assert.Contains(t, dirs, "skills")
}
func TestResolveSourceHome(t *testing.T) {
result, err := resolveSourceHome("/custom/path")
require.NoError(t, err)
assert.Equal(t, "/custom/path", result)
}
func TestResolveSourceHomeWithEnvVar(t *testing.T) {
t.Setenv("OPENCLAW_HOME", "/env/path")
result, err := resolveSourceHome("")
require.NoError(t, err)
assert.Equal(t, "/env/path", result)
}
func TestResolveSourceHomeWithTilde(t *testing.T) {
home, err := os.UserHomeDir()
require.NoError(t, err)
result, err := resolveSourceHome("~/openclaw")
require.NoError(t, err)
assert.Equal(t, filepath.Join(home, "openclaw"), result)
}
func TestFindSourceConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
result, err := findSourceConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, configPath, result)
}
func TestFindSourceConfigWithConfigJson(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
err := os.WriteFile(configPath, []byte("{}"), 0o644)
require.NoError(t, err)
result, err := findSourceConfig(tmpDir)
require.NoError(t, err)
assert.Equal(t, configPath, result)
}
func TestFindSourceConfigNotFound(t *testing.T) {
tmpDir := t.TempDir()
_, err := findSourceConfig(tmpDir)
require.Error(t, err)
assert.Contains(t, err.Error(), "no config file found")
}
func TestMapProvider(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"anthropic", "anthropic"},
{"claude", "anthropic"},
{"openai", "openai"},
{"gpt", "openai"},
{"groq", "groq"},
{"ollama", "ollama"},
{"openrouter", "openrouter"},
{"deepseek", "deepseek"},
{"together", "together"},
{"mistral", "mistral"},
{"fireworks", "fireworks"},
{"google", "google"},
{"gemini", "google"},
{"xai", "xai"},
{"grok", "xai"},
{"cerebras", "cerebras"},
{"sambanova", "sambanova"},
{"unknown", "unknown"},
{"", ""},
}
for _, tt := range tests {
result := mapProvider(tt.input)
assert.Equal(t, tt.expected, result, "mapProvider(%q)", tt.input)
}
}
func TestRewriteWorkspacePath(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"~/.openclaw/workspace", "~/.picoclaw/workspace"},
{"/home/user/.openclaw/workspace", "/home/user/.picoclaw/workspace"},
{"/path/without/openclaw/change", "/path/without/openclaw/change"},
{"", ""},
}
for _, tt := range tests {
result := rewriteWorkspacePath(tt.input)
assert.Equal(t, tt.expected, result, "rewriteWorkspacePath(%q)", tt.input)
}
}
+5 -5
View File
@@ -212,14 +212,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
}
func parseResponse(resp *anthropic.Message) *LLMResponse {
var content string
var content strings.Builder
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
content.WriteString(tb.Text)
case "tool_use":
tu := block.AsToolUse()
var args map[string]any
@@ -246,7 +246,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
}
return &LLMResponse{
Content: content,
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
@@ -264,8 +264,8 @@ func normalizeBaseURL(apiBase string) string {
}
base = strings.TrimRight(base, "/")
if strings.HasSuffix(base, "/v1") {
base = strings.TrimSuffix(base, "/v1")
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
if base == "" {
return defaultBaseURL
+2 -2
View File
@@ -163,8 +163,8 @@ func resolveCodexModel(model string) (string, string) {
return codexDefaultModel, "empty model"
}
if strings.HasPrefix(m, "openai/") {
m = strings.TrimPrefix(m, "openai/")
if after, ok := strings.CutPrefix(m, "openai/"); ok {
m = after
} else if strings.Contains(m, "/") {
return codexDefaultModel, "non-openai model namespace"
}
+2 -2
View File
@@ -138,7 +138,7 @@ func TestCooldown_FailureWindowReset(t *testing.T) {
ct, current := newTestTracker(now)
// 4 errors → 1h cooldown
for i := 0; i < 4; i++ {
for range 4 {
ct.MarkFailure("openai", FailoverRateLimit)
*current = current.Add(2 * time.Second) // small advance between errors
}
@@ -230,7 +230,7 @@ func TestCooldown_ConcurrentAccess(t *testing.T) {
ct := NewCooldownTracker()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
for range 100 {
wg.Add(3)
go func() {
defer wg.Done()
+9 -9
View File
@@ -6,6 +6,13 @@ import (
"strings"
)
// Common patterns in Go HTTP error messages
var httpStatusPatterns = []*regexp.Regexp{
regexp.MustCompile(`status[:\s]+(\d{3})`),
regexp.MustCompile(`http[/\s]+\d*\.?\d*\s+(\d{3})`),
regexp.MustCompile(`\b([3-5]\d{2})\b`),
}
// errorPattern defines a single pattern (string or regex) for error classification.
type errorPattern struct {
substring string
@@ -198,20 +205,13 @@ func classifyByMessage(msg string) FailoverReason {
}
// extractHTTPStatus extracts an HTTP status code from an error message.
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
// Looks for patterns like "status: 429", "status 429", "http/1.1 429", "http 429", or standalone "429".
func extractHTTPStatus(msg string) int {
// Common patterns in Go HTTP error messages
patterns := []*regexp.Regexp{
regexp.MustCompile(`status[:\s]+(\d{3})`),
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
}
for _, p := range patterns {
for _, p := range httpStatusPatterns {
if m := p.FindStringSubmatch(msg); len(m) > 1 {
return parseDigits(m[1])
}
}
return 0
}
+2 -1
View File
@@ -305,7 +305,8 @@ func TestExtractHTTPStatus(t *testing.T) {
}{
{"status: 429 rate limited", 429},
{"status 401 unauthorized", 401},
{"HTTP/1.1 502 Bad Gateway", 502},
{"http/1.1 502 bad gateway", 502},
{"error 429", 429},
{"no status code here", 0},
{"random number 12345", 0},
}
+7 -3
View File
@@ -26,8 +26,9 @@ func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*Gi
switch connectMode {
case "stdio":
// TODO:
return nil, fmt.Errorf("stdio mode not implemented")
// TODO: Implement stdio mode for GitHub Copilot provider
// See https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md for details
return nil, fmt.Errorf("stdio mode not implemented for GitHub Copilot provider; please use 'grpc' mode instead")
case "grpc":
client := copilot.NewClient(&copilot.ClientOptions{
CLIUrl: uri,
@@ -100,9 +101,12 @@ func (p *GitHubCopilotProvider) Chat(
return nil, fmt.Errorf("provider closed")
}
resp, _ := session.SendAndWait(ctx, copilot.MessageOptions{
resp, err := session.SendAndWait(ctx, copilot.MessageOptions{
Prompt: string(fullcontent),
})
if err != nil {
return nil, fmt.Errorf("failed to send message to copilot: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("empty response from copilot")
+4 -4
View File
@@ -312,8 +312,8 @@ func stripSystemParts(messages []Message) []openaiMessage {
}
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
before, after, ok := strings.Cut(model, "/")
if !ok {
return model
}
@@ -321,10 +321,10 @@ func normalizeModel(model, apiBase string) string {
return model
}
prefix := strings.ToLower(model[:idx])
prefix := strings.ToLower(before)
switch prefix {
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
return model[idx+1:]
return after
default:
return model
}
+8 -5
View File
@@ -1,6 +1,9 @@
package routing
import "testing"
import (
"strings"
"testing"
)
func TestNormalizeAgentID_Empty(t *testing.T) {
if got := NormalizeAgentID(""); got != DefaultAgentID {
@@ -57,11 +60,11 @@ func TestNormalizeAgentID_AllInvalid(t *testing.T) {
}
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
long := ""
for i := 0; i < 100; i++ {
long += "a"
var long strings.Builder
for range 100 {
long.WriteString("a")
}
got := NormalizeAgentID(long)
got := NormalizeAgentID(long.String())
if len(got) > MaxAgentIDLength {
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
}
-41
View File
@@ -2,7 +2,6 @@ package skills
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -18,14 +17,6 @@ type SkillInstaller struct {
workspace string
}
type AvailableSkill struct {
Name string `json:"name"`
Repository string `json:"repository"`
Description string `json:"description"`
Author string `json:"author"`
Tags []string `json:"tags"`
}
func NewSkillInstaller(workspace string) *SkillInstaller {
return &SkillInstaller{
workspace: workspace,
@@ -89,35 +80,3 @@ func (si *SkillInstaller) Uninstall(skillName string) error {
return nil
}
func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableSkill, error) {
url := "https://raw.githubusercontent.com/sipeed/picoclaw-skills/main/skills.json"
client := &http.Client{Timeout: 15 * time.Second}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := utils.DoRequestWithRetry(client, req)
if err != nil {
return nil, fmt.Errorf("failed to fetch skills list: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to fetch skills list: HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
var skills []AvailableSkill
if err := json.Unmarshal(body, &skills); err != nil {
return nil, fmt.Errorf("failed to parse skills list: %w", err)
}
return skills, nil
}
+1 -1
View File
@@ -240,7 +240,7 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
normalized := strings.ReplaceAll(content, "\r\n", "\n")
normalized = strings.ReplaceAll(normalized, "\r", "\n")
for _, line := range strings.Split(normalized, "\n") {
for line := range strings.SplitSeq(normalized, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
+2 -2
View File
@@ -1,7 +1,7 @@
package skills
import (
"sort"
"slices"
"strings"
"sync"
"time"
@@ -183,7 +183,7 @@ func buildTrigrams(s string) []uint32 {
}
// Sort and Deduplication
sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
slices.Sort(trigrams)
n := 1
for i := 1; i < len(trigrams); i++ {
if trigrams[i] != trigrams[i-1] {
+2 -2
View File
@@ -153,7 +153,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
// Concurrent writes
go func() {
for i := 0; i < 100; i++ {
for i := range 100 {
cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
}
done <- struct{}{}
@@ -161,7 +161,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
// Concurrent reads
go func() {
for i := 0; i < 100; i++ {
for range 100 {
cache.Get("query-write-a")
}
done <- struct{}{}
+2 -2
View File
@@ -135,7 +135,7 @@ func TestConcurrentAccess(t *testing.T) {
// Test concurrent writes
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
for i := range 10 {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
@@ -144,7 +144,7 @@ func TestConcurrentAccess(t *testing.T) {
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
for range 10 {
<-done
}
+12 -6
View File
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
"strings"
"sync"
"time"
@@ -33,15 +34,19 @@ type CronTool struct {
func NewCronTool(
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
execTimeout time.Duration, config *config.Config,
) *CronTool {
execTool := NewExecToolWithConfig(workspace, restrict, config)
) (*CronTool, error) {
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
if err != nil {
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
}
execTool.SetTimeout(execTimeout)
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
}
}, nil
}
// Name returns the tool name
@@ -218,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
return SilentResult("No scheduled jobs")
}
result := "Scheduled jobs:\n"
var result strings.Builder
result.WriteString("Scheduled jobs:\n")
for _, j := range jobs {
var scheduleInfo string
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
@@ -230,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
} else {
scheduleInfo = "unknown"
}
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
}
return SilentResult(result)
return SilentResult(result.String())
}
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
+1 -1
View File
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
r := NewToolRegistry()
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
for i := range 50 {
wg.Add(1)
go func(n int) {
defer wg.Done()
+57 -51
View File
@@ -24,56 +24,64 @@ type ExecTool struct {
restrictToWorkspace bool
}
var defaultDenyPatterns = []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
regexp.MustCompile(`\$\([^)]+\)`),
regexp.MustCompile(`\$\{[^}]+\}`),
regexp.MustCompile("`[^`]+`"),
regexp.MustCompile(`\|\s*sh\b`),
regexp.MustCompile(`\|\s*bash\b`),
regexp.MustCompile(`;\s*rm\s+-[rf]`),
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
regexp.MustCompile(`<<\s*EOF`),
regexp.MustCompile(`\$\(\s*cat\s+`),
regexp.MustCompile(`\$\(\s*curl\s+`),
regexp.MustCompile(`\$\(\s*wget\s+`),
regexp.MustCompile(`\$\(\s*which\s+`),
regexp.MustCompile(`\bsudo\b`),
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
regexp.MustCompile(`\byum\s+(install|remove)\b`),
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
regexp.MustCompile(`\bdocker\s+run\b`),
regexp.MustCompile(`\bdocker\s+exec\b`),
regexp.MustCompile(`\bgit\s+push\b`),
regexp.MustCompile(`\bgit\s+force\b`),
regexp.MustCompile(`\bssh\b.*@`),
regexp.MustCompile(`\beval\b`),
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
}
var (
defaultDenyPatterns = []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
// Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(
`\b(format|mkfs|diskpart)\b\s`,
),
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
regexp.MustCompile(`\$\([^)]+\)`),
regexp.MustCompile(`\$\{[^}]+\}`),
regexp.MustCompile("`[^`]+`"),
regexp.MustCompile(`\|\s*sh\b`),
regexp.MustCompile(`\|\s*bash\b`),
regexp.MustCompile(`;\s*rm\s+-[rf]`),
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
regexp.MustCompile(`<<\s*EOF`),
regexp.MustCompile(`\$\(\s*cat\s+`),
regexp.MustCompile(`\$\(\s*curl\s+`),
regexp.MustCompile(`\$\(\s*wget\s+`),
regexp.MustCompile(`\$\(\s*which\s+`),
regexp.MustCompile(`\bsudo\b`),
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
regexp.MustCompile(`\byum\s+(install|remove)\b`),
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
regexp.MustCompile(`\bdocker\s+run\b`),
regexp.MustCompile(`\bdocker\s+exec\b`),
regexp.MustCompile(`\bgit\s+push\b`),
regexp.MustCompile(`\bgit\s+force\b`),
regexp.MustCompile(`\bssh\b.*@`),
regexp.MustCompile(`\beval\b`),
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
}
func NewExecTool(workingDir string, restrict bool) *ExecTool {
// absolutePathPattern matches absolute file paths in commands (Unix and Windows).
absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
)
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
}
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
if config != nil {
@@ -86,8 +94,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
for _, pattern := range execConfig.CustomDenyPatterns {
re, err := regexp.Compile(pattern)
if err != nil {
fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err)
continue
return nil, fmt.Errorf("invalid custom deny pattern %q: %w", pattern, err)
}
denyPatterns = append(denyPatterns, re)
}
@@ -106,7 +113,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns: denyPatterns,
allowPatterns: nil,
restrictToWorkspace: restrict,
}
}, nil
}
func (t *ExecTool) Name() string {
@@ -288,8 +295,7 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
return ""
}
pathPattern := regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
matches := pathPattern.FindAllString(cmd, -1)
matches := absolutePathPattern.FindAllString(cmd, -1)
for _, raw := range matches {
p, err := filepath.Abs(raw)
+48 -11
View File
@@ -11,7 +11,10 @@ import (
// TestShellTool_Success verifies successful command execution
func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{
@@ -38,7 +41,10 @@ func TestShellTool_Success(t *testing.T) {
// TestShellTool_Failure verifies failed command execution
func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{
@@ -65,7 +71,11 @@ func TestShellTool_Failure(t *testing.T) {
// TestShellTool_Timeout verifies command timeout handling
func TestShellTool_Timeout(t *testing.T) {
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
@@ -93,7 +103,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0o644)
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{
@@ -114,7 +127,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{
@@ -135,7 +151,10 @@ func TestShellTool_DangerousCommand(t *testing.T) {
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{}
@@ -150,7 +169,10 @@ func TestShellTool_MissingCommand(t *testing.T) {
// TestShellTool_StderrCapture verifies stderr is captured and included
func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{
@@ -170,7 +192,10 @@ func TestShellTool_StderrCapture(t *testing.T) {
// TestShellTool_OutputTruncation verifies long output is truncated
func TestShellTool_OutputTruncation(t *testing.T) {
tool := NewExecTool("", false)
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
// Generate long output (>10000 chars)
@@ -198,7 +223,11 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
t.Fatalf("failed to create outside dir: %v", err)
}
tool := NewExecTool(workspace, true)
tool, err := NewExecTool(workspace, true)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
result := tool.Execute(context.Background(), map[string]any{
"command": "pwd",
"working_dir": outsideDir,
@@ -232,7 +261,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
t.Skipf("symlinks not supported in this environment: %v", err)
}
tool := NewExecTool(workspace, true)
tool, err := NewExecTool(workspace, true)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
result := tool.Execute(context.Background(), map[string]any{
"command": "cat secret.txt",
"working_dir": link,
@@ -249,7 +282,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
tool := NewExecTool(tmpDir, false)
tool, err := NewExecTool(tmpDir, false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
+5 -1
View File
@@ -22,7 +22,11 @@ func processExists(pid int) bool {
}
func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
tool := NewExecTool(t.TempDir(), false)
tool, err := NewExecTool(t.TempDir(), false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
tool.SetTimeout(500 * time.Millisecond)
args := map[string]any{
+58 -48
View File
@@ -16,6 +16,14 @@ import (
const (
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
// HTTP client timeouts for web tool providers.
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
fetchTimeout = 60 * time.Second // WebFetchTool
defaultMaxChars = 50000
maxRedirects = 5
)
// Pre-compiled regexes for HTML text extraction
@@ -75,6 +83,7 @@ type SearchProvider interface {
type BraveSearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -89,11 +98,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -144,6 +149,7 @@ type TavilySearchProvider struct {
apiKey string
baseURL string
proxy string
client *http.Client
}
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -175,11 +181,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -227,7 +229,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
}
type DuckDuckGoSearchProvider struct {
proxy string
proxy string
client *http.Client
}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -240,11 +243,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -286,7 +285,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
maxItems := min(len(matches), count)
for i := 0; i < maxItems; i++ {
for i := range maxItems {
urlStr := matches[i][1]
title := stripTags(matches[i][2])
title = strings.TrimSpace(title)
@@ -294,9 +293,9 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
// URL decoding if needed
if strings.Contains(urlStr, "uddg=") {
if u, err := url.QueryUnescape(urlStr); err == nil {
idx := strings.Index(u, "uddg=")
if idx != -1 {
urlStr = u[idx+5:]
_, after, ok := strings.Cut(u, "uddg=")
if ok {
urlStr = after
}
}
}
@@ -323,6 +322,7 @@ func stripTags(content string) string {
type PerplexitySearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -357,11 +357,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 30*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -416,43 +412,60 @@ type WebSearchToolOptions struct {
Proxy string
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
}
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
}
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
}
provider = &TavilySearchProvider{
apiKey: opts.TavilyAPIKey,
baseURL: opts.TavilyBaseURL,
proxy: opts.Proxy,
client: client,
}
if opts.TavilyMaxResults > 0 {
maxResults = opts.TavilyMaxResults
}
} else if opts.DuckDuckGoEnabled {
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
}
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else {
return nil
return nil, nil
}
return &WebSearchTool{
provider: provider,
maxResults: maxResults,
}
}, nil
}
func (t *WebSearchTool) Name() string {
@@ -527,7 +540,17 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) *WebFetchTool {
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) *WebFetchTool {
if maxChars <= 0 {
maxChars = 50000
maxChars = defaultMaxChars
}
client, err := createHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
return nil
}
if fetchLimitBytes <= 0 {
fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback
@@ -598,20 +621,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(t.proxy, 60*time.Second)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
}
// Configure redirect handling
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("stopped after 5 redirects")
}
return nil
}
resp, err := client.Do(req)
resp, err := t.client.Do(req)
if err != nil {
return ErrorResult(fmt.Sprintf("request failed: %v", err))
}
@@ -669,14 +679,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
resultJSON, _ := json.MarshalIndent(result, "", " ")
return &ToolResult{
ForLLM: fmt.Sprintf(
ForLLM: string(resultJSON),
ForUser: fmt.Sprintf(
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
len(text),
urlStr,
extractor,
truncated,
),
ForUser: string(resultJSON),
}
}
+45 -24
View File
@@ -36,14 +36,14 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain the fetched content
if !strings.Contains(result.ForUser, "Test Page") {
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
// ForLLM should contain the fetched content (full JSON result)
if !strings.Contains(result.ForLLM, "Test Page") {
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
}
// ForLLM should contain summary
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
// ForUser should contain summary
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
}
}
@@ -72,9 +72,9 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain formatted JSON
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
// ForLLM should contain formatted JSON
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
}
}
@@ -163,9 +163,9 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain truncated content (not the full 20000 chars)
// ForLLM should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]any)
json.Unmarshal([]byte(result.ForUser), &resultMap)
json.Unmarshal([]byte(result.ForLLM), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
@@ -220,13 +220,19 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when Brave API key is empty")
}
// Also nil when nothing is enabled
tool = NewWebSearchTool(WebSearchToolOptions{})
tool, err = NewWebSearchTool(WebSearchToolOptions{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when no provider is enabled")
}
@@ -234,7 +240,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx := context.Background()
args := map[string]any{}
@@ -272,14 +281,14 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain extracted text (without script/style tags)
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
// ForLLM should contain extracted text (without script/style tags)
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
}
// Should NOT contain script or style tags
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
// Should NOT contain script or style tags in ForLLM
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
}
}
@@ -498,12 +507,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "k",
PerplexityMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*PerplexitySearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
@@ -514,12 +526,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
})
t.Run("brave", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKey: "k",
BraveMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*BraveSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
@@ -530,11 +545,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
})
t.Run("duckduckgo", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: true,
DuckDuckGoMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
@@ -586,12 +604,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
}))
defer server.Close()
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKey: "test-key",
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
ctx := context.Background()
args := map[string]any{
+3
View File
@@ -37,6 +37,9 @@ func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response,
if i < maxRetries-1 {
if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil {
if resp != nil {
resp.Body.Close()
}
return nil, fmt.Errorf("failed to sleep: %w", err)
}
}
+88
View File
@@ -1,8 +1,11 @@
package utils
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -77,6 +80,91 @@ func TestDoRequestWithRetry(t *testing.T) {
}
}
func TestDoRequestWithRetry_ContextCancel(t *testing.T) {
// Use a long retry delay so cancellation always hits during sleepWithCtx.
retryDelayUnit = 10 * time.Second
t.Cleanup(func() { retryDelayUnit = time.Second })
bodyClosed := false
firstRoundTripDone := make(chan struct{}, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("error"))
}))
defer server.Close()
client := server.Client()
client.Timeout = 30 * time.Second
client.Transport = &bodyCloseTracker{
rt: client.Transport,
onClose: func() { bodyClosed = true },
// Signal after the first round-trip response is fully constructed on the client side.
onRoundTrip: func() {
select {
case firstRoundTripDone <- struct{}{}:
default:
}
},
trackURL: server.URL,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Cancel the context after the first round-trip completes on the client side.
// This ensures client.Do has returned a valid resp (with body) and the retry
// loop is about to enter sleepWithCtx, where the cancel will be detected.
go func() {
<-firstRoundTripDone
cancel()
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := DoRequestWithRetry(client, req)
if resp != nil {
resp.Body.Close()
}
require.Error(t, err, "expected error from context cancellation")
assert.Nil(t, resp, "expected nil response when context is canceled")
assert.True(t, bodyClosed, "expected resp.Body to be closed on context cancellation")
}
// bodyCloseTracker wraps an http.RoundTripper and records when response bodies are closed.
type bodyCloseTracker struct {
rt http.RoundTripper
onClose func()
onRoundTrip func() // called after each successful round-trip
trackURL string
}
func (t *bodyCloseTracker) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := t.rt.RoundTrip(req)
if err != nil {
return resp, err
}
if strings.HasPrefix(req.URL.String(), t.trackURL) {
resp.Body = &closeNotifier{ReadCloser: resp.Body, onClose: t.onClose}
if t.onRoundTrip != nil {
t.onRoundTrip()
}
}
return resp, nil
}
// closeNotifier wraps an io.ReadCloser to detect Close calls.
type closeNotifier struct {
io.ReadCloser
onClose func()
}
func (c *closeNotifier) Close() error {
c.onClose()
return c.ReadCloser.Close()
}
func TestDoRequestWithRetry_Delay(t *testing.T) {
retryDelayUnit = time.Millisecond
t.Cleanup(func() { retryDelayUnit = time.Second })