Merge branch 'main' into feat/markdown-output-format-web-fetch

This commit is contained in:
Mauro
2026-03-17 16:37:22 +01:00
committed by GitHub
104 changed files with 6151 additions and 1202 deletions
+22 -3
View File
@@ -458,7 +458,23 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string {
//
// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
// See: https://platform.openai.com/docs/guides/prompt-caching
func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
func formatCurrentSenderLine(senderID, senderDisplayName string) string {
senderID = strings.TrimSpace(senderID)
senderDisplayName = strings.TrimSpace(senderDisplayName)
switch {
case senderDisplayName != "" && senderID != "":
return fmt.Sprintf("Current sender: %s (ID: %s)", senderDisplayName, senderID)
case senderDisplayName != "":
return fmt.Sprintf("Current sender: %s", senderDisplayName)
case senderID != "":
return fmt.Sprintf("Current sender: %s", senderID)
default:
return ""
}
}
func (cb *ContextBuilder) buildDynamicContext(channel, chatID, senderID, senderDisplayName string) string {
now := time.Now().Format("2006-01-02 15:04 (Monday)")
rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
@@ -468,6 +484,9 @@ func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
if channel != "" && chatID != "" {
fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
}
if senderLine := formatCurrentSenderLine(senderID, senderDisplayName); senderLine != "" {
fmt.Fprintf(&sb, "\n\n## Current Sender\n%s", senderLine)
}
return sb.String()
}
@@ -477,7 +496,7 @@ func (cb *ContextBuilder) BuildMessages(
summary string,
currentMessage string,
media []string,
channel, chatID string,
channel, chatID, senderID, senderDisplayName string,
) []providers.Message {
messages := []providers.Message{}
@@ -493,7 +512,7 @@ func (cb *ContextBuilder) BuildMessages(
staticPrompt := cb.BuildSystemPromptWithCache()
// Build short dynamic context (time, runtime, session) — changes per request
dynamicCtx := cb.buildDynamicContext(channel, chatID)
dynamicCtx := cb.buildDynamicContext(channel, chatID, senderID, senderDisplayName)
// Compose a single system message: static (cached) + dynamic + optional summary.
// Keeping all system content in one message ensures every provider adapter can
+65 -3
View File
@@ -82,7 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1")
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
systemCount := 0
for _, m := range msgs {
@@ -126,6 +126,68 @@ func TestSingleSystemMessage(t *testing.T) {
}
}
func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nTest agent.",
})
defer os.RemoveAll(tmpDir)
cb := NewContextBuilder(tmpDir)
tests := []struct {
name string
senderID string
senderDisplayName string
wantLine string
wantSection bool
}{
{
name: "both id and display name",
senderID: "feishu:ou_xxx",
senderDisplayName: "Zhang San",
wantLine: "Current sender: Zhang San (ID: feishu:ou_xxx)",
wantSection: true,
},
{
name: "display name only",
senderDisplayName: "Alice",
wantLine: "Current sender: Alice",
wantSection: true,
},
{
name: "id only",
senderID: "discord:123",
wantLine: "Current sender: discord:123",
wantSection: true,
},
{
name: "no sender info",
wantSection: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
sys := msgs[0].Content
if tt.wantSection {
if !strings.Contains(sys, "## Current Sender") {
t.Fatalf("system prompt missing Current Sender section:\n%s", sys)
}
if !strings.Contains(sys, tt.wantLine) {
t.Fatalf("system prompt missing sender line %q:\n%s", tt.wantLine, sys)
}
return
}
if strings.Contains(sys, "## Current Sender") {
t.Fatalf("system prompt should omit Current Sender section:\n%s", sys)
}
})
}
}
// TestMtimeAutoInvalidation verifies that the cache detects source file changes
// via mtime without requiring explicit InvalidateCache().
// Fix: original implementation had no auto-invalidation — edits to bootstrap files,
@@ -576,7 +638,7 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
}
// Also exercise BuildMessages concurrently
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat")
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat", "", "")
if len(msgs) < 2 {
errs <- "BuildMessages returned fewer than 2 messages"
return
@@ -664,6 +726,6 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test")
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test", "", "")
}
}
+25 -2
View File
@@ -10,6 +10,7 @@ import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -66,7 +67,7 @@ func NewAgentInstance(
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
// Compile path whitelist patterns from config.
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
allowReadPaths := buildAllowReadPatterns(cfg)
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
@@ -82,7 +83,7 @@ func NewAgentInstance(
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
@@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
return compiled
}
func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp {
var configured []string
if cfg != nil {
configured = cfg.Tools.AllowReadPaths
}
compiled := compilePatterns(configured)
mediaDirPattern := regexp.MustCompile(mediaTempDirPattern())
for _, pattern := range compiled {
if pattern.String() == mediaDirPattern.String() {
return compiled
}
}
return append(compiled, mediaDirPattern)
}
func mediaTempDirPattern() string {
sep := regexp.QuoteMeta(string(os.PathSeparator))
return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)"
}
// Close releases resources held by the agent's session store.
func (a *AgentInstance) Close() error {
if a.Sessions != nil {
+86
View File
@@ -1,10 +1,14 @@
package agent
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
})
}
}
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt")
if err != nil {
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
}
mediaPath := mediaFile.Name()
if _, err := mediaFile.WriteString("attachment content"); err != nil {
mediaFile.Close()
t.Fatalf("WriteString(mediaFile) error = %v", err)
}
if err := mediaFile.Close(); err != nil {
t.Fatalf("Close(mediaFile) error = %v", err)
}
t.Cleanup(func() { _ = os.Remove(mediaPath) })
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "test-model",
RestrictToWorkspace: true,
},
},
Tools: config.ToolsConfig{
ReadFile: config.ReadFileToolConfig{Enabled: true},
ListDir: config.ToolConfig{Enabled: true},
Exec: config.ExecConfig{
ToolConfig: config.ToolConfig{Enabled: true},
EnableDenyPatterns: true,
AllowRemote: true,
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
readTool, ok := agent.Tools.Get("read_file")
if !ok {
t.Fatal("read_file tool not registered")
}
readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath})
if readResult.IsError {
t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM)
}
if !strings.Contains(readResult.ForLLM, "attachment content") {
t.Fatalf("read_file output missing media content: %s", readResult.ForLLM)
}
listTool, ok := agent.Tools.Get("list_dir")
if !ok {
t.Fatal("list_dir tool not registered")
}
listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir})
if listResult.IsError {
t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM)
}
if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) {
t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM)
}
execTool, ok := agent.Tools.Get("exec")
if !ok {
t.Fatal("exec tool not registered")
}
execResult := execTool.Execute(context.Background(), map[string]any{
"command": "cat " + filepath.Base(mediaPath),
"working_dir": mediaDir,
})
if execResult.IsError {
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
}
if !strings.Contains(execResult.ForLLM, "attachment content") {
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
}
}
+42 -26
View File
@@ -55,15 +55,17 @@ type AgentLoop struct {
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
SenderID string // Current sender ID for dynamic context
SenderDisplayName string // Current sender display name for dynamic context
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
}
const (
@@ -117,6 +119,8 @@ func registerSharedTools(
registry *AgentRegistry,
provider providers.LLMProvider,
) {
allowReadPaths := buildAllowReadPatterns(cfg)
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
if !ok {
@@ -161,7 +165,8 @@ func registerSharedTools(
50000,
cfg.Tools.Web.Proxy,
cfg.Tools.Web.Format,
cfg.Tools.Web.FetchLimitBytes)
cfg.Tools.Web.FetchLimitBytes,
cfg.Tools.Web.PrivateHostWhitelist)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
@@ -199,6 +204,7 @@ func registerSharedTools(
cfg.Agents.Defaults.RestrictToWorkspace,
cfg.Agents.Defaults.GetMaxMediaSize(),
nil,
allowReadPaths,
)
agent.Tools.Register(sendFileTool)
}
@@ -226,20 +232,26 @@ func registerSharedTools(
}
}
// Spawn tool with allowlist checker
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// Spawn and spawn_status tools share a SubagentManager.
// Construct it when either tool is enabled (both require subagent).
spawnEnabled := cfg.Tools.IsToolEnabled("spawn")
spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status")
if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
if spawnEnabled {
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
} else {
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
}
if spawnStatusEnabled {
agent.Tools.Register(tools.NewSpawnStatusTool(subagentManager))
}
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
}
}
}
@@ -736,14 +748,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
})
opts := processOptions{
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
SenderID: msg.SenderID,
SenderDisplayName: msg.Sender.DisplayName,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
}
// context-dependent commands check their own Runtime fields and report
@@ -883,6 +897,8 @@ func (al *AgentLoop) runAgentLoop(
opts.Media,
opts.Channel,
opts.ChatID,
opts.SenderID,
opts.SenderDisplayName,
)
// Resolve media:// refs: images→base64 data URLs, non-images→local paths in content
@@ -1154,7 +1170,7 @@ func (al *AgentLoop) runLLMIteration(
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
messages = agent.ContextBuilder.BuildMessages(
newHistory, newSummary, "",
nil, opts.Channel, opts.ChatID,
nil, opts.Channel, opts.ChatID, opts.SenderID, opts.SenderDisplayName,
)
continue
}
+75
View File
@@ -30,6 +30,28 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
type recordingProvider struct {
lastMessages []providers.Message
}
func (r *recordingProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
r.lastMessages = append([]providers.Message(nil), messages...)
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (r *recordingProvider) GetDefaultModel() string {
return "mock-model"
}
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
@@ -54,6 +76,59 @@ func newTestAgentLoop(
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
}
func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "discord",
SenderID: "discord:123",
Sender: bus.SenderInfo{
DisplayName: "Alice",
},
ChatID: "group-1",
Content: "hello",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages")
}
systemPrompt := provider.lastMessages[0].Content
wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)"
if !strings.Contains(systemPrompt, wantSender) {
t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt)
}
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "hello" {
t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage)
}
}
func TestRecordLastChannel(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
+1 -1
View File
@@ -618,7 +618,7 @@ func (c *FeishuChannel) downloadResource(
}
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
mediaDir := media.TempDir()
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
"error": mkdirErr.Error(),
+1 -2
View File
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
if len(m.channels) == 0 {
logger.WarnC("channels", "No channels enabled")
return errors.New("no channels enabled")
}
logger.InfoC("channels", "Starting all channels")
@@ -397,7 +396,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
"addr": m.httpServer.Addr,
})
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"error": err.Error(),
})
}
+1 -3
View File
@@ -35,8 +35,6 @@ const (
roomKindCacheTTL = 5 * time.Minute
roomKindCacheCleanupPeriod = 1 * time.Minute
roomKindCacheMaxEntries = 2048
matrixMediaTempDirName = "picoclaw_media"
)
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
@@ -1105,7 +1103,7 @@ func (c *MatrixChannel) stripSelfMention(text string) string {
}
func matrixMediaTempDir() (string, error) {
mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName)
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
return "", err
}
+2 -1
View File
@@ -15,6 +15,7 @@ import (
"maunium.net/go/mautrix/id"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
@@ -165,7 +166,7 @@ func TestMatrixMediaTempDir(t *testing.T) {
if err != nil {
t.Fatalf("matrixMediaTempDir failed: %v", err)
}
if filepath.Base(dir) != matrixMediaTempDirName {
if filepath.Base(dir) != media.TempDirName {
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
}
+28 -3
View File
@@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return
}
conn, err := c.upgrader.Upgrade(w, r, nil)
// Echo the matched subprotocol back so the browser accepts the upgrade.
var responseHeader http.Header
if proto := c.matchedSubprotocol(r); proto != "" {
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
}
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
if err != nil {
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
"error": err.Error(),
@@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
go c.readLoop(pc)
}
// authenticate checks the Bearer token from the Authorization header.
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
// authenticate checks the request for a valid token:
// 1. Authorization: Bearer <token> header
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
// 3. Query parameter "token" (only when AllowTokenQuery is on)
func (c *PicoChannel) authenticate(r *http.Request) bool {
token := c.config.Token
if token == "" {
@@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
}
}
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
if c.matchedSubprotocol(r) != "" {
return true
}
// Check query parameter only when explicitly allowed
if c.config.AllowTokenQuery {
if r.URL.Query().Get("token") == token {
@@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
return false
}
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
// the configured token, or "" if none do.
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
token := c.config.Token
for _, proto := range websocket.Subprotocols(r) {
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
return proto
}
}
return ""
}
// readLoop reads messages from a WebSocket connection.
func (c *PicoChannel) readLoop(pc *picoConn) {
defer func() {
+80 -6
View File
@@ -4,11 +4,13 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync/atomic"
"github.com/caarlos0/env/v11"
"github.com/sipeed/picoclaw/pkg/credential"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
@@ -623,8 +625,9 @@ func (c *ModelConfig) Validate() error {
}
type GatewayConfig struct {
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
}
type ToolDiscoveryConfig struct {
@@ -695,11 +698,13 @@ type WebToolsConfig struct {
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
Format string `json:"format,omitempty" env:"PICOCLAW_TOOLS_WEB_FORMAT"`
PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
}
type CronToolsConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
AllowCommand bool ` env:"PICOCLAW_TOOLS_CRON_ALLOW_COMMAND" json:"allow_command"`
}
type ExecConfig struct {
@@ -749,6 +754,7 @@ type ToolsConfig struct {
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SpawnStatus ToolConfig `json:"spawn_status" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
@@ -838,10 +844,24 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
if passphrase := credential.PassphraseProvider(); passphrase != "" {
for _, m := range cfg.ModelList {
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
fmt.Fprintf(os.Stderr,
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
m.ModelName)
}
}
}
if err := env.Parse(cfg); err != nil {
return nil, err
}
if err := resolveAPIKeys(cfg.ModelList, filepath.Dir(path)); err != nil {
return nil, err
}
// Migrate legacy channel config fields to new unified structures
cfg.migrateChannelConfigs()
@@ -858,6 +878,48 @@ func LoadConfig(path string) (*Config, error) {
return cfg, nil
}
// encryptPlaintextAPIKeys returns a copy of models with plaintext api_key values
// encrypted. Returns (nil, nil) when nothing changed (all keys already sealed or
// empty). Returns (nil, error) if any key fails to encrypt — callers must treat
// this as a hard failure to prevent a mixed plaintext/ciphertext state on disk.
// Symmetric counterpart of resolveAPIKeys: both operate purely on []ModelConfig
// and leave JSON marshaling to the caller.
func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelConfig, error) {
sealed := make([]ModelConfig, len(models))
copy(sealed, models)
changed := false
for i := range sealed {
m := &sealed[i]
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
continue
}
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
if err != nil {
return nil, fmt.Errorf("cannot seal api_key for model %q: %w", m.ModelName, err)
}
m.APIKey = encrypted
changed = true
}
if !changed {
return nil, nil
}
return sealed, nil
}
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
func resolveAPIKeys(models []ModelConfig, configDir string) error {
cr := credential.NewResolver(configDir)
for i := range models {
resolved, err := cr.Resolve(models[i].APIKey)
if err != nil {
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
}
models[i].APIKey = resolved
}
return nil
}
func (c *Config) migrateChannelConfigs() {
// Discord: mention_only -> group_trigger.mention_only
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
@@ -872,12 +934,22 @@ func (c *Config) migrateChannelConfigs() {
}
func SaveConfig(path string, cfg *Config) error {
if passphrase := credential.PassphraseProvider(); passphrase != "" {
sealed, err := encryptPlaintextAPIKeys(cfg.ModelList, passphrase)
if err != nil {
return err
}
if sealed != nil {
tmp := *cfg
tmp.ModelList = sealed
cfg = &tmp
}
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
// Use unified atomic write utility with explicit sync for flash storage reliability.
return fileutil.WriteFileAtomic(path, data, 0o600)
}
@@ -1044,6 +1116,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
return t.ReadFile.Enabled
case "spawn":
return t.Spawn.Enabled
case "spawn_status":
return t.SpawnStatus.Enabled
case "spi":
return t.SPI.Enabled
case "subagent":
+386 -5
View File
@@ -7,8 +7,22 @@ import (
"runtime"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/credential"
)
// mustSetupSSHKey generates a temporary Ed25519 SSH key in t.TempDir() and sets
// PICOCLAW_SSH_KEY_PATH to its path for the duration of the test. This is required
// whenever a test exercises encryption/decryption via credential.Encrypt or SaveConfig.
func mustSetupSSHKey(t *testing.T) {
t.Helper()
keyPath := filepath.Join(t.TempDir(), "picoclaw_ed25519.key")
if err := credential.GenerateSSHKey(keyPath); err != nil {
t.Fatalf("mustSetupSSHKey: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", keyPath)
}
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
var m AgentModelConfig
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
@@ -253,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if cfg.Gateway.HotReload {
t.Error("Gateway hot reload should be disabled by default")
}
}
// TestDefaultConfig_Providers verifies provider structure
@@ -391,6 +408,13 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
}
}
func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Tools.Cron.AllowCommand {
t.Fatal("DefaultConfig().Tools.Cron.AllowCommand should be true")
}
}
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -423,6 +447,22 @@ func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) {
}
}
func TestLoadConfig_CronAllowCommandDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"tools":{"cron":{"exec_timeout_minutes":5}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if !cfg.Tools.Cron.AllowCommand {
t.Fatal("tools.cron.allow_command should remain true when unset in config file")
}
}
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -482,13 +522,19 @@ func TestDefaultConfig_DMScope(t *testing.T) {
}
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
// Unset to ensure we test the default
t.Setenv("PICOCLAW_HOME", "")
// Set a known home for consistent test results
t.Setenv("HOME", "/tmp/home")
var fakeHome string
if runtime.GOOS == "windows" {
fakeHome = `C:\tmp\home`
t.Setenv("USERPROFILE", fakeHome)
} else {
fakeHome = "/tmp/home"
t.Setenv("HOME", fakeHome)
}
cfg := DefaultConfig()
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
want := filepath.Join(fakeHome, ".picoclaw", "workspace")
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
@@ -499,7 +545,7 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
cfg := DefaultConfig()
want := "/custom/picoclaw/home/workspace"
want := filepath.Join("/custom/picoclaw/home", "workspace")
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
@@ -621,3 +667,338 @@ func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
}
})
}
// TestLoadConfig_WarnsForPlaintextAPIKey verifies that LoadConfig resolves a plaintext
// api_key into memory but does NOT rewrite the config file. File writes are the sole
// responsibility of SaveConfig.
func TestLoadConfig_WarnsForPlaintextAPIKey(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
const original = `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
if err := os.WriteFile(cfgPath, []byte(original), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
cfg, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
// In-memory value must be the resolved plaintext.
if cfg.ModelList[0].APIKey != "sk-plaintext" {
t.Errorf("in-memory api_key = %q, want %q", cfg.ModelList[0].APIKey, "sk-plaintext")
}
// The file on disk must remain unchanged — LoadConfig must not write anything.
raw, _ := os.ReadFile(cfgPath)
if string(raw) != original {
t.Errorf("LoadConfig must not modify the config file; got:\n%s", string(raw))
}
}
// TestSaveConfig_EncryptsPlaintextAPIKey verifies that SaveConfig writes enc:// ciphertext
// to disk and that a subsequent LoadConfig decrypts it back to the original plaintext.
func TestSaveConfig_EncryptsPlaintextAPIKey(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
cfg := DefaultConfig()
cfg.ModelList = []ModelConfig{
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
// Disk must contain enc://, not the raw key.
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "enc://") {
t.Errorf("saved file should contain enc://, got:\n%s", string(raw))
}
if strings.Contains(string(raw), "sk-plaintext") {
t.Errorf("saved file must not contain the plaintext key")
}
// A fresh load must decrypt back to the original plaintext.
cfg2, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig after SaveConfig: %v", err)
}
if cfg2.ModelList[0].APIKey != "sk-plaintext" {
t.Errorf("loaded api_key = %q, want %q", cfg2.ModelList[0].APIKey, "sk-plaintext")
}
}
// TestLoadConfig_NoSealWithoutPassphrase verifies that api_key values are left
// unchanged when PICOCLAW_KEY_PASSPHRASE is not set.
func TestLoadConfig_NoSealWithoutPassphrase(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
if _, err := LoadConfig(cfgPath); err != nil {
t.Fatalf("LoadConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if strings.Contains(string(raw), "enc://") {
t.Error("config file must not be modified when no passphrase is set")
}
}
// TestLoadConfig_FileRefNotSealed verifies that file:// api_key references are not
// converted to enc:// values (they are resolved at runtime by the Resolver).
func TestLoadConfig_FileRefNotSealed(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
keyFile := filepath.Join(dir, "openai.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"file://openai.key"}]}`
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
if _, err := LoadConfig(cfgPath); err != nil {
t.Fatalf("LoadConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "file://openai.key") {
t.Error("file:// reference should be preserved unchanged in the config file")
}
if strings.Contains(string(raw), "enc://") {
t.Error("file:// reference must not be converted to enc://")
}
}
// TestSaveConfig_MixedKeys verifies that SaveConfig encrypts only plaintext api_keys
// and leaves already-encrypted (enc://) and file:// entries unchanged.
func TestSaveConfig_MixedKeys(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
// Pre-encrypt one key so we have a genuine enc:// value to put in the config.
if err := SaveConfig(cfgPath, &Config{
ModelList: []ModelConfig{
{ModelName: "pre", Model: "openai/gpt-4", APIKey: "sk-already-plain"},
},
}); err != nil {
t.Fatalf("setup SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
// Extract the enc:// value from the saved file.
var tmp struct {
ModelList []struct {
APIKey string `json:"api_key"`
} `json:"model_list"`
}
if err := json.Unmarshal(raw, &tmp); err != nil || len(tmp.ModelList) == 0 {
t.Fatalf("setup: could not parse saved config: %v", err)
}
alreadyEncrypted := tmp.ModelList[0].APIKey
if !strings.HasPrefix(alreadyEncrypted, "enc://") {
t.Fatalf("setup: expected enc:// key, got %q", alreadyEncrypted)
}
// Build a config with three models:
// 1. plaintext → must be encrypted by SaveConfig
// 2. enc:// → must be left unchanged (already encrypted)
// 3. file:// → must be left unchanged (file reference)
keyFile := filepath.Join(dir, "api.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "plain", Model: "openai/gpt-4", APIKey: "sk-new-plaintext"},
{ModelName: "enc", Model: "openai/gpt-4", APIKey: alreadyEncrypted},
{ModelName: "file", Model: "openai/gpt-4", APIKey: "file://api.key"},
},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
raw, _ = os.ReadFile(cfgPath)
s := string(raw)
// 1. Plaintext must be encrypted.
if strings.Contains(s, "sk-new-plaintext") {
t.Error("plaintext key must not appear in saved file")
}
// 2. The pre-existing enc:// value must still be present (byte-for-byte unchanged).
if !strings.Contains(s, alreadyEncrypted) {
t.Error("pre-existing enc:// entry must be preserved unchanged")
}
// 3. file:// must be preserved.
if !strings.Contains(s, "file://api.key") {
t.Error("file:// reference must be preserved unchanged")
}
// Now load and verify all three decrypt/resolve correctly.
cfg2, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig after SaveConfig: %v", err)
}
byName := make(map[string]string)
for _, m := range cfg2.ModelList {
byName[m.ModelName] = m.APIKey
}
if byName["plain"] != "sk-new-plaintext" {
t.Errorf("plain model api_key = %q, want %q", byName["plain"], "sk-new-plaintext")
}
if byName["enc"] != "sk-already-plain" {
t.Errorf("enc model api_key = %q, want %q", byName["enc"], "sk-already-plain")
}
if byName["file"] != "sk-from-file" {
t.Errorf("file model api_key = %q, want %q", byName["file"], "sk-from-file")
}
}
// TestLoadConfig_MixedKeys_NoPassphrase verifies that when PICOCLAW_KEY_PASSPHRASE
// is not set, enc:// entries cause LoadConfig to return an error, while plaintext
// and file:// entries in the same config are not affected.
func TestLoadConfig_MixedKeys_NoPassphrase(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// First encrypt a key so we have a real enc:// value.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
if err := SaveConfig(cfgPath, &Config{
ModelList: []ModelConfig{
{ModelName: "m", Model: "openai/gpt-4", APIKey: "sk-secret"},
},
}); err != nil {
t.Fatalf("setup SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
var tmp struct {
ModelList []struct {
APIKey string `json:"api_key"`
} `json:"model_list"`
}
if err := json.Unmarshal(raw, &tmp); err != nil {
t.Fatalf("setup parse: %v", err)
}
encValue := tmp.ModelList[0].APIKey
// Write a mixed config: enc:// + plaintext + file://
keyFile := filepath.Join(dir, "api.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
mixed, _ := json.Marshal(map[string]any{
"model_list": []map[string]any{
{"model_name": "enc", "model": "openai/gpt-4", "api_key": encValue},
{"model_name": "plain", "model": "openai/gpt-4", "api_key": "sk-plain"},
{"model_name": "file", "model": "openai/gpt-4", "api_key": "file://api.key"},
},
})
if err := os.WriteFile(cfgPath, mixed, 0o600); err != nil {
t.Fatalf("setup write: %v", err)
}
// Now clear the passphrase — LoadConfig must fail because enc:// cannot be decrypted.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
_, err := LoadConfig(cfgPath)
if err == nil {
t.Fatal("LoadConfig should fail when enc:// key is present and no passphrase is set")
}
if !strings.Contains(err.Error(), "passphrase required") {
t.Errorf("error should mention passphrase required, got: %v", err)
}
}
// TestSaveConfig_UsesPassphraseProvider verifies that SaveConfig encrypts plaintext
// api_keys using credential.PassphraseProvider() rather than os.Getenv directly.
// This matters for the launcher, which clears the environment variable and redirects
// PassphraseProvider to an in-memory SecureStore.
func TestSaveConfig_UsesPassphraseProvider(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Ensure the env var is empty — passphrase must come from PassphraseProvider only.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
mustSetupSSHKey(t)
// Replace PassphraseProvider with an in-memory function (simulating SecureStore).
const testPassphrase = "provider-passphrase"
orig := credential.PassphraseProvider
credential.PassphraseProvider = func() string { return testPassphrase }
t.Cleanup(func() { credential.PassphraseProvider = orig })
cfg := DefaultConfig()
cfg.ModelList = []ModelConfig{
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "enc://") {
t.Errorf("SaveConfig should have encrypted plaintext key via PassphraseProvider; got:\n%s", raw)
}
}
// TestLoadConfig_UsesPassphraseProvider verifies that LoadConfig decrypts enc:// keys
// using credential.PassphraseProvider() rather than os.Getenv directly.
func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Ensure the env var is empty throughout.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
mustSetupSSHKey(t)
const testPassphrase = "provider-passphrase"
const plainKey = "sk-secret"
// First, encrypt the key using the same passphrase.
encrypted, err := credential.Encrypt(testPassphrase, "", plainKey)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
raw, _ := json.Marshal(map[string]any{
"model_list": []map[string]any{
{"model_name": "test", "model": "openai/gpt-4", "api_key": encrypted},
},
})
if err = os.WriteFile(cfgPath, raw, 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Redirect PassphraseProvider — env var is empty, so without this the load would fail.
orig := credential.PassphraseProvider
credential.PassphraseProvider = func() string { return testPassphrase }
t.Cleanup(func() { credential.PassphraseProvider = orig })
cfg, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
if cfg.ModelList[0].APIKey != plainKey {
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
}
}
+7 -2
View File
@@ -395,8 +395,9 @@ func DefaultConfig() *Config {
},
},
Gateway: GatewayConfig{
Host: "127.0.0.1",
Port: 18790,
Host: "127.0.0.1",
Port: 18790,
HotReload: false,
},
Tools: ToolsConfig{
MediaCleanup: MediaCleanupConfig{
@@ -453,6 +454,7 @@ func DefaultConfig() *Config {
Enabled: true,
},
ExecTimeoutMinutes: 5,
AllowCommand: true,
},
Exec: ExecConfig{
ToolConfig: ToolConfig{
@@ -522,6 +524,9 @@ func DefaultConfig() *Config {
Spawn: ToolConfig{
Enabled: true,
},
SpawnStatus: ToolConfig{
Enabled: false,
},
SPI: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
+335
View File
@@ -0,0 +1,335 @@
// Package credential resolves API credential values for model_list entries.
//
// An API key is a form of authorization credential. This package centralizes
// how raw credential strings—plaintext or file references—are resolved into
// their actual values, keeping that logic out of the config loader.
//
// Supported formats for the api_key field:
//
// - Plaintext: "sk-abc123" → returned as-is
// - File ref: "file://filename.key" → content read from configDir/filename.key
// - Encrypted: "enc://<base64>" → AES-256-GCM decrypt via PICOCLAW_KEY_PASSPHRASE
// - Empty: "" → returned as-is (auth_method=oauth etc.)
//
// Encryption uses AES-256-GCM with HKDF-SHA256 key derivation (< 1ms, safe for embedded Linux).
// An SSH private key is required for both encryption and decryption.
// Key derivation:
//
// HKDF-SHA256(ikm=HMAC-SHA256(SHA256(sshKeyBytes), passphrase), salt, info)
//
// SSH key path resolution priority:
//
// 1. sshKeyPath argument to Encrypt (explicit)
// 2. PICOCLAW_SSH_KEY_PATH env var
// 3. ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform)
package credential
import (
"crypto/aes"
"crypto/cipher"
"crypto/hkdf"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// PassphraseEnvVar is the environment variable that holds the encryption passphrase.
// Other packages (e.g. config) reference this constant to avoid duplicating the string.
const PassphraseEnvVar = "PICOCLAW_KEY_PASSPHRASE"
// PassphraseProvider is the function used to retrieve the passphrase for enc://
// credential decryption. It defaults to reading PICOCLAW_KEY_PASSPHRASE from the
// process environment. Replace it at startup to use a different source, such as
// an in-memory SecureStore, so that all LoadConfig() calls everywhere share the
// same passphrase source without needing os.Environ.
//
// Example (launcher main.go):
//
// credential.PassphraseProvider = apiHandler.passphraseStore.Get
var PassphraseProvider func() string = func() string {
return os.Getenv(PassphraseEnvVar)
}
// ErrPassphraseRequired is returned when an enc:// credential is encountered but
// no passphrase is available from PassphraseProvider. Callers can detect this
// with errors.Is to distinguish a missing-passphrase condition from other errors.
var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
// ErrDecryptionFailed is returned when an enc:// credential cannot be decrypted,
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
const (
fileScheme = "file://"
encScheme = "enc://"
hkdfInfo = "picoclaw-credential-v1"
saltLen = 16
nonceLen = 12
keyLen = 32
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
)
// Resolver resolves raw credential strings for model_list api_key fields.
// File references are resolved relative to the directory of the config file.
type Resolver struct {
configDir string
resolvedConfigDir string // symlink-resolved form of configDir
}
// NewResolver returns a Resolver that resolves file:// references relative to
// configDir (typically filepath.Dir of the config file path).
func NewResolver(configDir string) *Resolver {
resolved := configDir
if configDir != "" {
if linkedPath, err := filepath.EvalSymlinks(configDir); err == nil {
resolved = linkedPath
}
}
return &Resolver{configDir: configDir, resolvedConfigDir: resolved}
}
// Resolve returns the actual credential value for raw:
//
// - "" → "" (no error; auth_method=oauth needs no key)
// - "file://name.key" → trimmed content of configDir/name.key
// - anything else → raw unchanged (plaintext credential)
func (r *Resolver) Resolve(raw string) (string, error) {
if raw == "" {
return "", nil
}
if strings.HasPrefix(raw, fileScheme) {
fileName := strings.TrimSpace(strings.TrimPrefix(raw, fileScheme))
if fileName == "" {
return "", fmt.Errorf("credential: file:// reference has no filename")
}
baseDir := r.resolvedConfigDir
if baseDir == "" {
baseDir = r.configDir
}
keyPath := filepath.Join(baseDir, fileName)
// Resolve symlinks before enforcing containment to prevent escaping via symlinks.
realKeyPath, err := filepath.EvalSymlinks(keyPath)
if err != nil {
return "", fmt.Errorf("credential: failed to resolve credential file path %q: %w", keyPath, err)
}
if !isWithinDir(realKeyPath, baseDir) {
return "", fmt.Errorf("credential: file:// path escapes config directory")
}
data, err := os.ReadFile(realKeyPath)
if err != nil {
return "", fmt.Errorf("credential: failed to read credential file %q: %w", realKeyPath, err)
}
value := strings.TrimSpace(string(data))
if value == "" {
return "", fmt.Errorf("credential: credential file %q is empty", realKeyPath)
}
return value, nil
}
if strings.HasPrefix(raw, encScheme) {
return resolveEncrypted(raw)
}
// Plaintext credential — return unchanged.
return raw, nil
}
// resolveEncrypted decrypts an enc:// credential using PassphraseProvider.
func resolveEncrypted(raw string) (string, error) {
passphrase := PassphraseProvider()
if passphrase == "" {
return "", ErrPassphraseRequired
}
sshKeyPath := pickSSHKeyPath("") // override="": consult env then auto-detect
b64 := strings.TrimPrefix(raw, encScheme)
blob, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", fmt.Errorf("credential: enc:// invalid base64: %w", err)
}
if len(blob) < saltLen+nonceLen+1 {
return "", fmt.Errorf("credential: enc:// payload too short")
}
salt := blob[:saltLen]
nonce := blob[saltLen : saltLen+nonceLen]
ciphertext := blob[saltLen+nonceLen:]
key, err := deriveKey(passphrase, sshKeyPath, salt)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("credential: enc:// cipher init: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("credential: enc:// gcm init: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrDecryptionFailed, err)
}
return string(plaintext), nil
}
// Encrypt encrypts plaintext and returns an enc:// credential string.
//
// passphrase is required (PICOCLAW_KEY_PASSPHRASE value).
// sshKeyPath is the SSH private key file to use; pass "" to auto-detect via
// PICOCLAW_SSH_KEY_PATH env var or ~/.ssh/picoclaw_ed25519.key.
// An SSH private key must be resolvable or Encrypt returns an error.
func Encrypt(passphrase, sshKeyPath, plaintext string) (string, error) {
if passphrase == "" {
return "", fmt.Errorf("credential: passphrase must not be empty")
}
sshKeyPath = pickSSHKeyPath(sshKeyPath)
salt := make([]byte, saltLen)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return "", fmt.Errorf("credential: failed to generate salt: %w", err)
}
key, err := deriveKey(passphrase, sshKeyPath, salt)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("credential: cipher init: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("credential: gcm init: %w", err)
}
nonce := make([]byte, nonceLen)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("credential: failed to generate nonce: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
blob := make([]byte, 0, saltLen+nonceLen+len(ciphertext))
blob = append(blob, salt...)
blob = append(blob, nonce...)
blob = append(blob, ciphertext...)
return encScheme + base64.StdEncoding.EncodeToString(blob), nil
}
// isWithinDir reports whether path is contained within (or equal to) dir.
// Uses filepath.IsLocal on the relative path for robust cross-platform traversal detection.
func isWithinDir(path, dir string) bool {
rel, err := filepath.Rel(filepath.Clean(dir), filepath.Clean(path))
return err == nil && filepath.IsLocal(rel)
}
// allowedSSHKeyPath reports whether path is in a permitted location for SSH key files:
// - exact match with PICOCLAW_SSH_KEY_PATH env var
// - within the PICOCLAW_HOME env var directory
// - within ~/.ssh/
func allowedSSHKeyPath(path string) bool {
if path == "" {
return true // passphrase-only mode; no file will be read
}
clean := filepath.Clean(path)
// Exact match with PICOCLAW_SSH_KEY_PATH.
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
if clean == filepath.Clean(envPath) {
return true
}
}
// Within PICOCLAW_HOME.
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
if isWithinDir(clean, picoHome) {
return true
}
}
// Within ~/.ssh/.
if userHome, err := os.UserHomeDir(); err == nil {
if isWithinDir(clean, filepath.Join(userHome, ".ssh")) {
return true
}
}
return false
}
// deriveKey derives a 32-byte AES-256 key from passphrase and SSH private key.
//
// ikm = HMAC-SHA256(key=SHA256(sshKeyBytes), msg=passphrase)
// Final key: HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
// sshKeyPath must be non-empty; returns an error otherwise.
func deriveKey(passphrase, sshKeyPath string, salt []byte) ([]byte, error) {
if sshKeyPath == "" {
return nil, fmt.Errorf(
"credential: SSH private key is required but not found" +
" (set PICOCLAW_SSH_KEY_PATH or place key at ~/.ssh/picoclaw_ed25519.key)")
}
if !allowedSSHKeyPath(sshKeyPath) {
return nil, fmt.Errorf(
"credential: SSH key path %q is not in an allowed location (PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/)",
sshKeyPath,
)
}
sshBytes, err := os.ReadFile(sshKeyPath)
if err != nil {
return nil, fmt.Errorf("credential: cannot read SSH key %q: %w", sshKeyPath, err)
}
sshHash := sha256.Sum256(sshBytes)
mac := hmac.New(sha256.New, sshHash[:])
mac.Write([]byte(passphrase))
ikm := mac.Sum(nil)
key, err := hkdf.Key(sha256.New, ikm, salt, hkdfInfo, keyLen)
if err != nil {
return nil, fmt.Errorf("credential: HKDF expand failed: %w", err)
}
return key, nil
}
// pickSSHKeyPath returns the SSH private key path to use for encryption/decryption.
//
// Priority:
// 1. override (non-empty explicit argument)
// 2. PICOCLAW_SSH_KEY_PATH env var
// 3. ~/.ssh/picoclaw_ed25519.key (auto-detection)
//
// Returns "" when no key is found; deriveKey will return an error in that case.
func pickSSHKeyPath(override string) string {
if override != "" {
return override
}
if p, ok := os.LookupEnv(sshKeyEnv); ok {
return p // respect explicit setting, even if ""
}
return findDefaultSSHKey()
}
// findDefaultSSHKey returns the picoclaw-specific SSH key path if it exists.
func findDefaultSSHKey() string {
p, err := DefaultSSHKeyPath()
if err != nil {
return ""
}
if _, err := os.Stat(p); err == nil {
return p
}
return ""
}
+283
View File
@@ -0,0 +1,283 @@
package credential_test
import (
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/credential"
)
func TestResolve_PlainKey(t *testing.T) {
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve("sk-plaintext-key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-plaintext-key" {
t.Fatalf("got %q, want %q", got, "sk-plaintext-key")
}
}
func TestResolve_FileKey_Success(t *testing.T) {
dir := t.TempDir()
keyFile := "openai_plain.key"
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte("sk-from-file\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
got, err := r.Resolve("file://" + keyFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-from-file" {
t.Fatalf("got %q, want %q", got, "sk-from-file")
}
}
func TestResolve_FileKey_NotFound(t *testing.T) {
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("file://missing.key")
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
}
func TestResolve_FileKey_Empty(t *testing.T) {
dir := t.TempDir()
keyFile := "empty.key"
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte(" \n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
_, err := r.Resolve("file://" + keyFile)
if err == nil {
t.Fatal("expected error for empty credential file, got nil")
}
}
// TestResolve_EncKey_RoundTrip tests basic encryption/decryption round-trip with an SSH key.
func TestResolve_EncKey_RoundTrip(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase-32bytes-long-ok!"
const plaintext = "sk-encrypted-secret"
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt(passphrase, "", plaintext)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve(enc)
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if got != plaintext {
t.Fatalf("got %q, want %q", got, plaintext)
}
}
// TestResolve_EncKey_WithSSHKey tests that the SSH key file is incorporated into key derivation.
func TestResolve_EncKey_WithSSHKey(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-private-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase"
const plaintext = "sk-ssh-protected-secret"
// Set PICOCLAW_SSH_KEY_PATH before Encrypt so the path passes allowedSSHKeyPath validation.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt(passphrase, sshKeyPath, plaintext)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve(enc)
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if got != plaintext {
t.Fatalf("got %q, want %q", got, plaintext)
}
}
func TestResolve_EncKey_NoPassphrase(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("some-passphrase", "", "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected error when PICOCLAW_KEY_PASSPHRASE is unset, got nil")
}
}
func TestResolve_EncKey_BadCiphertext(t *testing.T) {
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("enc://!!not-valid-base64!!")
if err == nil {
t.Fatal("expected error for invalid enc:// payload, got nil")
}
}
func TestResolve_EncKey_PayloadTooShort(t *testing.T) {
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
// Valid base64 but fewer bytes than salt(16)+nonce(12)+1 minimum.
import64 := "dG9vc2hvcnQ=" // "tooshort" = 8 bytes
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("enc://" + import64)
if err == nil {
t.Fatal("expected error for too-short enc:// payload, got nil")
}
}
func TestResolve_EncKey_WrongPassphrase(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("correct-passphrase", "", "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "wrong-passphrase")
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected decryption error for wrong passphrase, got nil")
}
}
func TestEncrypt_EmptyPassphrase(t *testing.T) {
_, err := credential.Encrypt("", "", "sk-secret")
if err == nil {
t.Fatal("expected error for empty passphrase, got nil")
}
}
func TestDeriveKey_SSHKeyNotFound(t *testing.T) {
// Encrypt with a real SSH key path, then try to decrypt with a missing path.
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Register the real key path so allowedSSHKeyPath validation passes for Encrypt.
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
// Point to a non-existent SSH key so deriveKey's ReadFile fails.
// The path is still under the same dir, so allowedSSHKeyPath passes (exact env match).
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", filepath.Join(dir, "nonexistent_key"))
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected error when SSH key file is missing, got nil")
}
}
// TestResolve_FileRef_PathTraversal verifies that file:// references cannot escape configDir
// via relative traversal ("../../etc/passwd") or absolute paths ("/abs/path").
func TestResolve_FileRef_PathTraversal(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Create a file outside configDir that the traversal would point to.
outsideFile := filepath.Join(t.TempDir(), "secret.key")
if err := os.WriteFile(outsideFile, []byte("stolen"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(filepath.Dir(cfgPath))
cases := []string{
"file://../../secret.key",
"file://../secret.key",
"file://" + outsideFile, // absolute path
}
for _, raw := range cases {
_, err := r.Resolve(raw)
if err == nil {
t.Errorf("Resolve(%q): expected path traversal error, got nil", raw)
}
}
}
// TestResolve_FileRef_withinConfigDir verifies that a legitimate relative file:// ref works.
func TestResolve_FileRef_withinConfigDir(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "my.key"), []byte("sk-valid\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
got, err := r.Resolve("file://my.key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-valid" {
t.Fatalf("got %q, want %q", got, "sk-valid")
}
}
// TestEncrypt_SSHKeyOutsideAllowedDirs verifies that Encrypt rejects SSH key paths
// that are not under PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/.
func TestEncrypt_SSHKeyOutsideAllowedDirs(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Make sure none of the allowed env vars point here.
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
t.Setenv("PICOCLAW_HOME", "")
_, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
if err == nil {
t.Fatal("expected error for SSH key outside allowed directories, got nil")
}
}
+62
View File
@@ -0,0 +1,62 @@
package credential
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"golang.org/x/crypto/ssh"
)
// DefaultSSHKeyPath returns the canonical path for the picoclaw-specific SSH key.
// The path is always ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform).
func DefaultSSHKeyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("credential: cannot determine home directory: %w", err)
}
return filepath.Join(home, ".ssh", "picoclaw_ed25519.key"), nil
}
// GenerateSSHKey generates an Ed25519 SSH key pair and writes the private key
// to path (permissions 0600) and the public key to path+".pub" (permissions 0644).
// The ~/.ssh/ directory is created with 0700 if it does not exist.
// If the files already exist they are overwritten.
func GenerateSSHKey(path string) error {
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("credential: keygen: cannot create directory %q: %w", filepath.Dir(path), err)
}
pubRaw, privRaw, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("credential: keygen: ed25519 key generation failed: %w", err)
}
// Marshal private key as OpenSSH PEM.
block, err := ssh.MarshalPrivateKey(privRaw, "")
if err != nil {
return fmt.Errorf("credential: keygen: marshal private key: %w", err)
}
privPEM := pem.EncodeToMemory(block)
if err = os.WriteFile(path, privPEM, 0o600); err != nil {
return fmt.Errorf("credential: keygen: write private key %q: %w", path, err)
}
// Marshal public key as authorized_keys line.
sshPub, err := ssh.NewPublicKey(pubRaw)
if err != nil {
return fmt.Errorf("credential: keygen: marshal public key: %w", err)
}
pubLine := ssh.MarshalAuthorizedKey(sshPub)
pubPath := path + ".pub"
if err := os.WriteFile(pubPath, pubLine, 0o644); err != nil {
return fmt.Errorf("credential: keygen: write public key %q: %w", pubPath, err)
}
return nil
}
+115
View File
@@ -0,0 +1,115 @@
package credential
import (
"crypto/ed25519"
"os"
"path/filepath"
"runtime"
"testing"
"golang.org/x/crypto/ssh"
)
func TestGenerateSSHKey_CreatesFiles(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "test_ed25519.key")
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("GenerateSSHKey() error = %v", err)
}
// Private key must exist.
privInfo, err := os.Stat(keyPath)
if err != nil {
t.Fatalf("private key file missing: %v", err)
}
// Check permissions on non-Windows (Windows does not support Unix permission bits).
if runtime.GOOS != "windows" {
if got := privInfo.Mode().Perm(); got != 0o600 {
t.Errorf("private key permissions = %04o, want 0600", got)
}
}
// Public key must exist.
pubPath := keyPath + ".pub"
pubInfo, err := os.Stat(pubPath)
if err != nil {
t.Fatalf("public key file missing: %v", err)
}
if runtime.GOOS != "windows" {
if got := pubInfo.Mode().Perm(); got != 0o644 {
t.Errorf("public key permissions = %04o, want 0644", got)
}
}
// Private key must be parseable as an OpenSSH ed25519 key.
privPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read private key: %v", err)
}
privKey, err := ssh.ParseRawPrivateKey(privPEM)
if err != nil {
t.Fatalf("parse private key: %v", err)
}
if _, ok := privKey.(*ed25519.PrivateKey); !ok {
t.Errorf("private key type = %T, want *ed25519.PrivateKey", privKey)
}
// Public key must be parseable as authorized_keys line.
pubBytes, err := os.ReadFile(pubPath)
if err != nil {
t.Fatalf("read public key: %v", err)
}
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(pubBytes)
if err != nil {
t.Fatalf("parse public key: %v", err)
}
if pubKey == nil {
t.Fatal("expected non-nil public key")
}
if len(rest) > 0 {
t.Errorf("unexpected trailing bytes after public key: %d bytes", len(rest))
}
}
func TestGenerateSSHKey_OverwritesExisting(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "test_ed25519.key")
// Generate twice; second call must not error and must produce a different key.
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("first GenerateSSHKey() error = %v", err)
}
first, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read first key: %v", err)
}
if err = GenerateSSHKey(keyPath); err != nil {
t.Fatalf("second GenerateSSHKey() error = %v", err)
}
second, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read second key: %v", err)
}
// Two independently generated Ed25519 keys must differ.
if string(first) == string(second) {
t.Error("expected overwritten key to differ from original")
}
}
func TestGenerateSSHKey_CreatesDirectory(t *testing.T) {
dir := t.TempDir()
// Nested directory that does not yet exist.
keyPath := filepath.Join(dir, "subdir", ".ssh", "picoclaw_ed25519.key")
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("GenerateSSHKey() error = %v", err)
}
if _, err := os.Stat(keyPath); err != nil {
t.Fatalf("private key not created: %v", err)
}
}
+44
View File
@@ -0,0 +1,44 @@
package credential
import "sync/atomic"
// SecureStore holds a passphrase in memory.
//
// Uses atomic.Pointer so reads and writes are lock-free.
// The passphrase is never written to disk; callers decide how to
// transport it outside this store (e.g., via cmd.Env or os.Environ).
type SecureStore struct {
val atomic.Pointer[string]
}
// NewSecureStore creates an empty SecureStore.
func NewSecureStore() *SecureStore {
return &SecureStore{}
}
// SetString stores the passphrase. An empty string clears the store.
func (s *SecureStore) SetString(passphrase string) {
if passphrase == "" {
s.val.Store(nil)
return
}
s.val.Store(&passphrase)
}
// Get returns the stored passphrase, or "" if not set.
func (s *SecureStore) Get() string {
if p := s.val.Load(); p != nil {
return *p
}
return ""
}
// IsSet reports whether a passphrase is currently stored.
func (s *SecureStore) IsSet() bool {
return s.val.Load() != nil
}
// Clear removes the stored passphrase.
func (s *SecureStore) Clear() {
s.val.Store(nil)
}
+81
View File
@@ -0,0 +1,81 @@
package credential
import (
"sync"
"testing"
)
func TestSecureStore_SetGet(t *testing.T) {
s := NewSecureStore()
if s.IsSet() {
t.Error("expected empty store")
}
s.SetString("hunter2")
if !s.IsSet() {
t.Error("expected store to be set")
}
if got := s.Get(); got != "hunter2" {
t.Errorf("Get() = %q, want %q", got, "hunter2")
}
}
func TestSecureStore_Clear(t *testing.T) {
s := NewSecureStore()
s.SetString("secret")
s.Clear()
if s.IsSet() {
t.Error("expected store to be empty after Clear()")
}
if got := s.Get(); got != "" {
t.Errorf("Get() after Clear() = %q, want empty", got)
}
}
func TestSecureStore_SetOverwrites(t *testing.T) {
s := NewSecureStore()
s.SetString("first")
s.SetString("second")
if got := s.Get(); got != "second" {
t.Errorf("Get() = %q, want %q", got, "second")
}
}
func TestSecureStore_EmptyPassphrase(t *testing.T) {
s := NewSecureStore()
s.SetString("") // empty → should not mark as set
if s.IsSet() {
t.Error("empty passphrase should not mark store as set")
}
}
func TestSecureStore_ConcurrentSetGet(t *testing.T) {
s := NewSecureStore()
const goroutines = 10
const iterations = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
if id%2 == 0 {
s.SetString("even")
} else {
s.SetString("odd")
}
_ = s.Get()
}
}(i)
}
wg.Wait()
final := s.Get()
if final != "" && final != "even" && final != "odd" {
t.Errorf("Get() returned unexpected value %q after concurrent Set/Get", final)
}
}
+594
View File
@@ -0,0 +1,594 @@
package gateway
import (
"context"
"fmt"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
_ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
_ "github.com/sipeed/picoclaw/pkg/channels/discord"
_ "github.com/sipeed/picoclaw/pkg/channels/feishu"
_ "github.com/sipeed/picoclaw/pkg/channels/irc"
_ "github.com/sipeed/picoclaw/pkg/channels/line"
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
_ "github.com/sipeed/picoclaw/pkg/channels/matrix"
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp"
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp_native"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/devices"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
serviceShutdownTimeout = 30 * time.Second
providerReloadTimeout = 30 * time.Second
gracefulShutdownTimeout = 15 * time.Second
)
type services struct {
CronService *cron.CronService
HeartbeatService *heartbeat.HeartbeatService
MediaStore media.MediaStore
ChannelManager *channels.Manager
DeviceService *devices.Service
HealthServer *health.Server
}
type startupBlockedProvider struct {
reason string
}
func (p *startupBlockedProvider) Chat(
_ context.Context,
_ []providers.Message,
_ []providers.ToolDefinition,
_ string,
_ map[string]any,
) (*providers.LLMResponse, error) {
return nil, fmt.Errorf("%s", p.reason)
}
func (p *startupBlockedProvider) GetDefaultModel() string {
return ""
}
// Run starts the gateway runtime using the configuration loaded from configPath.
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
if debug {
logger.SetLevel(logger.DEBUG)
fmt.Println("🔍 Debug mode enabled")
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
return fmt.Errorf("error loading config: %w", err)
}
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
fmt.Println("\n📦 Agent Status:")
startupInfo := agentLoop.GetStartupInfo()
toolsInfo := startupInfo["tools"].(map[string]any)
skillsInfo := startupInfo["skills"].(map[string]any)
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
fmt.Printf(" • Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
logger.InfoCF("agent", "Agent initialized",
map[string]any{
"tools_count": toolsInfo["count"],
"skills_total": skillsInfo["total"],
"skills_available": skillsInfo["available"],
})
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
if err != nil {
return err
}
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
fmt.Println("Press Ctrl+C to stop")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go agentLoop.Run(ctx)
var configReloadChan <-chan *config.Config
stopWatch := func() {}
if cfg.Gateway.HotReload {
configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
logger.Info("Config hot reload enabled")
}
defer stopWatch()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
for {
select {
case <-sigChan:
logger.Info("Shutting down...")
shutdownGateway(runningServices, agentLoop, provider, true)
return nil
case newCfg := <-configReloadChan:
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
if err != nil {
logger.Errorf("Config reload failed: %v", err)
}
}
}
}
func createStartupProvider(
cfg *config.Config,
allowEmptyStartup bool,
) (providers.LLMProvider, string, error) {
modelName := cfg.Agents.Defaults.GetModelName()
if modelName == "" && allowEmptyStartup {
reason := "no default model configured; gateway started in limited mode"
fmt.Printf("⚠ Warning: %s\n", reason)
logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
"limited_mode": true,
})
return &startupBlockedProvider{reason: reason}, "", nil
}
return providers.CreateProvider(cfg)
}
func setupAndStartServices(
cfg *config.Config,
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
) (*services, error) {
runningServices := &services{}
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
var err error
runningServices.CronService, err = setupCronTool(
agentLoop,
msgBus,
cfg.WorkspacePath(),
cfg.Agents.Defaults.RestrictToWorkspace,
execTimeout,
cfg,
)
if err != nil {
return nil, fmt.Errorf("error setting up cron service: %w", err)
}
if err = runningServices.CronService.Start(); err != nil {
return nil, fmt.Errorf("error starting cron service: %w", err)
}
fmt.Println("✓ Cron service started")
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
runningServices.HeartbeatService.SetBus(msgBus)
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
if err = runningServices.HeartbeatService.Start(); err != nil {
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
}
fmt.Println("✓ Heartbeat service started")
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
return nil, fmt.Errorf("error creating channel manager: %w", err)
}
agentLoop.SetChannelManager(runningServices.ChannelManager)
agentLoop.SetMediaStore(runningServices.MediaStore)
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
agentLoop.SetTranscriber(transcriber)
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println("⚠ Warning: No channels enabled")
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return nil, fmt.Errorf("error starting channels: %w", err)
}
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
stateManager := state.NewManager(cfg.WorkspacePath())
runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
runningServices.DeviceService.SetBus(msgBus)
if err = runningServices.DeviceService.Start(context.Background()); err != nil {
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println("✓ Device event service started")
}
return runningServices, nil
}
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer shutdownCancel()
if runningServices.ChannelManager != nil {
runningServices.ChannelManager.StopAll(shutdownCtx)
}
if runningServices.DeviceService != nil {
runningServices.DeviceService.Stop()
}
if runningServices.HeartbeatService != nil {
runningServices.HeartbeatService.Stop()
}
if runningServices.CronService != nil {
runningServices.CronService.Stop()
}
if runningServices.MediaStore != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
}
}
func shutdownGateway(
runningServices *services,
agentLoop *agent.AgentLoop,
provider providers.LLMProvider,
fullShutdown bool,
) {
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
cp.Close()
}
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
agentLoop.Stop()
agentLoop.Close()
logger.Info("✓ Gateway stopped")
}
func handleConfigReload(
ctx context.Context,
al *agent.AgentLoop,
newCfg *config.Config,
providerRef *providers.LLMProvider,
runningServices *services,
msgBus *bus.MessageBus,
allowEmptyStartup bool,
) error {
logger.Info("🔄 Config file changed, reloading...")
newModel := newCfg.Agents.Defaults.ModelName
if newModel == "" {
newModel = newCfg.Agents.Defaults.Model
}
logger.Infof(" New model is '%s', recreating provider...", newModel)
logger.Info(" Stopping all services...")
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
if err != nil {
logger.Errorf(" ⚠ Error creating new provider: %v", err)
logger.Warn(" Attempting to restart services with old provider and config...")
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error creating new provider: %w", err)
}
if newModelID != "" {
newCfg.Agents.Defaults.ModelName = newModelID
}
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
defer reloadCancel()
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
if cp, ok := newProvider.(providers.StatefulProvider); ok {
cp.Close()
}
logger.Warn(" Attempting to restart services with old provider and config...")
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error reloading agent loop: %w", err)
}
*providerRef = newProvider
logger.Info(" Restarting all services with new configuration...")
if err := restartServices(al, runningServices, msgBus); err != nil {
logger.Errorf(" ⚠ Error restarting services: %v", err)
return fmt.Errorf("error restarting services: %w", err)
}
logger.Info(" ✓ Provider, configuration, and services reloaded successfully (thread-safe)")
return nil
}
func restartServices(
al *agent.AgentLoop,
runningServices *services,
msgBus *bus.MessageBus,
) error {
cfg := al.GetConfig()
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
var err error
runningServices.CronService, err = setupCronTool(
al,
msgBus,
cfg.WorkspacePath(),
cfg.Agents.Defaults.RestrictToWorkspace,
execTimeout,
cfg,
)
if err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
if err = runningServices.CronService.Start(); err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
fmt.Println(" ✓ Cron service restarted")
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
runningServices.HeartbeatService.SetBus(msgBus)
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
if err = runningServices.HeartbeatService.Start(); err != nil {
return fmt.Errorf("error restarting heartbeat service: %w", err)
}
fmt.Println(" ✓ Heartbeat service restarted")
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
al.SetMediaStore(runningServices.MediaStore)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
return fmt.Errorf("error recreating channel manager: %w", err)
}
al.SetChannelManager(runningServices.ChannelManager)
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println(" ⚠ Warning: No channels enabled")
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return fmt.Errorf("error restarting channels: %w", err)
}
fmt.Printf(
" ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n",
cfg.Gateway.Host,
cfg.Gateway.Port,
)
stateManager := state.NewManager(cfg.WorkspacePath())
runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
runningServices.DeviceService.SetBus(msgBus)
if err := runningServices.DeviceService.Start(context.Background()); err != nil {
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println(" ✓ Device event service restarted")
}
transcriber := voice.DetectTranscriber(cfg)
al.SetTranscriber(transcriber)
if transcriber != nil {
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
} else {
logger.InfoCF("voice", "Transcription disabled", nil)
}
return nil
}
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
configChan := make(chan *config.Config, 1)
stop := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
lastModTime := getFileModTime(configPath)
lastSize := getFileSize(configPath)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
currentModTime := getFileModTime(configPath)
currentSize := getFileSize(configPath)
if currentModTime.After(lastModTime) || currentSize != lastSize {
if debug {
logger.Debugf("🔍 Config file change detected")
}
time.Sleep(500 * time.Millisecond)
lastModTime = currentModTime
lastSize = currentSize
newCfg, err := config.LoadConfig(configPath)
if err != nil {
logger.Errorf("⚠ Error loading new config: %v", err)
logger.Warn(" Using previous valid config")
continue
}
if err := newCfg.ValidateModelList(); err != nil {
logger.Errorf(" ⚠ New config validation failed: %v", err)
logger.Warn(" Using previous valid config")
continue
}
logger.Info("✓ Config file validated and loaded")
select {
case configChan <- newCfg:
default:
logger.Warn("⚠ Previous config reload still in progress, skipping")
}
}
case <-stop:
return
}
}
}()
stopFunc := func() {
close(stop)
wg.Wait()
}
return configChan, stopFunc
}
func getFileModTime(path string) time.Time {
info, err := os.Stat(path)
if err != nil {
return time.Time{}
}
return info.ModTime()
}
func getFileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {
return 0
}
return info.Size()
}
func setupCronTool(
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
workspace string,
restrict bool,
execTimeout time.Duration,
cfg *config.Config,
) (*cron.CronService, error) {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
cronService := cron.NewCronService(cronStorePath, nil)
var cronTool *tools.CronTool
if cfg.Tools.IsToolEnabled("cron") {
var err error
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
if err != nil {
return nil, fmt.Errorf("critical error during CronTool initialization: %w", err)
}
agentLoop.RegisterTool(cronTool)
}
if cronTool != nil {
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
result := cronTool.ExecuteJob(context.Background(), job)
return result, nil
})
}
return cronService, nil
}
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
return func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
return tools.SilentResult(response)
}
}
+3
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"maps"
"net/http"
"os"
"sync"
"time"
)
@@ -29,6 +30,7 @@ type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
Pid int `json:"pid"`
}
func NewServer(host string, port int) *Server {
@@ -112,6 +114,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
Pid: os.Getpid(),
}
json.NewEncoder(w).Encode(resp)
+25 -12
View File
@@ -2,7 +2,20 @@
package logger
import "fmt"
import (
"fmt"
"regexp"
)
// botTokenRe matches the bot ID prefix and the secret part of a Telegram bot token.
// Groups: 1 = "bot<id>:", 2 = first 4 chars of secret, 3 = middle, 4 = last 4 chars.
var botTokenRe = regexp.MustCompile(`(bot\d+:)([A-Za-z0-9_-]{4})[A-Za-z0-9_-]{12,}([A-Za-z0-9_-]{4})`)
// maskSecrets replaces any embedded bot tokens in s with a redacted placeholder
// that keeps the first and last 4 characters of the secret for identification.
func maskSecrets(s string) string {
return botTokenRe.ReplaceAllString(s, "${1}${2}****${3}")
}
// Logger implements common Logger interface
type Logger struct {
@@ -12,52 +25,52 @@ type Logger struct {
// Debug logs debug messages
func (b *Logger) Debug(v ...any) {
logMessage(DEBUG, b.component, fmt.Sprint(v...), nil)
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Info logs info messages
func (b *Logger) Info(v ...any) {
logMessage(INFO, b.component, fmt.Sprint(v...), nil)
logMessage(INFO, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Warn logs warning messages
func (b *Logger) Warn(v ...any) {
logMessage(WARN, b.component, fmt.Sprint(v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Error logs error messages
func (b *Logger) Error(v ...any) {
logMessage(ERROR, b.component, fmt.Sprint(v...), nil)
logMessage(ERROR, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Debugf logs formatted debug messages
func (b *Logger) Debugf(format string, v ...any) {
logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil)
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Infof logs formatted info messages
func (b *Logger) Infof(format string, v ...any) {
logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil)
logMessage(INFO, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Warnf logs formatted warning messages
func (b *Logger) Warnf(format string, v ...any) {
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Warningf logs formatted warning messages
func (b *Logger) Warningf(format string, v ...any) {
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Errorf logs formatted error messages
func (b *Logger) Errorf(format string, v ...any) {
logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil)
logMessage(ERROR, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Fatalf logs formatted fatal messages and exits
func (b *Logger) Fatalf(format string, v ...any) {
logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil)
logMessage(FATAL, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Log logs a message at a given level with caller information
@@ -75,7 +88,7 @@ func (b *Logger) Log(msgL, caller int, format string, a ...any) {
level = lvl
}
}
logMessage(level, b.component, fmt.Sprintf(format, a...), nil)
logMessage(level, b.component, maskSecrets(fmt.Sprintf(format, a...)), nil)
}
// Sync flushes log buffer (no-op for this implementation)
+13
View File
@@ -0,0 +1,13 @@
package media
import (
"os"
"path/filepath"
)
const TempDirName = "picoclaw_media"
// TempDir returns the shared temporary directory used for downloaded media.
func TempDir() string {
return filepath.Join(os.TempDir(), TempDirName)
}
+7 -1
View File
@@ -221,11 +221,17 @@ func buildRequestBody(
// Add tool_use blocks
for _, tc := range msg.ToolCalls {
// Handle nil Arguments (GLM-4 may return null input)
input := tc.Arguments
if input == nil {
input = map[string]any{}
}
toolUse := map[string]any{
"type": "tool_use",
"id": tc.ID,
"name": tc.Name,
"input": tc.Arguments,
"input": input,
}
content = append(content, toolUse)
}
+53 -20
View File
@@ -20,10 +20,12 @@ type JobExecutor interface {
// CronTool provides scheduling capabilities for the agent
type CronTool struct {
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
allowCommand bool
execEnabled bool
}
// NewCronTool creates a new CronTool
@@ -32,17 +34,32 @@ func NewCronTool(
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
execTimeout time.Duration, config *config.Config,
) (*CronTool, error) {
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
if err != nil {
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
allowCommand := true
execEnabled := true
if config != nil {
allowCommand = config.Tools.Cron.AllowCommand
execEnabled = config.Tools.Exec.Enabled
}
execTool.SetTimeout(execTimeout)
var execTool *ExecTool
if execEnabled {
var err error
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
if err != nil {
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
}
}
if execTool != nil {
execTool.SetTimeout(execTimeout)
}
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
allowCommand: allowCommand,
execEnabled: execEnabled,
}, nil
}
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
},
"command_confirm": map[string]any{
"type": "boolean",
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
},
"at_seconds": map[string]any{
"type": "integer",
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
},
"deliver": map[string]any{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
},
},
"required": []string{"action"},
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
}
// Read deliver parameter, default to true
deliver := true
// Read deliver parameter, default to false so scheduled tasks execute through the agent
deliver := false
if d, ok := args["deliver"].(bool); ok {
deliver = d
}
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
// Non-command reminders (plain messages) remain open to all channels.
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
// allow_command is disabled, explicit confirmation is required as an override.
// Non-command reminders remain open to all channels.
command, _ := args["command"].(string)
commandConfirm, _ := args["command_confirm"].(bool)
if command != "" {
if !t.execEnabled {
return ErrorResult("command execution is disabled")
}
if !constants.IsInternalChannel(channel) {
return ErrorResult("scheduling command execution is restricted to internal channels")
}
if !commandConfirm {
return ErrorResult("command_confirm=true is required to schedule command execution")
if !t.allowCommand && !commandConfirm {
return ErrorResult("command_confirm=true is required when allow_command is disabled")
}
deliver = false
}
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// Execute command if present
if job.Payload.Command != "" {
if !t.execEnabled || t.execTool == nil {
output := "Error executing scheduled command: command execution is disabled"
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: output,
})
return "ok"
}
args := map[string]any{
"command": job.Payload.Command,
"__channel": channel,
+126 -6
View File
@@ -5,18 +5,18 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
)
func newTestCronTool(t *testing.T) *CronTool {
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
t.Helper()
storePath := filepath.Join(t.TempDir(), "cron.json")
cronService := cron.NewCronService(storePath, nil)
msgBus := bus.NewMessageBus()
cfg := config.DefaultConfig()
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
if err != nil {
t.Fatalf("NewCronTool() error: %v", err)
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
return tool
}
func newTestCronTool(t *testing.T) *CronTool {
t.Helper()
return newTestCronToolWithConfig(t, config.DefaultConfig())
}
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
tool := newTestCronTool(t)
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
}
}
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
"at_seconds": float64(60),
})
if result.IsError {
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
}
}
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Cron.AllowCommand = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected error when command_confirm is missing")
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
}
if !strings.Contains(result.ForLLM, "command_confirm=true") {
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
}
}
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Cron.AllowCommand = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if result.IsError {
t.Fatalf(
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
result.ForLLM,
)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
}
}
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Exec.Enabled = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected command scheduling to be blocked when exec is disabled")
}
if !strings.Contains(result.ForLLM, "command execution is disabled") {
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
}
}
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
}
}
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "send me a poem",
"at_seconds": float64(600),
})
if result.IsError {
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
}
jobs := tool.cronService.ListJobs(false)
if len(jobs) != 1 {
t.Fatalf("expected 1 job, got %d", len(jobs))
}
if jobs[0].Payload.Deliver {
t.Fatal("expected deliver=false by default for non-command jobs")
}
}
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Exec.Enabled = false
tool := newTestCronToolWithConfig(t, cfg)
job := &cron.CronJob{}
job.Payload.Channel = "cli"
job.Payload.To = "direct"
job.Payload.Command = "df -h"
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
t.Fatalf("ExecuteJob() = %q, want ok", got)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
}
if !strings.Contains(msg.Content, "command execution is disabled") {
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
}
}
+161 -9
View File
@@ -20,8 +20,7 @@ import (
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
if workspace == "" {
return path, fmt.Errorf("workspace is not defined")
}
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
}
if restrict {
if isAllowedPath(absPath, patterns) {
return absPath, nil
}
if !isWithinWorkspace(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
return absPath, nil
}
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
if len(patterns) == 0 {
return false
}
cleaned := filepath.Clean(path)
if !filepath.IsAbs(cleaned) {
return false
}
if !matchesAllowedPath(cleaned, patterns) {
return false
}
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
if err != nil {
return false
}
return matchesAllowedPath(resolved, patterns)
}
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
cleaned := filepath.Clean(path)
for _, pattern := range patterns {
if pattern.MatchString(cleaned) {
return true
}
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
return true
}
}
return false
}
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
raw := pattern.String()
if !strings.HasPrefix(raw, "^") {
return "", false
}
literal := strings.TrimPrefix(raw, "^")
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
literal = strings.TrimSuffix(literal, "(?:/|$)")
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
// Reject patterns that still contain regex operators after removing the
// optional anchored-directory suffix. That keeps arbitrary regex behavior
// unchanged and only enables normalized prefix matching for literal paths.
if containsUnescapedRegexMeta(literal) {
return "", false
}
unescaped, ok := unescapeRegexLiteral(literal)
if !ok || unescaped == "" {
return "", false
}
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
}
func appendUniquePath(paths []string, path string) []string {
for _, existing := range paths {
if existing == path {
return paths
}
}
return append(paths, path)
}
func containsUnescapedRegexMeta(s string) bool {
escaped := false
for _, r := range s {
if escaped {
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
switch r {
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
return true
}
}
return escaped
}
func unescapeRegexLiteral(s string) (string, bool) {
var b strings.Builder
b.Grow(len(s))
escaped := false
for _, r := range s {
if escaped {
b.WriteRune(r)
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
b.WriteRune(r)
}
if escaped {
return "", false
}
return b.String(), true
}
func isWithinAllowedRoot(path, root string) bool {
candidate := filepath.Clean(path)
allowedVariants := []string{filepath.Clean(root)}
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
}
for _, allowedRoot := range allowedVariants {
if isWithinWorkspace(candidate, allowedRoot) {
return true
}
}
return false
}
func resolveExistingAncestor(path string) (string, error) {
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
if resolved, err := filepath.EvalSymlinks(current); err == nil {
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
}
}
func resolvePathAgainstExistingAncestor(path string) (string, error) {
cleaned := filepath.Clean(path)
for current := cleaned; ; current = filepath.Dir(current) {
resolved, err := filepath.EvalSymlinks(current)
if err == nil {
suffix, relErr := filepath.Rel(current, cleaned)
if relErr != nil {
return "", relErr
}
if suffix == "." {
return filepath.Clean(resolved), nil
}
return filepath.Clean(filepath.Join(resolved, suffix)), nil
}
if !os.IsNotExist(err) {
return "", err
}
if filepath.Dir(current) == current {
return "", os.ErrNotExist
}
}
}
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
return err == nil && filepath.IsLocal(rel)
return err == nil && (rel == "." || filepath.IsLocal(rel))
}
type ReadFileTool struct {
@@ -625,12 +782,7 @@ type whitelistFs struct {
}
func (w *whitelistFs) matches(path string) bool {
for _, p := range w.patterns {
if p.MatchString(path) {
return true
}
}
return false
return isAllowedPath(path, w.patterns)
}
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
+84
View File
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
}
}
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
workspace := t.TempDir()
allowedDir := t.TempDir()
secretDir := t.TempDir()
secretFile := filepath.Join(secretDir, "secret.txt")
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
t.Fatalf("WriteFile(secretFile) error = %v", err)
}
linkPath := filepath.Join(allowedDir, "link_out")
if err := os.Symlink(secretDir, linkPath); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
if !result.IsError {
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
}
}
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
workspace := t.TempDir()
rootDir := t.TempDir()
allowedDir := filepath.Join(rootDir, "allowed")
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewWriteFileTool(workspace, true, patterns)
result := tool.Execute(context.Background(), map[string]any{
"path": targetFile,
"content": "outside write",
})
if result.IsError {
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
}
data, err := os.ReadFile(targetFile)
if err != nil {
t.Fatalf("ReadFile(targetFile) error = %v", err)
}
if string(data) != "outside write" {
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
}
}
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
workspace := t.TempDir()
realDir := t.TempDir()
linkParent := t.TempDir()
allowedAlias := filepath.Join(linkParent, "allowed-link")
if err := os.Symlink(realDir, allowedAlias); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
}
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
t.Fatalf("WriteFile(targetFile) error = %v", err)
}
patterns := []*regexp.Regexp{
regexp.MustCompile(
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
),
}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
if result.IsError {
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "through alias") {
t.Fatalf("expected file content, got: %s", result.ForLLM)
}
}
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
// by reading a file in multiple chunks using 'offset' and 'length'.
func TestReadFileTool_ChunkedReading(t *testing.T) {
+15 -2
View File
@@ -6,6 +6,7 @@ import (
"mime"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/h2non/filetype"
@@ -21,20 +22,32 @@ type SendFileTool struct {
restrict bool
maxFileSize int
mediaStore media.MediaStore
allowPaths []*regexp.Regexp
defaultChannel string
defaultChatID string
}
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
func NewSendFileTool(
workspace string,
restrict bool,
maxFileSize int,
store media.MediaStore,
allowPaths ...[]*regexp.Regexp,
) *SendFileTool {
if maxFileSize <= 0 {
maxFileSize = config.DefaultMaxMediaSize
}
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &SendFileTool{
workspace: workspace,
restrict: restrict,
maxFileSize: maxFileSize,
mediaStore: store,
allowPaths: patterns,
}
}
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult("media store not configured")
}
resolved, err := validatePath(path, t.workspace, t.restrict)
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
if err != nil {
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
}
+39
View File
@@ -4,6 +4,7 @@ import (
"context"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
}
}
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
if err != nil {
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
}
testPath := testFile.Name()
if _, err := testFile.WriteString("forward me"); err != nil {
testFile.Close()
t.Fatalf("WriteString(testFile) error = %v", err)
}
if err := testFile.Close(); err != nil {
t.Fatalf("Close(testFile) error = %v", err)
}
t.Cleanup(func() { _ = os.Remove(testPath) })
pattern := regexp.MustCompile(
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
)
store := media.NewFileMediaStore()
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
if result.IsError {
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
}
func TestDetectMediaType_MagicBytes(t *testing.T) {
dir := t.TempDir()
+31 -13
View File
@@ -23,6 +23,7 @@ type ExecTool struct {
denyPatterns []*regexp.Regexp
allowPatterns []*regexp.Regexp
customAllowPatterns []*regexp.Regexp
allowedPathPatterns []*regexp.Regexp
restrictToWorkspace bool
allowRemote bool
}
@@ -95,14 +96,23 @@ var (
}
)
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
}
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
func NewExecToolWithConfig(
workingDir string,
restrict bool,
config *config.Config,
allowPaths ...[]*regexp.Regexp,
) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
customAllowPatterns := make([]*regexp.Regexp, 0)
var allowedPathPatterns []*regexp.Regexp
allowRemote := true
if len(allowPaths) > 0 {
allowedPathPatterns = allowPaths[0]
}
if config != nil {
execConfig := config.Tools.Exec
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
allowedPathPatterns: allowedPathPatterns,
restrictToWorkspace: restrict,
allowRemote: allowRemote,
}, nil
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
cwd := t.workingDir
if wd, ok := args["working_dir"].(string); ok && wd != "" {
if t.restrictToWorkspace && t.workingDir != "" {
resolvedWD, err := validatePath(wd, t.workingDir, true)
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
if err != nil {
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
}
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
if err != nil {
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
}
absWorkspace, _ := filepath.Abs(t.workingDir)
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
if wsResolved == "" {
wsResolved = absWorkspace
if isAllowedPath(resolved, t.allowedPathPatterns) {
cwd = resolved
} else {
absWorkspace, _ := filepath.Abs(t.workingDir)
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
if wsResolved == "" {
wsResolved = absWorkspace
}
rel, err := filepath.Rel(wsResolved, resolved)
if err != nil || !filepath.IsLocal(rel) {
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
}
cwd = resolved
}
rel, err := filepath.Rel(wsResolved, resolved)
if err != nil || !filepath.IsLocal(rel) {
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
}
cwd = resolved
}
// timeout == 0 means no timeout
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
if safePaths[p] {
continue
}
if isAllowedPath(p, t.allowedPathPatterns) {
continue
}
rel, err := filepath.Rel(cwdPath, p)
if err != nil {
+178
View File
@@ -0,0 +1,178 @@
package tools
import (
"context"
"fmt"
"sort"
"strings"
"time"
)
// SpawnStatusTool reports the status of subagents that were spawned via the
// spawn tool. It can query a specific task by ID, or list every known task with
// a summary count broken-down by status.
type SpawnStatusTool struct {
manager *SubagentManager
}
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
return &SpawnStatusTool{manager: manager}
}
func (t *SpawnStatusTool) Name() string {
return "spawn_status"
}
func (t *SpawnStatusTool) Description() string {
return "Get the status of spawned subagents. " +
"Returns a list of all subagents and their current state " +
"(running, completed, failed, or canceled), or retrieves details " +
"for a specific subagent task when task_id is provided. " +
"Results are scoped to the current conversation's channel and chat ID; " +
"all tasks are listed only when no channel/chat context is injected " +
"(e.g. direct programmatic calls via Execute)."
}
func (t *SpawnStatusTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"task_id": map[string]any{
"type": "string",
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
"subagent. When omitted, all visible subagents are listed.",
},
},
"required": []string{},
}
}
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if t.manager == nil {
return ErrorResult("Subagent manager not configured")
}
// Derive the calling conversation's identity so we can scope results to the
// current chat only — preventing cross-conversation task leakage in
// multi-user deployments.
callerChannel := ToolChannel(ctx)
callerChatID := ToolChatID(ctx)
var taskID string
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
taskIDStr, ok := rawTaskID.(string)
if !ok {
return ErrorResult("task_id must be a string")
}
taskID = strings.TrimSpace(taskIDStr)
}
if taskID != "" {
// GetTaskCopy returns a consistent snapshot under the manager lock,
// eliminating any data race with the concurrent subagent goroutine.
taskCopy, ok := t.manager.GetTaskCopy(taskID)
if !ok {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
// Restrict lookup to tasks that belong to this conversation.
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
return NewToolResult(spawnStatusFormatTask(&taskCopy))
}
// ListTaskCopies returns consistent snapshots under the manager lock.
origTasks := t.manager.ListTaskCopies()
if len(origTasks) == 0 {
return NewToolResult("No subagents have been spawned yet.")
}
tasks := make([]*SubagentTask, 0, len(origTasks))
for i := range origTasks {
cpy := &origTasks[i]
// Filter to tasks that originate from the current conversation only.
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
continue
}
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
continue
}
tasks = append(tasks, cpy)
}
if len(tasks) == 0 {
return NewToolResult("No subagents found for this conversation.")
}
// Order by creation time (ascending) so spawning order is preserved.
// Fall back to ID string for tasks created in the same millisecond.
sort.Slice(tasks, func(i, j int) bool {
if tasks[i].Created != tasks[j].Created {
return tasks[i].Created < tasks[j].Created
}
return tasks[i].ID < tasks[j].ID
})
counts := map[string]int{}
for _, task := range tasks {
counts[task.Status]++
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
for _, status := range []string{"running", "completed", "failed", "canceled"} {
if n := counts[status]; n > 0 {
label := strings.ToUpper(status[:1]) + status[1:] + ":"
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
}
}
sb.WriteString("\n")
for _, task := range tasks {
sb.WriteString(spawnStatusFormatTask(task))
sb.WriteString("\n\n")
}
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
}
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
func spawnStatusFormatTask(task *SubagentTask) string {
var sb strings.Builder
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
if task.Label != "" {
header += fmt.Sprintf(" label=%q", task.Label)
}
if task.AgentID != "" {
header += fmt.Sprintf(" agent=%s", task.AgentID)
}
if task.Created > 0 {
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
header += fmt.Sprintf(" created=%s", created)
}
sb.WriteString(header)
if task.Task != "" {
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
}
if task.Result != "" {
result := task.Result
const maxResultLen = 300
runes := []rune(result)
if len(runes) > maxResultLen {
result = string(runes[:maxResultLen]) + "…"
}
sb.WriteString(fmt.Sprintf("\n result: %s", result))
}
return sb.String()
}
+406
View File
@@ -0,0 +1,406 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
"time"
)
func TestSpawnStatusTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
if tool.Name() != "spawn_status" {
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
}
}
func TestSpawnStatusTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(strings.ToLower(desc), "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
func TestSpawnStatusTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
params := tool.Parameters()
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Expected 'properties' to be a map")
}
if _, hasTaskID := props["task_id"]; !hasTaskID {
t.Error("Expected 'task_id' parameter in properties")
}
}
func TestSpawnStatusTool_NilManager(t *testing.T) {
tool := &SpawnStatusTool{manager: nil}
result := tool.Execute(context.Background(), map[string]any{})
if !result.IsError {
t.Error("Expected error result when manager is nil")
}
}
func TestSpawnStatusTool_Empty(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "No subagents") {
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ListAll(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
now := time.Now().UnixMilli()
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Do task A",
Label: "task-a",
Status: "running",
Created: now,
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2",
Task: "Do task B",
Label: "task-b",
Status: "completed",
Result: "Done successfully",
Created: now,
}
manager.tasks["subagent-3"] = &SubagentTask{
ID: "subagent-3",
Task: "Do task C",
Status: "failed",
Result: "Error: something went wrong",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
// Summary header
if !strings.Contains(result.ForLLM, "3 total") {
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
}
// Individual task IDs
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
if !strings.Contains(result.ForLLM, id) {
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
}
}
// Status values
for _, status := range []string{"running", "completed", "failed"} {
if !strings.Contains(result.ForLLM, status) {
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
}
}
// Result content
if !strings.Contains(result.ForLLM, "Done successfully") {
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_GetByID(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-42"] = &SubagentTask{
ID: "subagent-42",
Task: "Specific task",
Label: "my-task",
Status: "failed",
Result: "Something went wrong",
Created: time.Now().UnixMilli(),
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-42") {
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "failed") {
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Something went wrong") {
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "my-task") {
t.Errorf("Expected label in output, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
if !result.IsError {
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "nonexistent-999") {
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnStatusTool(manager)
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
if !result.IsError {
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
}
if !strings.Contains(result.ForLLM, "task_id must be a string") {
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
}
}
}
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
longResult := strings.Repeat("X", 500)
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Long task",
Status: "completed",
Result: longResult,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
// Output should be shorter than the raw result due to truncation
if len(result.ForLLM) >= len(longResult) {
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
}
if !strings.Contains(result.ForLLM, "…") {
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
cjkChar := string(rune(0x5b57))
longResult := strings.Repeat(cjkChar, 400)
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Unicode task",
Status: "completed",
Result: longResult,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "…") {
t.Errorf("Expected truncation indicator in output")
}
// The truncated result must be valid UTF-8 (no split rune boundaries).
if !strings.Contains(result.ForLLM, cjkChar) {
t.Errorf("Expected CJK runes to appear intact in output")
}
}
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
id := fmt.Sprintf("subagent-%d", i+1)
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
// The summary line should mention all statuses that have counts
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
if !strings.Contains(result.ForLLM, want) {
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
}
}
}
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
now := time.Now().UnixMilli()
manager.mu.Lock()
// Intentionally insert with out-of-order IDs and timestamps that reflect
// true spawn order: subagent-2 was spawned first, subagent-10 second.
manager.tasks["subagent-10"] = &SubagentTask{
ID: "subagent-10", Task: "second", Status: "running",
Created: now + 1,
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2", Task: "first", Status: "running",
Created: now,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
pos2 := strings.Index(result.ForLLM, "subagent-2")
pos10 := strings.Index(result.ForLLM, "subagent-10")
if pos2 < 0 || pos10 < 0 {
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
}
if pos2 > pos10 {
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1", Task: "mine", Status: "running",
OriginChannel: "telegram", OriginChatID: "chat-A",
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2", Task: "other user", Status: "running",
OriginChannel: "telegram", OriginChatID: "chat-B",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// Caller is chat-A — should only see subagent-1.
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
result := tool.Execute(ctx, map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-1") {
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
}
if strings.Contains(result.ForLLM, "subagent-2") {
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-99"] = &SubagentTask{
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
OriginChannel: "slack", OriginChatID: "room-Z",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// Different chat trying to look up subagent-99 by ID.
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
if !result.IsError {
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1", Task: "t", Status: "completed",
OriginChannel: "telegram", OriginChatID: "chat-A",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// No ToolContext injected (e.g. a direct programmatic call that bypasses
// WithToolContext entirely) — callerChannel and callerChatID are both "".
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
// which *does* inject a non-empty context; this test covers the case where
// no context injection happens at all.
// The filter conditions require a non-empty caller value, so all tasks pass through.
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-1") {
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
}
}
+25 -3
View File
@@ -109,9 +109,6 @@ func (sm *SubagentManager) Spawn(
}
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
task.Status = "running"
task.Created = time.Now().UnixMilli()
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
@@ -219,6 +216,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
return task, ok
}
// GetTaskCopy returns a copy of the task with the given ID, taken under the
// read lock, so the caller receives a consistent snapshot with no data race.
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
sm.mu.RLock()
defer sm.mu.RUnlock()
task, ok := sm.tasks[taskID]
if !ok {
return SubagentTask{}, false
}
return *task, true
}
func (sm *SubagentManager) ListTasks() []*SubagentTask {
sm.mu.RLock()
defer sm.mu.RUnlock()
@@ -230,6 +239,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
return tasks
}
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
// so callers receive consistent snapshots with no data race.
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
sm.mu.RLock()
defer sm.mu.RUnlock()
copies := make([]SubagentTask, 0, len(sm.tasks))
for _, task := range sm.tasks {
copies = append(copies, *task)
}
return copies
}
// SubagentTool executes a subagent task synchronously and returns the result.
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
+95 -10
View File
@@ -780,11 +780,17 @@ type WebFetchTool struct {
client *http.Client
format string
fetchLimitBytes int64
whitelist *privateHostWhitelist
}
type privateHostWhitelist struct {
exact map[string]struct{}
cidrs []*net.IPNet
}
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
// createHTTPClient cannot fail with an empty proxy string.
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes)
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes, nil)
}
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
@@ -792,9 +798,22 @@ func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFe
var allowPrivateWebFetchHosts atomic.Bool
func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil)
}
func NewWebFetchToolWithConfig(
maxChars int,
proxy string,
fetchLimitBytes int64,
privateHostWhitelist []string,
) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = defaultMaxChars
}
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
if err != nil {
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
}
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
@@ -804,13 +823,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
transport.DialContext = newSafeDialContext(dialer)
transport.DialContext = newSafeDialContext(dialer, whitelist)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
if isObviousPrivateHost(req.URL.Hostname()) {
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
return fmt.Errorf("redirect target is private or local network host")
}
return nil
@@ -824,6 +843,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
client: client,
format: format,
fetchLimitBytes: fetchLimitBytes,
whitelist: whitelist,
}, nil
}
@@ -875,7 +895,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
// The real SSRF guard is newSafeDialContext at connect time.
hostname := parsedURL.Hostname()
if isObviousPrivateHost(hostname) {
if isObviousPrivateHost(hostname, t.whitelist) {
return ErrorResult("fetching private or local network hosts is not allowed")
}
@@ -1019,7 +1039,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
func newSafeDialContext(
dialer *net.Dialer,
whitelist *privateHostWhitelist,
) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if allowPrivateWebFetchHosts.Load() {
return dialer.DialContext(ctx, network, address)
@@ -1034,7 +1057,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrRestrictedIP(ip) {
if shouldBlockPrivateIP(ip, whitelist) {
return nil, fmt.Errorf("blocked private or local target: %s", host)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
@@ -1048,7 +1071,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
attempted := 0
var lastErr error
for _, ipAddr := range ipAddrs {
if isPrivateOrRestrictedIP(ipAddr.IP) {
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
continue
}
attempted++
@@ -1060,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if attempted == 0 {
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
}
if lastErr != nil {
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
@@ -1069,10 +1092,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
}
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
if len(entries) == 0 {
return nil, nil
}
whitelist := &privateHostWhitelist{
exact: make(map[string]struct{}),
cidrs: make([]*net.IPNet, 0, len(entries)),
}
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
if ip := net.ParseIP(entry); ip != nil {
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
continue
}
_, network, err := net.ParseCIDR(entry)
if err != nil {
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
}
whitelist.cidrs = append(whitelist.cidrs, network)
}
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
return nil, nil
}
return whitelist, nil
}
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
if w == nil || ip == nil {
return false
}
normalized := normalizeWhitelistIP(ip)
if _, ok := w.exact[normalized.String()]; ok {
return true
}
for _, network := range w.cidrs {
if network.Contains(normalized) {
return true
}
}
return false
}
func normalizeWhitelistIP(ip net.IP) net.IP {
if ip == nil {
return nil
}
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
}
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
}
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
func isObviousPrivateHost(host string) bool {
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
if allowPrivateWebFetchHosts.Load() {
return false
}
@@ -1088,7 +1173,7 @@ func isObviousPrivateHost(host string) bool {
}
if ip := net.ParseIP(h); ip != nil {
return isPrivateOrRestrictedIP(ip)
return shouldBlockPrivateIP(ip, whitelist)
}
return false
+147
View File
@@ -10,6 +10,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -425,6 +426,29 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
})
}
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
t.Helper()
hostPort := strings.TrimPrefix(rawURL, "http://")
hostPort = strings.TrimPrefix(hostPort, "https://")
host, port, err := net.SplitHostPort(hostPort)
if err != nil {
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
}
return host, port
}
func singleHostCIDR(t *testing.T, host string) string {
t.Helper()
ip := net.ParseIP(host)
if ip == nil {
t.Fatalf("failed to parse IP %q", host)
}
if ip.To4() != nil {
return ip.String() + "/32"
}
return ip.String() + "/128"
}
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
@@ -443,6 +467,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("exact whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{host})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("cidr whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{singleHostCIDR(t, host)})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
@@ -572,6 +646,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
}
}
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err == nil {
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
}
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
accepted := make(chan struct{}, 1)
go func() {
conn, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
conn.Close()
accepted <- struct{}{}
}()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
if err != nil {
t.Fatalf("failed to parse whitelist: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err != nil {
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
}
conn.Close()
select {
case <-accepted:
case <-time.After(time.Second):
t.Fatal("expected localhost listener to accept a connection")
}
}
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
tests := []struct {
@@ -662,6 +799,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
}
}
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
_, err := NewWebFetchToolWithConfig(1024, "", testFetchLimit, []string{"not-an-ip-or-cidr"})
if err == nil {
t.Fatal("expected invalid whitelist entry to fail")
}
if !strings.Contains(err.Error(), "invalid entry") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
+2 -1
View File
@@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
@@ -67,7 +68,7 @@ func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{
"error": err.Error(),