mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into feat/markdown-output-format-web-fetch
This commit is contained in:
+22
-3
@@ -458,7 +458,23 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
//
|
||||
// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
// See: https://platform.openai.com/docs/guides/prompt-caching
|
||||
func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
|
||||
func formatCurrentSenderLine(senderID, senderDisplayName string) string {
|
||||
senderID = strings.TrimSpace(senderID)
|
||||
senderDisplayName = strings.TrimSpace(senderDisplayName)
|
||||
|
||||
switch {
|
||||
case senderDisplayName != "" && senderID != "":
|
||||
return fmt.Sprintf("Current sender: %s (ID: %s)", senderDisplayName, senderID)
|
||||
case senderDisplayName != "":
|
||||
return fmt.Sprintf("Current sender: %s", senderDisplayName)
|
||||
case senderID != "":
|
||||
return fmt.Sprintf("Current sender: %s", senderID)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) buildDynamicContext(channel, chatID, senderID, senderDisplayName string) string {
|
||||
now := time.Now().Format("2006-01-02 15:04 (Monday)")
|
||||
rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
|
||||
|
||||
@@ -468,6 +484,9 @@ func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
|
||||
if channel != "" && chatID != "" {
|
||||
fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
|
||||
}
|
||||
if senderLine := formatCurrentSenderLine(senderID, senderDisplayName); senderLine != "" {
|
||||
fmt.Fprintf(&sb, "\n\n## Current Sender\n%s", senderLine)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -477,7 +496,7 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
summary string,
|
||||
currentMessage string,
|
||||
media []string,
|
||||
channel, chatID string,
|
||||
channel, chatID, senderID, senderDisplayName string,
|
||||
) []providers.Message {
|
||||
messages := []providers.Message{}
|
||||
|
||||
@@ -493,7 +512,7 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
staticPrompt := cb.BuildSystemPromptWithCache()
|
||||
|
||||
// Build short dynamic context (time, runtime, session) — changes per request
|
||||
dynamicCtx := cb.buildDynamicContext(channel, chatID)
|
||||
dynamicCtx := cb.buildDynamicContext(channel, chatID, senderID, senderDisplayName)
|
||||
|
||||
// Compose a single system message: static (cached) + dynamic + optional summary.
|
||||
// Keeping all system content in one message ensures every provider adapter can
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1")
|
||||
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
|
||||
|
||||
systemCount := 0
|
||||
for _, m := range msgs {
|
||||
@@ -126,6 +126,68 @@ func TestSingleSystemMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nTest agent.",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
senderID string
|
||||
senderDisplayName string
|
||||
wantLine string
|
||||
wantSection bool
|
||||
}{
|
||||
{
|
||||
name: "both id and display name",
|
||||
senderID: "feishu:ou_xxx",
|
||||
senderDisplayName: "Zhang San",
|
||||
wantLine: "Current sender: Zhang San (ID: feishu:ou_xxx)",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "display name only",
|
||||
senderDisplayName: "Alice",
|
||||
wantLine: "Current sender: Alice",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "id only",
|
||||
senderID: "discord:123",
|
||||
wantLine: "Current sender: discord:123",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "no sender info",
|
||||
wantSection: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
|
||||
sys := msgs[0].Content
|
||||
|
||||
if tt.wantSection {
|
||||
if !strings.Contains(sys, "## Current Sender") {
|
||||
t.Fatalf("system prompt missing Current Sender section:\n%s", sys)
|
||||
}
|
||||
if !strings.Contains(sys, tt.wantLine) {
|
||||
t.Fatalf("system prompt missing sender line %q:\n%s", tt.wantLine, sys)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(sys, "## Current Sender") {
|
||||
t.Fatalf("system prompt should omit Current Sender section:\n%s", sys)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMtimeAutoInvalidation verifies that the cache detects source file changes
|
||||
// via mtime without requiring explicit InvalidateCache().
|
||||
// Fix: original implementation had no auto-invalidation — edits to bootstrap files,
|
||||
@@ -576,7 +638,7 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// Also exercise BuildMessages concurrently
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat")
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat", "", "")
|
||||
if len(msgs) < 2 {
|
||||
errs <- "BuildMessages returned fewer than 2 messages"
|
||||
return
|
||||
@@ -664,6 +726,6 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test")
|
||||
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test", "", "")
|
||||
}
|
||||
}
|
||||
|
||||
+25
-2
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -66,7 +67,7 @@ func NewAgentInstance(
|
||||
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
|
||||
|
||||
// Compile path whitelist patterns from config.
|
||||
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
@@ -82,7 +83,7 @@ func NewAgentInstance(
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("exec") {
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
@@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
|
||||
return compiled
|
||||
}
|
||||
|
||||
func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp {
|
||||
var configured []string
|
||||
if cfg != nil {
|
||||
configured = cfg.Tools.AllowReadPaths
|
||||
}
|
||||
|
||||
compiled := compilePatterns(configured)
|
||||
mediaDirPattern := regexp.MustCompile(mediaTempDirPattern())
|
||||
for _, pattern := range compiled {
|
||||
if pattern.String() == mediaDirPattern.String() {
|
||||
return compiled
|
||||
}
|
||||
}
|
||||
|
||||
return append(compiled, mediaDirPattern)
|
||||
}
|
||||
|
||||
func mediaTempDirPattern() string {
|
||||
sep := regexp.QuoteMeta(string(os.PathSeparator))
|
||||
return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)"
|
||||
}
|
||||
|
||||
// Close releases resources held by the agent's session store.
|
||||
func (a *AgentInstance) Close() error {
|
||||
if a.Sessions != nil {
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
|
||||
@@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
mediaPath := mediaFile.Name()
|
||||
if _, err := mediaFile.WriteString("attachment content"); err != nil {
|
||||
mediaFile.Close()
|
||||
t.Fatalf("WriteString(mediaFile) error = %v", err)
|
||||
}
|
||||
if err := mediaFile.Close(); err != nil {
|
||||
t.Fatalf("Close(mediaFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(mediaPath) })
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "test-model",
|
||||
RestrictToWorkspace: true,
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{Enabled: true},
|
||||
ListDir: config.ToolConfig{Enabled: true},
|
||||
Exec: config.ExecConfig{
|
||||
ToolConfig: config.ToolConfig{Enabled: true},
|
||||
EnableDenyPatterns: true,
|
||||
AllowRemote: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
|
||||
readTool, ok := agent.Tools.Get("read_file")
|
||||
if !ok {
|
||||
t.Fatal("read_file tool not registered")
|
||||
}
|
||||
readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath})
|
||||
if readResult.IsError {
|
||||
t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(readResult.ForLLM, "attachment content") {
|
||||
t.Fatalf("read_file output missing media content: %s", readResult.ForLLM)
|
||||
}
|
||||
|
||||
listTool, ok := agent.Tools.Get("list_dir")
|
||||
if !ok {
|
||||
t.Fatal("list_dir tool not registered")
|
||||
}
|
||||
listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir})
|
||||
if listResult.IsError {
|
||||
t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) {
|
||||
t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM)
|
||||
}
|
||||
|
||||
execTool, ok := agent.Tools.Get("exec")
|
||||
if !ok {
|
||||
t.Fatal("exec tool not registered")
|
||||
}
|
||||
execResult := execTool.Execute(context.Background(), map[string]any{
|
||||
"command": "cat " + filepath.Base(mediaPath),
|
||||
"working_dir": mediaDir,
|
||||
})
|
||||
if execResult.IsError {
|
||||
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(execResult.ForLLM, "attachment content") {
|
||||
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+42
-26
@@ -55,15 +55,17 @@ type AgentLoop struct {
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
SenderID string // Current sender ID for dynamic context
|
||||
SenderDisplayName string // Current sender display name for dynamic context
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -117,6 +119,8 @@ func registerSharedTools(
|
||||
registry *AgentRegistry,
|
||||
provider providers.LLMProvider,
|
||||
) {
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
@@ -161,7 +165,8 @@ func registerSharedTools(
|
||||
50000,
|
||||
cfg.Tools.Web.Proxy,
|
||||
cfg.Tools.Web.Format,
|
||||
cfg.Tools.Web.FetchLimitBytes)
|
||||
cfg.Tools.Web.FetchLimitBytes,
|
||||
cfg.Tools.Web.PrivateHostWhitelist)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
@@ -199,6 +204,7 @@ func registerSharedTools(
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
cfg.Agents.Defaults.GetMaxMediaSize(),
|
||||
nil,
|
||||
allowReadPaths,
|
||||
)
|
||||
agent.Tools.Register(sendFileTool)
|
||||
}
|
||||
@@ -226,20 +232,26 @@ func registerSharedTools(
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
if cfg.Tools.IsToolEnabled("spawn") {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
// Spawn and spawn_status tools share a SubagentManager.
|
||||
// Construct it when either tool is enabled (both require subagent).
|
||||
spawnEnabled := cfg.Tools.IsToolEnabled("spawn")
|
||||
spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status")
|
||||
if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
if spawnEnabled {
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
|
||||
})
|
||||
agent.Tools.Register(spawnTool)
|
||||
} else {
|
||||
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
|
||||
}
|
||||
if spawnStatusEnabled {
|
||||
agent.Tools.Register(tools.NewSpawnStatusTool(subagentManager))
|
||||
}
|
||||
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
|
||||
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -736,14 +748,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
})
|
||||
|
||||
opts := processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
SenderID: msg.SenderID,
|
||||
SenderDisplayName: msg.Sender.DisplayName,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
}
|
||||
|
||||
// context-dependent commands check their own Runtime fields and report
|
||||
@@ -883,6 +897,8 @@ func (al *AgentLoop) runAgentLoop(
|
||||
opts.Media,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
opts.SenderID,
|
||||
opts.SenderDisplayName,
|
||||
)
|
||||
|
||||
// Resolve media:// refs: images→base64 data URLs, non-images→local paths in content
|
||||
@@ -1154,7 +1170,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
|
||||
messages = agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
nil, opts.Channel, opts.ChatID,
|
||||
nil, opts.Channel, opts.ChatID, opts.SenderID, opts.SenderDisplayName,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -30,6 +30,28 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
type recordingProvider struct {
|
||||
lastMessages []providers.Message
|
||||
}
|
||||
|
||||
func (r *recordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastMessages = append([]providers.Message(nil), messages...)
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *recordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
@@ -54,6 +76,59 @@ func newTestAgentLoop(
|
||||
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "discord",
|
||||
SenderID: "discord:123",
|
||||
Sender: bus.SenderInfo{
|
||||
DisplayName: "Alice",
|
||||
},
|
||||
ChatID: "group-1",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(provider.lastMessages) == 0 {
|
||||
t.Fatal("provider did not receive any messages")
|
||||
}
|
||||
|
||||
systemPrompt := provider.lastMessages[0].Content
|
||||
wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)"
|
||||
if !strings.Contains(systemPrompt, wantSender) {
|
||||
t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt)
|
||||
}
|
||||
|
||||
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
||||
if lastMessage.Role != "user" || lastMessage.Content != "hello" {
|
||||
t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -618,7 +618,7 @@ func (c *FeishuChannel) downloadResource(
|
||||
}
|
||||
|
||||
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
mediaDir := media.TempDir()
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
|
||||
"error": mkdirErr.Error(),
|
||||
|
||||
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
if len(m.channels) == 0 {
|
||||
logger.WarnC("channels", "No channels enabled")
|
||||
return errors.New("no channels enabled")
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "Starting all channels")
|
||||
@@ -397,7 +396,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -35,8 +35,6 @@ const (
|
||||
roomKindCacheTTL = 5 * time.Minute
|
||||
roomKindCacheCleanupPeriod = 1 * time.Minute
|
||||
roomKindCacheMaxEntries = 2048
|
||||
|
||||
matrixMediaTempDirName = "picoclaw_media"
|
||||
)
|
||||
|
||||
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
|
||||
@@ -1105,7 +1103,7 @@ func (c *MatrixChannel) stripSelfMention(text string) string {
|
||||
}
|
||||
|
||||
func matrixMediaTempDir() (string, error) {
|
||||
mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName)
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
|
||||
@@ -165,7 +166,7 @@ func TestMatrixMediaTempDir(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("matrixMediaTempDir failed: %v", err)
|
||||
}
|
||||
if filepath.Base(dir) != matrixMediaTempDirName {
|
||||
if filepath.Base(dir) != media.TempDirName {
|
||||
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
|
||||
}
|
||||
|
||||
|
||||
@@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, nil)
|
||||
// Echo the matched subprotocol back so the browser accepts the upgrade.
|
||||
var responseHeader http.Header
|
||||
if proto := c.matchedSubprotocol(r); proto != "" {
|
||||
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
|
||||
if err != nil {
|
||||
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
go c.readLoop(pc)
|
||||
}
|
||||
|
||||
// authenticate checks the Bearer token from the Authorization header.
|
||||
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
|
||||
// authenticate checks the request for a valid token:
|
||||
// 1. Authorization: Bearer <token> header
|
||||
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
|
||||
// 3. Query parameter "token" (only when AllowTokenQuery is on)
|
||||
func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
token := c.config.Token
|
||||
if token == "" {
|
||||
@@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
|
||||
if c.matchedSubprotocol(r) != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check query parameter only when explicitly allowed
|
||||
if c.config.AllowTokenQuery {
|
||||
if r.URL.Query().Get("token") == token {
|
||||
@@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
|
||||
// the configured token, or "" if none do.
|
||||
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
|
||||
token := c.config.Token
|
||||
for _, proto := range websocket.Subprotocols(r) {
|
||||
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
|
||||
return proto
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readLoop reads messages from a WebSocket connection.
|
||||
func (c *PicoChannel) readLoop(pc *picoConn) {
|
||||
defer func() {
|
||||
|
||||
+80
-6
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
@@ -623,8 +625,9 @@ func (c *ModelConfig) Validate() error {
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
|
||||
}
|
||||
|
||||
type ToolDiscoveryConfig struct {
|
||||
@@ -695,11 +698,13 @@ type WebToolsConfig struct {
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
|
||||
Format string `json:"format,omitempty" env:"PICOCLAW_TOOLS_WEB_FORMAT"`
|
||||
PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
AllowCommand bool ` env:"PICOCLAW_TOOLS_CRON_ALLOW_COMMAND" json:"allow_command"`
|
||||
}
|
||||
|
||||
type ExecConfig struct {
|
||||
@@ -749,6 +754,7 @@ type ToolsConfig struct {
|
||||
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SpawnStatus ToolConfig `json:"spawn_status" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
|
||||
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
|
||||
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
|
||||
@@ -838,10 +844,24 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
|
||||
m.ModelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := resolveAPIKeys(cfg.ModelList, filepath.Dir(path)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
@@ -858,6 +878,48 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextAPIKeys returns a copy of models with plaintext api_key values
|
||||
// encrypted. Returns (nil, nil) when nothing changed (all keys already sealed or
|
||||
// empty). Returns (nil, error) if any key fails to encrypt — callers must treat
|
||||
// this as a hard failure to prevent a mixed plaintext/ciphertext state on disk.
|
||||
// Symmetric counterpart of resolveAPIKeys: both operate purely on []ModelConfig
|
||||
// and leave JSON marshaling to the caller.
|
||||
func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelConfig, error) {
|
||||
sealed := make([]ModelConfig, len(models))
|
||||
copy(sealed, models)
|
||||
changed := false
|
||||
for i := range sealed {
|
||||
m := &sealed[i]
|
||||
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
|
||||
continue
|
||||
}
|
||||
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot seal api_key for model %q: %w", m.ModelName, err)
|
||||
}
|
||||
m.APIKey = encrypted
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
return nil, nil
|
||||
}
|
||||
return sealed, nil
|
||||
}
|
||||
|
||||
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
|
||||
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
|
||||
func resolveAPIKeys(models []ModelConfig, configDir string) error {
|
||||
cr := credential.NewResolver(configDir)
|
||||
for i := range models {
|
||||
resolved, err := cr.Resolve(models[i].APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
|
||||
}
|
||||
models[i].APIKey = resolved
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) migrateChannelConfigs() {
|
||||
// Discord: mention_only -> group_trigger.mention_only
|
||||
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
@@ -872,12 +934,22 @@ func (c *Config) migrateChannelConfigs() {
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
sealed, err := encryptPlaintextAPIKeys(cfg.ModelList, passphrase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sealed != nil {
|
||||
tmp := *cfg
|
||||
tmp.ModelList = sealed
|
||||
cfg = &tmp
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use unified atomic write utility with explicit sync for flash storage reliability.
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
@@ -1044,6 +1116,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.ReadFile.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spawn_status":
|
||||
return t.SpawnStatus.Enabled
|
||||
case "spi":
|
||||
return t.SPI.Enabled
|
||||
case "subagent":
|
||||
|
||||
+386
-5
@@ -7,8 +7,22 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
// mustSetupSSHKey generates a temporary Ed25519 SSH key in t.TempDir() and sets
|
||||
// PICOCLAW_SSH_KEY_PATH to its path for the duration of the test. This is required
|
||||
// whenever a test exercises encryption/decryption via credential.Encrypt or SaveConfig.
|
||||
func mustSetupSSHKey(t *testing.T) {
|
||||
t.Helper()
|
||||
keyPath := filepath.Join(t.TempDir(), "picoclaw_ed25519.key")
|
||||
if err := credential.GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("mustSetupSSHKey: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", keyPath)
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
|
||||
@@ -253,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
if cfg.Gateway.HotReload {
|
||||
t.Error("Gateway hot reload should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Providers verifies provider structure
|
||||
@@ -391,6 +408,13 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Tools.Cron.AllowCommand {
|
||||
t.Fatal("DefaultConfig().Tools.Cron.AllowCommand should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -423,6 +447,22 @@ func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_CronAllowCommandDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"tools":{"cron":{"exec_timeout_minutes":5}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if !cfg.Tools.Cron.AllowCommand {
|
||||
t.Fatal("tools.cron.allow_command should remain true when unset in config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -482,13 +522,19 @@ func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
|
||||
// Unset to ensure we test the default
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
// Set a known home for consistent test results
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
var fakeHome string
|
||||
if runtime.GOOS == "windows" {
|
||||
fakeHome = `C:\tmp\home`
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
} else {
|
||||
fakeHome = "/tmp/home"
|
||||
t.Setenv("HOME", fakeHome)
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
|
||||
want := filepath.Join(fakeHome, ".picoclaw", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -499,7 +545,7 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := "/custom/picoclaw/home/workspace"
|
||||
want := filepath.Join("/custom/picoclaw/home", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -621,3 +667,338 @@ func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoadConfig_WarnsForPlaintextAPIKey verifies that LoadConfig resolves a plaintext
|
||||
// api_key into memory but does NOT rewrite the config file. File writes are the sole
|
||||
// responsibility of SaveConfig.
|
||||
func TestLoadConfig_WarnsForPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
const original = `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(original), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
// In-memory value must be the resolved plaintext.
|
||||
if cfg.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("in-memory api_key = %q, want %q", cfg.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
// The file on disk must remain unchanged — LoadConfig must not write anything.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if string(raw) != original {
|
||||
t.Errorf("LoadConfig must not modify the config file; got:\n%s", string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_EncryptsPlaintextAPIKey verifies that SaveConfig writes enc:// ciphertext
|
||||
// to disk and that a subsequent LoadConfig decrypts it back to the original plaintext.
|
||||
func TestSaveConfig_EncryptsPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
// Disk must contain enc://, not the raw key.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("saved file should contain enc://, got:\n%s", string(raw))
|
||||
}
|
||||
if strings.Contains(string(raw), "sk-plaintext") {
|
||||
t.Errorf("saved file must not contain the plaintext key")
|
||||
}
|
||||
|
||||
// A fresh load must decrypt back to the original plaintext.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
if cfg2.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("loaded api_key = %q, want %q", cfg2.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_NoSealWithoutPassphrase verifies that api_key values are left
|
||||
// unchanged when PICOCLAW_KEY_PASSPHRASE is not set.
|
||||
func TestLoadConfig_NoSealWithoutPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("config file must not be modified when no passphrase is set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_FileRefNotSealed verifies that file:// api_key references are not
|
||||
// converted to enc:// values (they are resolved at runtime by the Resolver).
|
||||
func TestLoadConfig_FileRefNotSealed(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
keyFile := filepath.Join(dir, "openai.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"file://openai.key"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "file://openai.key") {
|
||||
t.Error("file:// reference should be preserved unchanged in the config file")
|
||||
}
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("file:// reference must not be converted to enc://")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_MixedKeys verifies that SaveConfig encrypts only plaintext api_keys
|
||||
// and leaves already-encrypted (enc://) and file:// entries unchanged.
|
||||
func TestSaveConfig_MixedKeys(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Pre-encrypt one key so we have a genuine enc:// value to put in the config.
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "pre", Model: "openai/gpt-4", APIKey: "sk-already-plain"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
// Extract the enc:// value from the saved file.
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil || len(tmp.ModelList) == 0 {
|
||||
t.Fatalf("setup: could not parse saved config: %v", err)
|
||||
}
|
||||
alreadyEncrypted := tmp.ModelList[0].APIKey
|
||||
if !strings.HasPrefix(alreadyEncrypted, "enc://") {
|
||||
t.Fatalf("setup: expected enc:// key, got %q", alreadyEncrypted)
|
||||
}
|
||||
|
||||
// Build a config with three models:
|
||||
// 1. plaintext → must be encrypted by SaveConfig
|
||||
// 2. enc:// → must be left unchanged (already encrypted)
|
||||
// 3. file:// → must be left unchanged (file reference)
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "plain", Model: "openai/gpt-4", APIKey: "sk-new-plaintext"},
|
||||
{ModelName: "enc", Model: "openai/gpt-4", APIKey: alreadyEncrypted},
|
||||
{ModelName: "file", Model: "openai/gpt-4", APIKey: "file://api.key"},
|
||||
},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ = os.ReadFile(cfgPath)
|
||||
s := string(raw)
|
||||
|
||||
// 1. Plaintext must be encrypted.
|
||||
if strings.Contains(s, "sk-new-plaintext") {
|
||||
t.Error("plaintext key must not appear in saved file")
|
||||
}
|
||||
// 2. The pre-existing enc:// value must still be present (byte-for-byte unchanged).
|
||||
if !strings.Contains(s, alreadyEncrypted) {
|
||||
t.Error("pre-existing enc:// entry must be preserved unchanged")
|
||||
}
|
||||
// 3. file:// must be preserved.
|
||||
if !strings.Contains(s, "file://api.key") {
|
||||
t.Error("file:// reference must be preserved unchanged")
|
||||
}
|
||||
|
||||
// Now load and verify all three decrypt/resolve correctly.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
byName := make(map[string]string)
|
||||
for _, m := range cfg2.ModelList {
|
||||
byName[m.ModelName] = m.APIKey
|
||||
}
|
||||
if byName["plain"] != "sk-new-plaintext" {
|
||||
t.Errorf("plain model api_key = %q, want %q", byName["plain"], "sk-new-plaintext")
|
||||
}
|
||||
if byName["enc"] != "sk-already-plain" {
|
||||
t.Errorf("enc model api_key = %q, want %q", byName["enc"], "sk-already-plain")
|
||||
}
|
||||
if byName["file"] != "sk-from-file" {
|
||||
t.Errorf("file model api_key = %q, want %q", byName["file"], "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_MixedKeys_NoPassphrase verifies that when PICOCLAW_KEY_PASSPHRASE
|
||||
// is not set, enc:// entries cause LoadConfig to return an error, while plaintext
|
||||
// and file:// entries in the same config are not affected.
|
||||
func TestLoadConfig_MixedKeys_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// First encrypt a key so we have a real enc:// value.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "m", Model: "openai/gpt-4", APIKey: "sk-secret"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil {
|
||||
t.Fatalf("setup parse: %v", err)
|
||||
}
|
||||
encValue := tmp.ModelList[0].APIKey
|
||||
|
||||
// Write a mixed config: enc:// + plaintext + file://
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
mixed, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "enc", "model": "openai/gpt-4", "api_key": encValue},
|
||||
{"model_name": "plain", "model": "openai/gpt-4", "api_key": "sk-plain"},
|
||||
{"model_name": "file", "model": "openai/gpt-4", "api_key": "file://api.key"},
|
||||
},
|
||||
})
|
||||
if err := os.WriteFile(cfgPath, mixed, 0o600); err != nil {
|
||||
t.Fatalf("setup write: %v", err)
|
||||
}
|
||||
|
||||
// Now clear the passphrase — LoadConfig must fail because enc:// cannot be decrypted.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
_, err := LoadConfig(cfgPath)
|
||||
if err == nil {
|
||||
t.Fatal("LoadConfig should fail when enc:// key is present and no passphrase is set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "passphrase required") {
|
||||
t.Errorf("error should mention passphrase required, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_UsesPassphraseProvider verifies that SaveConfig encrypts plaintext
|
||||
// api_keys using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
// This matters for the launcher, which clears the environment variable and redirects
|
||||
// PassphraseProvider to an in-memory SecureStore.
|
||||
func TestSaveConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty — passphrase must come from PassphraseProvider only.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Replace PassphraseProvider with an in-memory function (simulating SecureStore).
|
||||
const testPassphrase = "provider-passphrase"
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("SaveConfig should have encrypted plaintext key via PassphraseProvider; got:\n%s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_UsesPassphraseProvider verifies that LoadConfig decrypts enc:// keys
|
||||
// using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty throughout.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
const testPassphrase = "provider-passphrase"
|
||||
const plainKey = "sk-secret"
|
||||
|
||||
// First, encrypt the key using the same passphrase.
|
||||
encrypted, err := credential.Encrypt(testPassphrase, "", plainKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "test", "model": "openai/gpt-4", "api_key": encrypted},
|
||||
},
|
||||
})
|
||||
if err = os.WriteFile(cfgPath, raw, 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Redirect PassphraseProvider — env var is empty, so without this the load would fail.
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
if cfg.ModelList[0].APIKey != plainKey {
|
||||
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -395,8 +395,9 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
HotReload: false,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
MediaCleanup: MediaCleanupConfig{
|
||||
@@ -453,6 +454,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
},
|
||||
ExecTimeoutMinutes: 5,
|
||||
AllowCommand: true,
|
||||
},
|
||||
Exec: ExecConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
@@ -522,6 +524,9 @@ func DefaultConfig() *Config {
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
SpawnStatus: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
SPI: ToolConfig{
|
||||
Enabled: false, // Hardware tool - Linux only
|
||||
},
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
// Package credential resolves API credential values for model_list entries.
|
||||
//
|
||||
// An API key is a form of authorization credential. This package centralizes
|
||||
// how raw credential strings—plaintext or file references—are resolved into
|
||||
// their actual values, keeping that logic out of the config loader.
|
||||
//
|
||||
// Supported formats for the api_key field:
|
||||
//
|
||||
// - Plaintext: "sk-abc123" → returned as-is
|
||||
// - File ref: "file://filename.key" → content read from configDir/filename.key
|
||||
// - Encrypted: "enc://<base64>" → AES-256-GCM decrypt via PICOCLAW_KEY_PASSPHRASE
|
||||
// - Empty: "" → returned as-is (auth_method=oauth etc.)
|
||||
//
|
||||
// Encryption uses AES-256-GCM with HKDF-SHA256 key derivation (< 1ms, safe for embedded Linux).
|
||||
// An SSH private key is required for both encryption and decryption.
|
||||
// Key derivation:
|
||||
//
|
||||
// HKDF-SHA256(ikm=HMAC-SHA256(SHA256(sshKeyBytes), passphrase), salt, info)
|
||||
//
|
||||
// SSH key path resolution priority:
|
||||
//
|
||||
// 1. sshKeyPath argument to Encrypt (explicit)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform)
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hkdf"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PassphraseEnvVar is the environment variable that holds the encryption passphrase.
|
||||
// Other packages (e.g. config) reference this constant to avoid duplicating the string.
|
||||
const PassphraseEnvVar = "PICOCLAW_KEY_PASSPHRASE"
|
||||
|
||||
// PassphraseProvider is the function used to retrieve the passphrase for enc://
|
||||
// credential decryption. It defaults to reading PICOCLAW_KEY_PASSPHRASE from the
|
||||
// process environment. Replace it at startup to use a different source, such as
|
||||
// an in-memory SecureStore, so that all LoadConfig() calls everywhere share the
|
||||
// same passphrase source without needing os.Environ.
|
||||
//
|
||||
// Example (launcher main.go):
|
||||
//
|
||||
// credential.PassphraseProvider = apiHandler.passphraseStore.Get
|
||||
var PassphraseProvider func() string = func() string {
|
||||
return os.Getenv(PassphraseEnvVar)
|
||||
}
|
||||
|
||||
// ErrPassphraseRequired is returned when an enc:// credential is encountered but
|
||||
// no passphrase is available from PassphraseProvider. Callers can detect this
|
||||
// with errors.Is to distinguish a missing-passphrase condition from other errors.
|
||||
var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
|
||||
|
||||
// ErrDecryptionFailed is returned when an enc:// credential cannot be decrypted,
|
||||
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
|
||||
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
|
||||
|
||||
const (
|
||||
fileScheme = "file://"
|
||||
encScheme = "enc://"
|
||||
hkdfInfo = "picoclaw-credential-v1"
|
||||
saltLen = 16
|
||||
nonceLen = 12
|
||||
keyLen = 32
|
||||
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
|
||||
)
|
||||
|
||||
// Resolver resolves raw credential strings for model_list api_key fields.
|
||||
// File references are resolved relative to the directory of the config file.
|
||||
type Resolver struct {
|
||||
configDir string
|
||||
resolvedConfigDir string // symlink-resolved form of configDir
|
||||
}
|
||||
|
||||
// NewResolver returns a Resolver that resolves file:// references relative to
|
||||
// configDir (typically filepath.Dir of the config file path).
|
||||
func NewResolver(configDir string) *Resolver {
|
||||
resolved := configDir
|
||||
if configDir != "" {
|
||||
if linkedPath, err := filepath.EvalSymlinks(configDir); err == nil {
|
||||
resolved = linkedPath
|
||||
}
|
||||
}
|
||||
return &Resolver{configDir: configDir, resolvedConfigDir: resolved}
|
||||
}
|
||||
|
||||
// Resolve returns the actual credential value for raw:
|
||||
//
|
||||
// - "" → "" (no error; auth_method=oauth needs no key)
|
||||
// - "file://name.key" → trimmed content of configDir/name.key
|
||||
// - anything else → raw unchanged (plaintext credential)
|
||||
func (r *Resolver) Resolve(raw string) (string, error) {
|
||||
if raw == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, fileScheme) {
|
||||
fileName := strings.TrimSpace(strings.TrimPrefix(raw, fileScheme))
|
||||
if fileName == "" {
|
||||
return "", fmt.Errorf("credential: file:// reference has no filename")
|
||||
}
|
||||
|
||||
baseDir := r.resolvedConfigDir
|
||||
if baseDir == "" {
|
||||
baseDir = r.configDir
|
||||
}
|
||||
keyPath := filepath.Join(baseDir, fileName)
|
||||
// Resolve symlinks before enforcing containment to prevent escaping via symlinks.
|
||||
realKeyPath, err := filepath.EvalSymlinks(keyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to resolve credential file path %q: %w", keyPath, err)
|
||||
}
|
||||
if !isWithinDir(realKeyPath, baseDir) {
|
||||
return "", fmt.Errorf("credential: file:// path escapes config directory")
|
||||
}
|
||||
data, err := os.ReadFile(realKeyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to read credential file %q: %w", realKeyPath, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(data))
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("credential: credential file %q is empty", realKeyPath)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, encScheme) {
|
||||
return resolveEncrypted(raw)
|
||||
}
|
||||
|
||||
// Plaintext credential — return unchanged.
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// resolveEncrypted decrypts an enc:// credential using PassphraseProvider.
|
||||
func resolveEncrypted(raw string) (string, error) {
|
||||
passphrase := PassphraseProvider()
|
||||
if passphrase == "" {
|
||||
return "", ErrPassphraseRequired
|
||||
}
|
||||
|
||||
sshKeyPath := pickSSHKeyPath("") // override="": consult env then auto-detect
|
||||
|
||||
b64 := strings.TrimPrefix(raw, encScheme)
|
||||
blob, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// invalid base64: %w", err)
|
||||
}
|
||||
if len(blob) < saltLen+nonceLen+1 {
|
||||
return "", fmt.Errorf("credential: enc:// payload too short")
|
||||
}
|
||||
|
||||
salt := blob[:saltLen]
|
||||
nonce := blob[saltLen : saltLen+nonceLen]
|
||||
ciphertext := blob[saltLen+nonceLen:]
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// gcm init: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %w", ErrDecryptionFailed, err)
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext and returns an enc:// credential string.
|
||||
//
|
||||
// passphrase is required (PICOCLAW_KEY_PASSPHRASE value).
|
||||
// sshKeyPath is the SSH private key file to use; pass "" to auto-detect via
|
||||
// PICOCLAW_SSH_KEY_PATH env var or ~/.ssh/picoclaw_ed25519.key.
|
||||
// An SSH private key must be resolvable or Encrypt returns an error.
|
||||
func Encrypt(passphrase, sshKeyPath, plaintext string) (string, error) {
|
||||
if passphrase == "" {
|
||||
return "", fmt.Errorf("credential: passphrase must not be empty")
|
||||
}
|
||||
sshKeyPath = pickSSHKeyPath(sshKeyPath)
|
||||
|
||||
salt := make([]byte, saltLen)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: gcm init: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, nonceLen)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
|
||||
blob := make([]byte, 0, saltLen+nonceLen+len(ciphertext))
|
||||
blob = append(blob, salt...)
|
||||
blob = append(blob, nonce...)
|
||||
blob = append(blob, ciphertext...)
|
||||
return encScheme + base64.StdEncoding.EncodeToString(blob), nil
|
||||
}
|
||||
|
||||
// isWithinDir reports whether path is contained within (or equal to) dir.
|
||||
// Uses filepath.IsLocal on the relative path for robust cross-platform traversal detection.
|
||||
func isWithinDir(path, dir string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(dir), filepath.Clean(path))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
}
|
||||
|
||||
// allowedSSHKeyPath reports whether path is in a permitted location for SSH key files:
|
||||
// - exact match with PICOCLAW_SSH_KEY_PATH env var
|
||||
// - within the PICOCLAW_HOME env var directory
|
||||
// - within ~/.ssh/
|
||||
func allowedSSHKeyPath(path string) bool {
|
||||
if path == "" {
|
||||
return true // passphrase-only mode; no file will be read
|
||||
}
|
||||
clean := filepath.Clean(path)
|
||||
|
||||
// Exact match with PICOCLAW_SSH_KEY_PATH.
|
||||
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
|
||||
if clean == filepath.Clean(envPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within PICOCLAW_HOME.
|
||||
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
|
||||
if isWithinDir(clean, picoHome) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within ~/.ssh/.
|
||||
if userHome, err := os.UserHomeDir(); err == nil {
|
||||
if isWithinDir(clean, filepath.Join(userHome, ".ssh")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// deriveKey derives a 32-byte AES-256 key from passphrase and SSH private key.
|
||||
//
|
||||
// ikm = HMAC-SHA256(key=SHA256(sshKeyBytes), msg=passphrase)
|
||||
// Final key: HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
|
||||
// sshKeyPath must be non-empty; returns an error otherwise.
|
||||
func deriveKey(passphrase, sshKeyPath string, salt []byte) ([]byte, error) {
|
||||
if sshKeyPath == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH private key is required but not found" +
|
||||
" (set PICOCLAW_SSH_KEY_PATH or place key at ~/.ssh/picoclaw_ed25519.key)")
|
||||
}
|
||||
if !allowedSSHKeyPath(sshKeyPath) {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH key path %q is not in an allowed location (PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/)",
|
||||
sshKeyPath,
|
||||
)
|
||||
}
|
||||
sshBytes, err := os.ReadFile(sshKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: cannot read SSH key %q: %w", sshKeyPath, err)
|
||||
}
|
||||
sshHash := sha256.Sum256(sshBytes)
|
||||
mac := hmac.New(sha256.New, sshHash[:])
|
||||
mac.Write([]byte(passphrase))
|
||||
ikm := mac.Sum(nil)
|
||||
|
||||
key, err := hkdf.Key(sha256.New, ikm, salt, hkdfInfo, keyLen)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: HKDF expand failed: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// pickSSHKeyPath returns the SSH private key path to use for encryption/decryption.
|
||||
//
|
||||
// Priority:
|
||||
// 1. override (non-empty explicit argument)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (auto-detection)
|
||||
//
|
||||
// Returns "" when no key is found; deriveKey will return an error in that case.
|
||||
func pickSSHKeyPath(override string) string {
|
||||
if override != "" {
|
||||
return override
|
||||
}
|
||||
if p, ok := os.LookupEnv(sshKeyEnv); ok {
|
||||
return p // respect explicit setting, even if ""
|
||||
}
|
||||
return findDefaultSSHKey()
|
||||
}
|
||||
|
||||
// findDefaultSSHKey returns the picoclaw-specific SSH key path if it exists.
|
||||
func findDefaultSSHKey() string {
|
||||
p, err := DefaultSSHKeyPath()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package credential_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
func TestResolve_PlainKey(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve("sk-plaintext-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-plaintext-key" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-plaintext-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Success(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "openai_plain.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte("sk-from-file\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://" + keyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-from-file" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_NotFound(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("file://missing.key")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Empty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "empty.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte(" \n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
_, err := r.Resolve("file://" + keyFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty credential file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_RoundTrip tests basic encryption/decryption round-trip with an SSH key.
|
||||
func TestResolve_EncKey_RoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase-32bytes-long-ok!"
|
||||
const plaintext = "sk-encrypted-secret"
|
||||
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, "", plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_WithSSHKey tests that the SSH key file is incorporated into key derivation.
|
||||
func TestResolve_EncKey_WithSSHKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-private-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase"
|
||||
const plaintext = "sk-ssh-protected-secret"
|
||||
|
||||
// Set PICOCLAW_SSH_KEY_PATH before Encrypt so the path passes allowedSSHKeyPath validation.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, sshKeyPath, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("some-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when PICOCLAW_KEY_PASSPHRASE is unset, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_BadCiphertext(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://!!not-valid-base64!!")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_PayloadTooShort(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
// Valid base64 but fewer bytes than salt(16)+nonce(12)+1 minimum.
|
||||
import64 := "dG9vc2hvcnQ=" // "tooshort" = 8 bytes
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://" + import64)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for too-short enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_WrongPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("correct-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "wrong-passphrase")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected decryption error for wrong passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt_EmptyPassphrase(t *testing.T) {
|
||||
_, err := credential.Encrypt("", "", "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKey_SSHKeyNotFound(t *testing.T) {
|
||||
// Encrypt with a real SSH key path, then try to decrypt with a missing path.
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Register the real key path so allowedSSHKeyPath validation passes for Encrypt.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Point to a non-existent SSH key so deriveKey's ReadFile fails.
|
||||
// The path is still under the same dir, so allowedSSHKeyPath passes (exact env match).
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", filepath.Join(dir, "nonexistent_key"))
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when SSH key file is missing, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_PathTraversal verifies that file:// references cannot escape configDir
|
||||
// via relative traversal ("../../etc/passwd") or absolute paths ("/abs/path").
|
||||
func TestResolve_FileRef_PathTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
// Create a file outside configDir that the traversal would point to.
|
||||
outsideFile := filepath.Join(t.TempDir(), "secret.key")
|
||||
if err := os.WriteFile(outsideFile, []byte("stolen"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(filepath.Dir(cfgPath))
|
||||
|
||||
cases := []string{
|
||||
"file://../../secret.key",
|
||||
"file://../secret.key",
|
||||
"file://" + outsideFile, // absolute path
|
||||
}
|
||||
for _, raw := range cases {
|
||||
_, err := r.Resolve(raw)
|
||||
if err == nil {
|
||||
t.Errorf("Resolve(%q): expected path traversal error, got nil", raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_withinConfigDir verifies that a legitimate relative file:// ref works.
|
||||
func TestResolve_FileRef_withinConfigDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "my.key"), []byte("sk-valid\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://my.key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-valid" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-valid")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncrypt_SSHKeyOutsideAllowedDirs verifies that Encrypt rejects SSH key paths
|
||||
// that are not under PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/.
|
||||
func TestEncrypt_SSHKeyOutsideAllowedDirs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Make sure none of the allowed env vars point here.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
|
||||
_, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for SSH key outside allowed directories, got nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// DefaultSSHKeyPath returns the canonical path for the picoclaw-specific SSH key.
|
||||
// The path is always ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform).
|
||||
func DefaultSSHKeyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cannot determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".ssh", "picoclaw_ed25519.key"), nil
|
||||
}
|
||||
|
||||
// GenerateSSHKey generates an Ed25519 SSH key pair and writes the private key
|
||||
// to path (permissions 0600) and the public key to path+".pub" (permissions 0644).
|
||||
// The ~/.ssh/ directory is created with 0700 if it does not exist.
|
||||
// If the files already exist they are overwritten.
|
||||
func GenerateSSHKey(path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("credential: keygen: cannot create directory %q: %w", filepath.Dir(path), err)
|
||||
}
|
||||
|
||||
pubRaw, privRaw, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: ed25519 key generation failed: %w", err)
|
||||
}
|
||||
|
||||
// Marshal private key as OpenSSH PEM.
|
||||
block, err := ssh.MarshalPrivateKey(privRaw, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal private key: %w", err)
|
||||
}
|
||||
privPEM := pem.EncodeToMemory(block)
|
||||
|
||||
if err = os.WriteFile(path, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write private key %q: %w", path, err)
|
||||
}
|
||||
|
||||
// Marshal public key as authorized_keys line.
|
||||
sshPub, err := ssh.NewPublicKey(pubRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal public key: %w", err)
|
||||
}
|
||||
pubLine := ssh.MarshalAuthorizedKey(sshPub)
|
||||
|
||||
pubPath := path + ".pub"
|
||||
if err := os.WriteFile(pubPath, pubLine, 0o644); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write public key %q: %w", pubPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestGenerateSSHKey_CreatesFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
// Private key must exist.
|
||||
privInfo, err := os.Stat(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("private key file missing: %v", err)
|
||||
}
|
||||
|
||||
// Check permissions on non-Windows (Windows does not support Unix permission bits).
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := privInfo.Mode().Perm(); got != 0o600 {
|
||||
t.Errorf("private key permissions = %04o, want 0600", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Public key must exist.
|
||||
pubPath := keyPath + ".pub"
|
||||
pubInfo, err := os.Stat(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("public key file missing: %v", err)
|
||||
}
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := pubInfo.Mode().Perm(); got != 0o644 {
|
||||
t.Errorf("public key permissions = %04o, want 0644", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Private key must be parseable as an OpenSSH ed25519 key.
|
||||
privPEM, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read private key: %v", err)
|
||||
}
|
||||
privKey, err := ssh.ParseRawPrivateKey(privPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("parse private key: %v", err)
|
||||
}
|
||||
if _, ok := privKey.(*ed25519.PrivateKey); !ok {
|
||||
t.Errorf("private key type = %T, want *ed25519.PrivateKey", privKey)
|
||||
}
|
||||
|
||||
// Public key must be parseable as authorized_keys line.
|
||||
pubBytes, err := os.ReadFile(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read public key: %v", err)
|
||||
}
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(pubBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("parse public key: %v", err)
|
||||
}
|
||||
if pubKey == nil {
|
||||
t.Fatal("expected non-nil public key")
|
||||
}
|
||||
if len(rest) > 0 {
|
||||
t.Errorf("unexpected trailing bytes after public key: %d bytes", len(rest))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_OverwritesExisting(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
// Generate twice; second call must not error and must produce a different key.
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("first GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
first, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first key: %v", err)
|
||||
}
|
||||
|
||||
if err = GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("second GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
second, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second key: %v", err)
|
||||
}
|
||||
|
||||
// Two independently generated Ed25519 keys must differ.
|
||||
if string(first) == string(second) {
|
||||
t.Error("expected overwritten key to differ from original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_CreatesDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Nested directory that does not yet exist.
|
||||
keyPath := filepath.Join(dir, "subdir", ".ssh", "picoclaw_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(keyPath); err != nil {
|
||||
t.Fatalf("private key not created: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package credential
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// SecureStore holds a passphrase in memory.
|
||||
//
|
||||
// Uses atomic.Pointer so reads and writes are lock-free.
|
||||
// The passphrase is never written to disk; callers decide how to
|
||||
// transport it outside this store (e.g., via cmd.Env or os.Environ).
|
||||
type SecureStore struct {
|
||||
val atomic.Pointer[string]
|
||||
}
|
||||
|
||||
// NewSecureStore creates an empty SecureStore.
|
||||
func NewSecureStore() *SecureStore {
|
||||
return &SecureStore{}
|
||||
}
|
||||
|
||||
// SetString stores the passphrase. An empty string clears the store.
|
||||
func (s *SecureStore) SetString(passphrase string) {
|
||||
if passphrase == "" {
|
||||
s.val.Store(nil)
|
||||
return
|
||||
}
|
||||
s.val.Store(&passphrase)
|
||||
}
|
||||
|
||||
// Get returns the stored passphrase, or "" if not set.
|
||||
func (s *SecureStore) Get() string {
|
||||
if p := s.val.Load(); p != nil {
|
||||
return *p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSet reports whether a passphrase is currently stored.
|
||||
func (s *SecureStore) IsSet() bool {
|
||||
return s.val.Load() != nil
|
||||
}
|
||||
|
||||
// Clear removes the stored passphrase.
|
||||
func (s *SecureStore) Clear() {
|
||||
s.val.Store(nil)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSecureStore_SetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
if s.IsSet() {
|
||||
t.Error("expected empty store")
|
||||
}
|
||||
|
||||
s.SetString("hunter2")
|
||||
if !s.IsSet() {
|
||||
t.Error("expected store to be set")
|
||||
}
|
||||
if got := s.Get(); got != "hunter2" {
|
||||
t.Errorf("Get() = %q, want %q", got, "hunter2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_Clear(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("secret")
|
||||
s.Clear()
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("expected store to be empty after Clear()")
|
||||
}
|
||||
if got := s.Get(); got != "" {
|
||||
t.Errorf("Get() after Clear() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_SetOverwrites(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("first")
|
||||
s.SetString("second")
|
||||
|
||||
if got := s.Get(); got != "second" {
|
||||
t.Errorf("Get() = %q, want %q", got, "second")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_EmptyPassphrase(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("") // empty → should not mark as set
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("empty passphrase should not mark store as set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_ConcurrentSetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
const goroutines = 10
|
||||
const iterations = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
if id%2 == 0 {
|
||||
s.SetString("even")
|
||||
} else {
|
||||
s.SetString("odd")
|
||||
}
|
||||
_ = s.Get()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
final := s.Get()
|
||||
if final != "" && final != "even" && final != "odd" {
|
||||
t.Errorf("Get() returned unexpected value %q after concurrent Set/Get", final)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,594 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/discord"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/feishu"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/irc"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/line"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/matrix"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp_native"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/devices"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceShutdownTimeout = 30 * time.Second
|
||||
providerReloadTimeout = 30 * time.Second
|
||||
gracefulShutdownTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
type services struct {
|
||||
CronService *cron.CronService
|
||||
HeartbeatService *heartbeat.HeartbeatService
|
||||
MediaStore media.MediaStore
|
||||
ChannelManager *channels.Manager
|
||||
DeviceService *devices.Service
|
||||
HealthServer *health.Server
|
||||
}
|
||||
|
||||
type startupBlockedProvider struct {
|
||||
reason string
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) Chat(
|
||||
_ context.Context,
|
||||
_ []providers.Message,
|
||||
_ []providers.ToolDefinition,
|
||||
_ string,
|
||||
_ map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return nil, fmt.Errorf("%s", p.reason)
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Run starts the gateway runtime using the configuration loaded from configPath.
|
||||
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %w", err)
|
||||
}
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
}
|
||||
|
||||
if modelID != "" {
|
||||
cfg.Agents.Defaults.ModelName = modelID
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
fmt.Println("\n📦 Agent Status:")
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
toolsInfo := startupInfo["tools"].(map[string]any)
|
||||
skillsInfo := startupInfo["skills"].(map[string]any)
|
||||
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
|
||||
fmt.Printf(" • Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
|
||||
|
||||
logger.InfoCF("agent", "Agent initialized",
|
||||
map[string]any{
|
||||
"tools_count": toolsInfo["count"],
|
||||
"skills_total": skillsInfo["total"],
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
var configReloadChan <-chan *config.Config
|
||||
stopWatch := func() {}
|
||||
if cfg.Gateway.HotReload {
|
||||
configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
|
||||
logger.Info("Config hot reload enabled")
|
||||
}
|
||||
defer stopWatch()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sigChan:
|
||||
logger.Info("Shutting down...")
|
||||
shutdownGateway(runningServices, agentLoop, provider, true)
|
||||
return nil
|
||||
case newCfg := <-configReloadChan:
|
||||
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf("Config reload failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createStartupProvider(
|
||||
cfg *config.Config,
|
||||
allowEmptyStartup bool,
|
||||
) (providers.LLMProvider, string, error) {
|
||||
modelName := cfg.Agents.Defaults.GetModelName()
|
||||
if modelName == "" && allowEmptyStartup {
|
||||
reason := "no default model configured; gateway started in limited mode"
|
||||
fmt.Printf("⚠ Warning: %s\n", reason)
|
||||
logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
|
||||
"limited_mode": true,
|
||||
})
|
||||
return &startupBlockedProvider{reason: reason}, "", nil
|
||||
}
|
||||
|
||||
return providers.CreateProvider(cfg)
|
||||
}
|
||||
|
||||
func setupAndStartServices(
|
||||
cfg *config.Config,
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
var err error
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
agentLoop,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error setting up cron service: %w", err)
|
||||
}
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting cron service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Cron service started")
|
||||
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
return nil, fmt.Errorf("error creating channel manager: %w", err)
|
||||
}
|
||||
|
||||
agentLoop.SetChannelManager(runningServices.ChannelManager)
|
||||
agentLoop.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
|
||||
agentLoop.SetTranscriber(transcriber)
|
||||
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
}
|
||||
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err = runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
|
||||
return runningServices, nil
|
||||
}
|
||||
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
if runningServices.ChannelManager != nil {
|
||||
runningServices.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if runningServices.DeviceService != nil {
|
||||
runningServices.DeviceService.Stop()
|
||||
}
|
||||
if runningServices.HeartbeatService != nil {
|
||||
runningServices.HeartbeatService.Stop()
|
||||
}
|
||||
if runningServices.CronService != nil {
|
||||
runningServices.CronService.Stop()
|
||||
}
|
||||
if runningServices.MediaStore != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownGateway(
|
||||
runningServices *services,
|
||||
agentLoop *agent.AgentLoop,
|
||||
provider providers.LLMProvider,
|
||||
fullShutdown bool,
|
||||
) {
|
||||
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
|
||||
logger.Info("✓ Gateway stopped")
|
||||
}
|
||||
|
||||
func handleConfigReload(
|
||||
ctx context.Context,
|
||||
al *agent.AgentLoop,
|
||||
newCfg *config.Config,
|
||||
providerRef *providers.LLMProvider,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
allowEmptyStartup bool,
|
||||
) error {
|
||||
logger.Info("🔄 Config file changed, reloading...")
|
||||
|
||||
newModel := newCfg.Agents.Defaults.ModelName
|
||||
if newModel == "" {
|
||||
newModel = newCfg.Agents.Defaults.Model
|
||||
}
|
||||
|
||||
logger.Infof(" New model is '%s', recreating provider...", newModel)
|
||||
|
||||
logger.Info(" Stopping all services...")
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
|
||||
|
||||
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf(" ⚠ Error creating new provider: %v", err)
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error creating new provider: %w", err)
|
||||
}
|
||||
|
||||
if newModelID != "" {
|
||||
newCfg.Agents.Defaults.ModelName = newModelID
|
||||
}
|
||||
|
||||
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
|
||||
defer reloadCancel()
|
||||
|
||||
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
|
||||
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
|
||||
if cp, ok := newProvider.(providers.StatefulProvider); ok {
|
||||
cp.Close()
|
||||
}
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error reloading agent loop: %w", err)
|
||||
}
|
||||
|
||||
*providerRef = newProvider
|
||||
|
||||
logger.Info(" Restarting all services with new configuration...")
|
||||
if err := restartServices(al, runningServices, msgBus); err != nil {
|
||||
logger.Errorf(" ⚠ Error restarting services: %v", err)
|
||||
return fmt.Errorf("error restarting services: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(" ✓ Provider, configuration, and services reloaded successfully (thread-safe)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restartServices(
|
||||
al *agent.AgentLoop,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
) error {
|
||||
cfg := al.GetConfig()
|
||||
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
var err error
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
al,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Cron service restarted")
|
||||
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Heartbeat service restarted")
|
||||
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
al.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error recreating channel manager: %w", err)
|
||||
}
|
||||
al.SetChannelManager(runningServices.ChannelManager)
|
||||
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println(" ⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return fmt.Errorf("error restarting channels: %w", err)
|
||||
}
|
||||
fmt.Printf(
|
||||
" ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n",
|
||||
cfg.Gateway.Host,
|
||||
cfg.Gateway.Port,
|
||||
)
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err := runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println(" ✓ Device event service restarted")
|
||||
}
|
||||
|
||||
transcriber := voice.DetectTranscriber(cfg)
|
||||
al.SetTranscriber(transcriber)
|
||||
if transcriber != nil {
|
||||
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
} else {
|
||||
logger.InfoCF("voice", "Transcription disabled", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
|
||||
configChan := make(chan *config.Config, 1)
|
||||
stop := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
lastModTime := getFileModTime(configPath)
|
||||
lastSize := getFileSize(configPath)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
currentModTime := getFileModTime(configPath)
|
||||
currentSize := getFileSize(configPath)
|
||||
|
||||
if currentModTime.After(lastModTime) || currentSize != lastSize {
|
||||
if debug {
|
||||
logger.Debugf("🔍 Config file change detected")
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
lastModTime = currentModTime
|
||||
lastSize = currentSize
|
||||
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
logger.Errorf("⚠ Error loading new config: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := newCfg.ValidateModelList(); err != nil {
|
||||
logger.Errorf(" ⚠ New config validation failed: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("✓ Config file validated and loaded")
|
||||
|
||||
select {
|
||||
case configChan <- newCfg:
|
||||
default:
|
||||
logger.Warn("⚠ Previous config reload still in progress, skipping")
|
||||
}
|
||||
}
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
stopFunc := func() {
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
return configChan, stopFunc
|
||||
}
|
||||
|
||||
func getFileModTime(path string) time.Time {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return info.ModTime()
|
||||
}
|
||||
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
|
||||
func setupCronTool(
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
workspace string,
|
||||
restrict bool,
|
||||
execTimeout time.Duration,
|
||||
cfg *config.Config,
|
||||
) (*cron.CronService, error) {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
var cronTool *tools.CronTool
|
||||
if cfg.Tools.IsToolEnabled("cron") {
|
||||
var err error
|
||||
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("critical error during CronTool initialization: %w", err)
|
||||
}
|
||||
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
}
|
||||
|
||||
if cronTool != nil {
|
||||
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||
result := cronTool.ExecuteJob(context.Background(), job)
|
||||
return result, nil
|
||||
})
|
||||
}
|
||||
|
||||
return cronService, nil
|
||||
}
|
||||
|
||||
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
|
||||
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
return tools.SilentResult(response)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -29,6 +30,7 @@ type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
Pid int `json:"pid"`
|
||||
}
|
||||
|
||||
func NewServer(host string, port int) *Server {
|
||||
@@ -112,6 +114,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
Pid: os.Getpid(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
@@ -2,7 +2,20 @@
|
||||
|
||||
package logger
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// botTokenRe matches the bot ID prefix and the secret part of a Telegram bot token.
|
||||
// Groups: 1 = "bot<id>:", 2 = first 4 chars of secret, 3 = middle, 4 = last 4 chars.
|
||||
var botTokenRe = regexp.MustCompile(`(bot\d+:)([A-Za-z0-9_-]{4})[A-Za-z0-9_-]{12,}([A-Za-z0-9_-]{4})`)
|
||||
|
||||
// maskSecrets replaces any embedded bot tokens in s with a redacted placeholder
|
||||
// that keeps the first and last 4 characters of the secret for identification.
|
||||
func maskSecrets(s string) string {
|
||||
return botTokenRe.ReplaceAllString(s, "${1}${2}****${3}")
|
||||
}
|
||||
|
||||
// Logger implements common Logger interface
|
||||
type Logger struct {
|
||||
@@ -12,52 +25,52 @@ type Logger struct {
|
||||
|
||||
// Debug logs debug messages
|
||||
func (b *Logger) Debug(v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Info logs info messages
|
||||
func (b *Logger) Info(v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(INFO, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Warn logs warning messages
|
||||
func (b *Logger) Warn(v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Error logs error messages
|
||||
func (b *Logger) Error(v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(ERROR, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Debugf logs formatted debug messages
|
||||
func (b *Logger) Debugf(format string, v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Infof logs formatted info messages
|
||||
func (b *Logger) Infof(format string, v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(INFO, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Warnf logs formatted warning messages
|
||||
func (b *Logger) Warnf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Warningf logs formatted warning messages
|
||||
func (b *Logger) Warningf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Errorf logs formatted error messages
|
||||
func (b *Logger) Errorf(format string, v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(ERROR, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Fatalf logs formatted fatal messages and exits
|
||||
func (b *Logger) Fatalf(format string, v ...any) {
|
||||
logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(FATAL, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Log logs a message at a given level with caller information
|
||||
@@ -75,7 +88,7 @@ func (b *Logger) Log(msgL, caller int, format string, a ...any) {
|
||||
level = lvl
|
||||
}
|
||||
}
|
||||
logMessage(level, b.component, fmt.Sprintf(format, a...), nil)
|
||||
logMessage(level, b.component, maskSecrets(fmt.Sprintf(format, a...)), nil)
|
||||
}
|
||||
|
||||
// Sync flushes log buffer (no-op for this implementation)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package media
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const TempDirName = "picoclaw_media"
|
||||
|
||||
// TempDir returns the shared temporary directory used for downloaded media.
|
||||
func TempDir() string {
|
||||
return filepath.Join(os.TempDir(), TempDirName)
|
||||
}
|
||||
@@ -221,11 +221,17 @@ func buildRequestBody(
|
||||
|
||||
// Add tool_use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Handle nil Arguments (GLM-4 may return null input)
|
||||
input := tc.Arguments
|
||||
if input == nil {
|
||||
input = map[string]any{}
|
||||
}
|
||||
|
||||
toolUse := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": tc.ID,
|
||||
"name": tc.Name,
|
||||
"input": tc.Arguments,
|
||||
"input": input,
|
||||
}
|
||||
content = append(content, toolUse)
|
||||
}
|
||||
|
||||
+53
-20
@@ -20,10 +20,12 @@ type JobExecutor interface {
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
allowCommand bool
|
||||
execEnabled bool
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
@@ -32,17 +34,32 @@ func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) (*CronTool, error) {
|
||||
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
allowCommand := true
|
||||
execEnabled := true
|
||||
if config != nil {
|
||||
allowCommand = config.Tools.Cron.AllowCommand
|
||||
execEnabled = config.Tools.Exec.Enabled
|
||||
}
|
||||
|
||||
execTool.SetTimeout(execTimeout)
|
||||
var execTool *ExecTool
|
||||
if execEnabled {
|
||||
var err error
|
||||
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if execTool != nil {
|
||||
execTool.SetTimeout(execTimeout)
|
||||
}
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
allowCommand: allowCommand,
|
||||
execEnabled: execEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"command_confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
|
||||
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
|
||||
},
|
||||
"at_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"deliver": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
// Read deliver parameter, default to false so scheduled tasks execute through the agent
|
||||
deliver := false
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
|
||||
// Non-command reminders (plain messages) remain open to all channels.
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
|
||||
// allow_command is disabled, explicit confirmation is required as an override.
|
||||
// Non-command reminders remain open to all channels.
|
||||
command, _ := args["command"].(string)
|
||||
commandConfirm, _ := args["command_confirm"].(bool)
|
||||
if command != "" {
|
||||
if !t.execEnabled {
|
||||
return ErrorResult("command execution is disabled")
|
||||
}
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
return ErrorResult("scheduling command execution is restricted to internal channels")
|
||||
}
|
||||
if !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required to schedule command execution")
|
||||
if !t.allowCommand && !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required when allow_command is disabled")
|
||||
}
|
||||
deliver = false
|
||||
}
|
||||
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
|
||||
// Execute command if present
|
||||
if job.Payload.Command != "" {
|
||||
if !t.execEnabled || t.execTool == nil {
|
||||
output := "Error executing scheduled command: command execution is disabled"
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
args := map[string]any{
|
||||
"command": job.Payload.Command,
|
||||
"__channel": channel,
|
||||
|
||||
+126
-6
@@ -5,18 +5,18 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
)
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.DefaultConfig()
|
||||
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCronTool() error: %v", err)
|
||||
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
|
||||
return tool
|
||||
}
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
t.Helper()
|
||||
return newTestCronToolWithConfig(t, config.DefaultConfig())
|
||||
}
|
||||
|
||||
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
|
||||
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
|
||||
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when command_confirm is missing")
|
||||
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command_confirm=true") {
|
||||
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
|
||||
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected command scheduling to be blocked when exec is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command execution is disabled") {
|
||||
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "send me a poem",
|
||||
"at_seconds": float64(600),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
jobs := tool.cronService.ListJobs(false)
|
||||
if len(jobs) != 1 {
|
||||
t.Fatalf("expected 1 job, got %d", len(jobs))
|
||||
}
|
||||
if jobs[0].Payload.Deliver {
|
||||
t.Fatal("expected deliver=false by default for non-command jobs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
job := &cron.CronJob{}
|
||||
job.Payload.Channel = "cli"
|
||||
job.Payload.To = "direct"
|
||||
job.Payload.Command = "df -h"
|
||||
|
||||
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
|
||||
t.Fatalf("ExecuteJob() = %q, want ok", got)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
if !strings.Contains(msg.Content, "command execution is disabled") {
|
||||
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
+161
-9
@@ -20,8 +20,7 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
// validatePath ensures the given path is within the workspace if restrict is true.
|
||||
func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
|
||||
if restrict {
|
||||
if isAllowedPath(absPath, patterns) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
if len(patterns) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(path)
|
||||
if !filepath.IsAbs(cleaned) {
|
||||
return false
|
||||
}
|
||||
if !matchesAllowedPath(cleaned, patterns) {
|
||||
return false
|
||||
}
|
||||
|
||||
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchesAllowedPath(resolved, patterns)
|
||||
}
|
||||
|
||||
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
cleaned := filepath.Clean(path)
|
||||
for _, pattern := range patterns {
|
||||
if pattern.MatchString(cleaned) {
|
||||
return true
|
||||
}
|
||||
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
|
||||
raw := pattern.String()
|
||||
if !strings.HasPrefix(raw, "^") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
literal := strings.TrimPrefix(raw, "^")
|
||||
|
||||
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
|
||||
literal = strings.TrimSuffix(literal, "(?:/|$)")
|
||||
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
|
||||
|
||||
// Reject patterns that still contain regex operators after removing the
|
||||
// optional anchored-directory suffix. That keeps arbitrary regex behavior
|
||||
// unchanged and only enables normalized prefix matching for literal paths.
|
||||
if containsUnescapedRegexMeta(literal) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
unescaped, ok := unescapeRegexLiteral(literal)
|
||||
if !ok || unescaped == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
|
||||
}
|
||||
|
||||
func appendUniquePath(paths []string, path string) []string {
|
||||
for _, existing := range paths {
|
||||
if existing == path {
|
||||
return paths
|
||||
}
|
||||
}
|
||||
return append(paths, path)
|
||||
}
|
||||
|
||||
func containsUnescapedRegexMeta(s string) bool {
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
|
||||
return true
|
||||
}
|
||||
}
|
||||
return escaped
|
||||
}
|
||||
|
||||
func unescapeRegexLiteral(s string) (string, bool) {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
b.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
if escaped {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func isWithinAllowedRoot(path, root string) bool {
|
||||
candidate := filepath.Clean(path)
|
||||
allowedVariants := []string{filepath.Clean(root)}
|
||||
|
||||
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
|
||||
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
|
||||
}
|
||||
|
||||
for _, allowedRoot := range allowedVariants {
|
||||
if isWithinWorkspace(candidate, allowedRoot) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolvePathAgainstExistingAncestor(path string) (string, error) {
|
||||
cleaned := filepath.Clean(path)
|
||||
for current := cleaned; ; current = filepath.Dir(current) {
|
||||
resolved, err := filepath.EvalSymlinks(current)
|
||||
if err == nil {
|
||||
suffix, relErr := filepath.Rel(current, cleaned)
|
||||
if relErr != nil {
|
||||
return "", relErr
|
||||
}
|
||||
if suffix == "." {
|
||||
return filepath.Clean(resolved), nil
|
||||
}
|
||||
return filepath.Clean(filepath.Join(resolved, suffix)), nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
return err == nil && (rel == "." || filepath.IsLocal(rel))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
@@ -625,12 +782,7 @@ type whitelistFs struct {
|
||||
}
|
||||
|
||||
func (w *whitelistFs) matches(path string) bool {
|
||||
for _, p := range w.patterns {
|
||||
if p.MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return isAllowedPath(path, w.patterns)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
|
||||
|
||||
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
allowedDir := t.TempDir()
|
||||
secretDir := t.TempDir()
|
||||
secretFile := filepath.Join(secretDir, "secret.txt")
|
||||
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(secretFile) error = %v", err)
|
||||
}
|
||||
|
||||
linkPath := filepath.Join(allowedDir, "link_out")
|
||||
if err := os.Symlink(secretDir, linkPath); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
rootDir := t.TempDir()
|
||||
allowedDir := filepath.Join(rootDir, "allowed")
|
||||
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewWriteFileTool(workspace, true, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": targetFile,
|
||||
"content": "outside write",
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(targetFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile(targetFile) error = %v", err)
|
||||
}
|
||||
if string(data) != "outside write" {
|
||||
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
realDir := t.TempDir()
|
||||
linkParent := t.TempDir()
|
||||
allowedAlias := filepath.Join(linkParent, "allowed-link")
|
||||
|
||||
if err := os.Symlink(realDir, allowedAlias); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
|
||||
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(targetFile) error = %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
|
||||
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
),
|
||||
}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "through alias") {
|
||||
t.Fatalf("expected file content, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
|
||||
// by reading a file in multiple chunks using 'offset' and 'length'.
|
||||
func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
|
||||
+15
-2
@@ -6,6 +6,7 @@ import (
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
@@ -21,20 +22,32 @@ type SendFileTool struct {
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
}
|
||||
|
||||
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
|
||||
func NewSendFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
store media.MediaStore,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *SendFileTool {
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &SendFileTool{
|
||||
workspace: workspace,
|
||||
restrict: restrict,
|
||||
maxFileSize: maxFileSize,
|
||||
mediaStore: store,
|
||||
allowPaths: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("media store not configured")
|
||||
}
|
||||
|
||||
resolved, err := validatePath(path, t.workspace, t.restrict)
|
||||
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
testPath := testFile.Name()
|
||||
if _, err := testFile.WriteString("forward me"); err != nil {
|
||||
testFile.Close()
|
||||
t.Fatalf("WriteString(testFile) error = %v", err)
|
||||
}
|
||||
if err := testFile.Close(); err != nil {
|
||||
t.Fatalf("Close(testFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(testPath) })
|
||||
|
||||
pattern := regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_MagicBytes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
+31
-13
@@ -23,6 +23,7 @@ type ExecTool struct {
|
||||
denyPatterns []*regexp.Regexp
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
allowedPathPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
allowRemote bool
|
||||
}
|
||||
@@ -95,14 +96,23 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
func NewExecToolWithConfig(
|
||||
workingDir string,
|
||||
restrict bool,
|
||||
config *config.Config,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
var allowedPathPatterns []*regexp.Regexp
|
||||
allowRemote := true
|
||||
if len(allowPaths) > 0 {
|
||||
allowedPathPatterns = allowPaths[0]
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
allowedPathPatterns: allowedPathPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
allowRemote: allowRemote,
|
||||
}, nil
|
||||
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePath(wd, t.workingDir, true)
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
if isAllowedPath(resolved, t.allowedPathPatterns) {
|
||||
cwd = resolved
|
||||
} else {
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
|
||||
// timeout == 0 means no timeout
|
||||
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
if safePaths[p] {
|
||||
continue
|
||||
}
|
||||
if isAllowedPath(p, t.allowedPathPatterns) {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(cwdPath, p)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SpawnStatusTool reports the status of subagents that were spawned via the
|
||||
// spawn tool. It can query a specific task by ID, or list every known task with
|
||||
// a summary count broken-down by status.
|
||||
type SpawnStatusTool struct {
|
||||
manager *SubagentManager
|
||||
}
|
||||
|
||||
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
|
||||
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
|
||||
return &SpawnStatusTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Name() string {
|
||||
return "spawn_status"
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Description() string {
|
||||
return "Get the status of spawned subagents. " +
|
||||
"Returns a list of all subagents and their current state " +
|
||||
"(running, completed, failed, or canceled), or retrieves details " +
|
||||
"for a specific subagent task when task_id is provided. " +
|
||||
"Results are scoped to the current conversation's channel and chat ID; " +
|
||||
"all tasks are listed only when no channel/chat context is injected " +
|
||||
"(e.g. direct programmatic calls via Execute)."
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"task_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
|
||||
"subagent. When omitted, all visible subagents are listed.",
|
||||
},
|
||||
},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Derive the calling conversation's identity so we can scope results to the
|
||||
// current chat only — preventing cross-conversation task leakage in
|
||||
// multi-user deployments.
|
||||
callerChannel := ToolChannel(ctx)
|
||||
callerChatID := ToolChatID(ctx)
|
||||
|
||||
var taskID string
|
||||
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
|
||||
taskIDStr, ok := rawTaskID.(string)
|
||||
if !ok {
|
||||
return ErrorResult("task_id must be a string")
|
||||
}
|
||||
taskID = strings.TrimSpace(taskIDStr)
|
||||
}
|
||||
|
||||
if taskID != "" {
|
||||
// GetTaskCopy returns a consistent snapshot under the manager lock,
|
||||
// eliminating any data race with the concurrent subagent goroutine.
|
||||
taskCopy, ok := t.manager.GetTaskCopy(taskID)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
// Restrict lookup to tasks that belong to this conversation.
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
return NewToolResult(spawnStatusFormatTask(&taskCopy))
|
||||
}
|
||||
|
||||
// ListTaskCopies returns consistent snapshots under the manager lock.
|
||||
origTasks := t.manager.ListTaskCopies()
|
||||
if len(origTasks) == 0 {
|
||||
return NewToolResult("No subagents have been spawned yet.")
|
||||
}
|
||||
|
||||
tasks := make([]*SubagentTask, 0, len(origTasks))
|
||||
for i := range origTasks {
|
||||
cpy := &origTasks[i]
|
||||
|
||||
// Filter to tasks that originate from the current conversation only.
|
||||
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
|
||||
continue
|
||||
}
|
||||
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
|
||||
continue
|
||||
}
|
||||
|
||||
tasks = append(tasks, cpy)
|
||||
}
|
||||
|
||||
if len(tasks) == 0 {
|
||||
return NewToolResult("No subagents found for this conversation.")
|
||||
}
|
||||
|
||||
// Order by creation time (ascending) so spawning order is preserved.
|
||||
// Fall back to ID string for tasks created in the same millisecond.
|
||||
sort.Slice(tasks, func(i, j int) bool {
|
||||
if tasks[i].Created != tasks[j].Created {
|
||||
return tasks[i].Created < tasks[j].Created
|
||||
}
|
||||
return tasks[i].ID < tasks[j].ID
|
||||
})
|
||||
|
||||
counts := map[string]int{}
|
||||
for _, task := range tasks {
|
||||
counts[task.Status]++
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
|
||||
for _, status := range []string{"running", "completed", "failed", "canceled"} {
|
||||
if n := counts[status]; n > 0 {
|
||||
label := strings.ToUpper(status[:1]) + status[1:] + ":"
|
||||
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, task := range tasks {
|
||||
sb.WriteString(spawnStatusFormatTask(task))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
|
||||
}
|
||||
|
||||
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
|
||||
func spawnStatusFormatTask(task *SubagentTask) string {
|
||||
var sb strings.Builder
|
||||
|
||||
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
|
||||
if task.Label != "" {
|
||||
header += fmt.Sprintf(" label=%q", task.Label)
|
||||
}
|
||||
if task.AgentID != "" {
|
||||
header += fmt.Sprintf(" agent=%s", task.AgentID)
|
||||
}
|
||||
if task.Created > 0 {
|
||||
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
|
||||
header += fmt.Sprintf(" created=%s", created)
|
||||
}
|
||||
sb.WriteString(header)
|
||||
|
||||
if task.Task != "" {
|
||||
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
|
||||
}
|
||||
if task.Result != "" {
|
||||
result := task.Result
|
||||
const maxResultLen = 300
|
||||
runes := []rune(result)
|
||||
if len(runes) > maxResultLen {
|
||||
result = string(runes[:maxResultLen]) + "…"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("\n result: %s", result))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSpawnStatusTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
if tool.Name() != "spawn_status" {
|
||||
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(desc), "subagent") {
|
||||
t.Errorf("Description should mention 'subagent', got: %s", desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got: %v", params["type"])
|
||||
}
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected 'properties' to be a map")
|
||||
}
|
||||
if _, hasTaskID := props["task_id"]; !hasTaskID {
|
||||
t.Error("Expected 'task_id' parameter in properties")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_NilManager(t *testing.T) {
|
||||
tool := &SpawnStatusTool{manager: nil}
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Error("Expected error result when manager is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Empty(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "No subagents") {
|
||||
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Do task A",
|
||||
Label: "task-a",
|
||||
Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2",
|
||||
Task: "Do task B",
|
||||
Label: "task-b",
|
||||
Status: "completed",
|
||||
Result: "Done successfully",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-3"] = &SubagentTask{
|
||||
ID: "subagent-3",
|
||||
Task: "Do task C",
|
||||
Status: "failed",
|
||||
Result: "Error: something went wrong",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Summary header
|
||||
if !strings.Contains(result.ForLLM, "3 total") {
|
||||
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Individual task IDs
|
||||
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
|
||||
if !strings.Contains(result.ForLLM, id) {
|
||||
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Status values
|
||||
for _, status := range []string{"running", "completed", "failed"} {
|
||||
if !strings.Contains(result.ForLLM, status) {
|
||||
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Result content
|
||||
if !strings.Contains(result.ForLLM, "Done successfully") {
|
||||
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-42"] = &SubagentTask{
|
||||
ID: "subagent-42",
|
||||
Task: "Specific task",
|
||||
Label: "my-task",
|
||||
Status: "failed",
|
||||
Result: "Something went wrong",
|
||||
Created: time.Now().UnixMilli(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-42") {
|
||||
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "failed") {
|
||||
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Something went wrong") {
|
||||
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "my-task") {
|
||||
t.Errorf("Expected label in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "nonexistent-999") {
|
||||
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task_id must be a string") {
|
||||
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
longResult := strings.Repeat("X", 500)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Long task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// Output should be shorter than the raw result due to truncation
|
||||
if len(result.ForLLM) >= len(longResult) {
|
||||
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
|
||||
cjkChar := string(rune(0x5b57))
|
||||
longResult := strings.Repeat(cjkChar, 400)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Unicode task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator in output")
|
||||
}
|
||||
// The truncated result must be valid UTF-8 (no split rune boundaries).
|
||||
if !strings.Contains(result.ForLLM, cjkChar) {
|
||||
t.Errorf("Expected CJK runes to appear intact in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
|
||||
id := fmt.Sprintf("subagent-%d", i+1)
|
||||
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// The summary line should mention all statuses that have counts
|
||||
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
|
||||
if !strings.Contains(result.ForLLM, want) {
|
||||
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
// Intentionally insert with out-of-order IDs and timestamps that reflect
|
||||
// true spawn order: subagent-2 was spawned first, subagent-10 second.
|
||||
manager.tasks["subagent-10"] = &SubagentTask{
|
||||
ID: "subagent-10", Task: "second", Status: "running",
|
||||
Created: now + 1,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "first", Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
pos2 := strings.Index(result.ForLLM, "subagent-2")
|
||||
pos10 := strings.Index(result.ForLLM, "subagent-10")
|
||||
if pos2 < 0 || pos10 < 0 {
|
||||
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
|
||||
}
|
||||
if pos2 > pos10 {
|
||||
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "mine", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "other user", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-B",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Caller is chat-A — should only see subagent-1.
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "subagent-2") {
|
||||
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-99"] = &SubagentTask{
|
||||
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
|
||||
OriginChannel: "slack", OriginChatID: "room-Z",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Different chat trying to look up subagent-99 by ID.
|
||||
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
|
||||
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "t", Status: "completed",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// No ToolContext injected (e.g. a direct programmatic call that bypasses
|
||||
// WithToolContext entirely) — callerChannel and callerChatID are both "".
|
||||
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
|
||||
// which *does* inject a non-empty context; this test covers the case where
|
||||
// no context injection happens at all.
|
||||
// The filter conditions require a non-empty caller value, so all tasks pass through.
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
+25
-3
@@ -109,9 +109,6 @@ func (sm *SubagentManager) Spawn(
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
|
||||
task.Status = "running"
|
||||
task.Created = time.Now().UnixMilli()
|
||||
|
||||
// Build system prompt for subagent
|
||||
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
|
||||
You have access to tools - use them as needed to complete your task.
|
||||
@@ -219,6 +216,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
|
||||
return task, ok
|
||||
}
|
||||
|
||||
// GetTaskCopy returns a copy of the task with the given ID, taken under the
|
||||
// read lock, so the caller receives a consistent snapshot with no data race.
|
||||
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if !ok {
|
||||
return SubagentTask{}, false
|
||||
}
|
||||
return *task, true
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
@@ -230,6 +239,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
return tasks
|
||||
}
|
||||
|
||||
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
|
||||
// so callers receive consistent snapshots with no data race.
|
||||
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
copies := make([]SubagentTask, 0, len(sm.tasks))
|
||||
for _, task := range sm.tasks {
|
||||
copies = append(copies, *task)
|
||||
}
|
||||
return copies
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
|
||||
// and returns the result directly in the ToolResult.
|
||||
|
||||
+95
-10
@@ -780,11 +780,17 @@ type WebFetchTool struct {
|
||||
client *http.Client
|
||||
format string
|
||||
fetchLimitBytes int64
|
||||
whitelist *privateHostWhitelist
|
||||
}
|
||||
|
||||
type privateHostWhitelist struct {
|
||||
exact map[string]struct{}
|
||||
cidrs []*net.IPNet
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
// createHTTPClient cannot fail with an empty proxy string.
|
||||
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes)
|
||||
return NewWebFetchToolWithProxy(maxChars, "", format, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
|
||||
@@ -792,9 +798,22 @@ func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFe
|
||||
var allowPrivateWebFetchHosts atomic.Bool
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
return NewWebFetchToolWithConfig(maxChars, proxy, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithConfig(
|
||||
maxChars int,
|
||||
proxy string,
|
||||
fetchLimitBytes int64,
|
||||
privateHostWhitelist []string,
|
||||
) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
|
||||
}
|
||||
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
@@ -804,13 +823,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = newSafeDialContext(dialer)
|
||||
transport.DialContext = newSafeDialContext(dialer, whitelist)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
if isObviousPrivateHost(req.URL.Hostname()) {
|
||||
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
|
||||
return fmt.Errorf("redirect target is private or local network host")
|
||||
}
|
||||
return nil
|
||||
@@ -824,6 +843,7 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, format string, fetchLi
|
||||
client: client,
|
||||
format: format,
|
||||
fetchLimitBytes: fetchLimitBytes,
|
||||
whitelist: whitelist,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -875,7 +895,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
|
||||
// The real SSRF guard is newSafeDialContext at connect time.
|
||||
hostname := parsedURL.Hostname()
|
||||
if isObviousPrivateHost(hostname) {
|
||||
if isObviousPrivateHost(hostname, t.whitelist) {
|
||||
return ErrorResult("fetching private or local network hosts is not allowed")
|
||||
}
|
||||
|
||||
@@ -1019,7 +1039,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
|
||||
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
|
||||
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
|
||||
func newSafeDialContext(
|
||||
dialer *net.Dialer,
|
||||
whitelist *privateHostWhitelist,
|
||||
) func(context.Context, string, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
@@ -1034,7 +1057,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateOrRestrictedIP(ip) {
|
||||
if shouldBlockPrivateIP(ip, whitelist) {
|
||||
return nil, fmt.Errorf("blocked private or local target: %s", host)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
@@ -1048,7 +1071,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
attempted := 0
|
||||
var lastErr error
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if isPrivateOrRestrictedIP(ipAddr.IP) {
|
||||
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
|
||||
continue
|
||||
}
|
||||
attempted++
|
||||
@@ -1060,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if attempted == 0 {
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
|
||||
@@ -1069,10 +1092,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
}
|
||||
|
||||
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
|
||||
if len(entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
whitelist := &privateHostWhitelist{
|
||||
exact: make(map[string]struct{}),
|
||||
cidrs: make([]*net.IPNet, 0, len(entries)),
|
||||
}
|
||||
for _, entry := range entries {
|
||||
entry = strings.TrimSpace(entry)
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
|
||||
continue
|
||||
}
|
||||
_, network, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
|
||||
}
|
||||
whitelist.cidrs = append(whitelist.cidrs, network)
|
||||
}
|
||||
|
||||
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return whitelist, nil
|
||||
}
|
||||
|
||||
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
|
||||
if w == nil || ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
normalized := normalizeWhitelistIP(ip)
|
||||
if _, ok := w.exact[normalized.String()]; ok {
|
||||
return true
|
||||
}
|
||||
for _, network := range w.cidrs {
|
||||
if network.Contains(normalized) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeWhitelistIP(ip net.IP) net.IP {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
|
||||
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
|
||||
}
|
||||
|
||||
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
|
||||
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
|
||||
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
|
||||
func isObviousPrivateHost(host string) bool {
|
||||
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return false
|
||||
}
|
||||
@@ -1088,7 +1173,7 @@ func isObviousPrivateHost(host string) bool {
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
return isPrivateOrRestrictedIP(ip)
|
||||
return shouldBlockPrivateIP(ip, whitelist)
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
@@ -425,6 +426,29 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
|
||||
t.Helper()
|
||||
hostPort := strings.TrimPrefix(rawURL, "http://")
|
||||
hostPort = strings.TrimPrefix(hostPort, "https://")
|
||||
host, port, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func singleHostCIDR(t *testing.T, host string) string {
|
||||
t.Helper()
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", host)
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
return ip.String() + "/32"
|
||||
}
|
||||
return ip.String() + "/128"
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
@@ -443,6 +467,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("exact whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{host})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("cidr whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", testFetchLimit, []string{singleHostCIDR(t, host)})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
@@ -572,6 +646,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
|
||||
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err == nil {
|
||||
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
accepted := make(chan struct{}, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
accepted <- struct{}{}
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse whitelist: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
|
||||
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err != nil {
|
||||
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
select {
|
||||
case <-accepted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected localhost listener to accept a connection")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
|
||||
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -662,6 +799,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
||||
_, err := NewWebFetchToolWithConfig(1024, "", testFetchLimit, []string{"not-an-ip-or-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid whitelist entry to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid entry") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
|
||||
+2
-1
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
|
||||
@@ -67,7 +68,7 @@ func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
|
||||
opts.LoggerPrefix = "utils"
|
||||
}
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{
|
||||
"error": err.Error(),
|
||||
|
||||
Reference in New Issue
Block a user