Merge upstream/main into feat/subturn-poc

Includes JSONL session persistence (#1170), spawn_status tool, Azure provider,
credential encryption, and various fixes. SubTurn features preserved and
integrated with new spawn_status functionality.
This commit is contained in:
Administrator
2026-03-17 21:55:20 +08:00
110 changed files with 7413 additions and 1547 deletions
+25 -2
View File
@@ -10,6 +10,7 @@ import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -66,7 +67,7 @@ func NewAgentInstance(
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
// Compile path whitelist patterns from config.
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
allowReadPaths := buildAllowReadPatterns(cfg)
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
@@ -82,7 +83,7 @@ func NewAgentInstance(
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
@@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
return compiled
}
func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp {
var configured []string
if cfg != nil {
configured = cfg.Tools.AllowReadPaths
}
compiled := compilePatterns(configured)
mediaDirPattern := regexp.MustCompile(mediaTempDirPattern())
for _, pattern := range compiled {
if pattern.String() == mediaDirPattern.String() {
return compiled
}
}
return append(compiled, mediaDirPattern)
}
func mediaTempDirPattern() string {
sep := regexp.QuoteMeta(string(os.PathSeparator))
return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)"
}
// Close releases resources held by the agent's session store.
func (a *AgentInstance) Close() error {
if a.Sessions != nil {
+86
View File
@@ -1,10 +1,14 @@
package agent
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
})
}
}
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt")
if err != nil {
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
}
mediaPath := mediaFile.Name()
if _, err := mediaFile.WriteString("attachment content"); err != nil {
mediaFile.Close()
t.Fatalf("WriteString(mediaFile) error = %v", err)
}
if err := mediaFile.Close(); err != nil {
t.Fatalf("Close(mediaFile) error = %v", err)
}
t.Cleanup(func() { _ = os.Remove(mediaPath) })
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "test-model",
RestrictToWorkspace: true,
},
},
Tools: config.ToolsConfig{
ReadFile: config.ReadFileToolConfig{Enabled: true},
ListDir: config.ToolConfig{Enabled: true},
Exec: config.ExecConfig{
ToolConfig: config.ToolConfig{Enabled: true},
EnableDenyPatterns: true,
AllowRemote: true,
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
readTool, ok := agent.Tools.Get("read_file")
if !ok {
t.Fatal("read_file tool not registered")
}
readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath})
if readResult.IsError {
t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM)
}
if !strings.Contains(readResult.ForLLM, "attachment content") {
t.Fatalf("read_file output missing media content: %s", readResult.ForLLM)
}
listTool, ok := agent.Tools.Get("list_dir")
if !ok {
t.Fatal("list_dir tool not registered")
}
listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir})
if listResult.IsError {
t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM)
}
if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) {
t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM)
}
execTool, ok := agent.Tools.Get("exec")
if !ok {
t.Fatal("exec tool not registered")
}
execResult := execTool.Execute(context.Background(), map[string]any{
"command": "cat " + filepath.Base(mediaPath),
"working_dir": mediaDir,
})
if execResult.IsError {
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
}
if !strings.Contains(execResult.ForLLM, "attachment content") {
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
}
}
+70 -61
View File
@@ -124,6 +124,8 @@ func registerSharedTools(
registry *AgentRegistry,
provider providers.LLMProvider,
) {
allowReadPaths := buildAllowReadPatterns(cfg)
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
if !ok {
@@ -202,6 +204,7 @@ func registerSharedTools(
cfg.Agents.Defaults.RestrictToWorkspace,
cfg.Agents.Defaults.GetMaxMediaSize(),
nil,
allowReadPaths,
)
agent.Tools.Register(sendFileTool)
}
@@ -229,72 +232,75 @@ func registerSharedTools(
}
}
// Spawn tool with allowlist checker
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// Set the spawner that links into AgentLoop's turnState
subagentManager.SetSpawner(func(
ctx context.Context,
task, label, targetAgentID string,
tls *tools.ToolRegistry,
maxTokens int,
temperature float64,
hasMaxTokens, hasTemperature bool,
) (*tools.ToolResult, error) {
// 1. Recover parent Turn State from Context
parentTS := turnStateFromContext(ctx)
if parentTS == nil {
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
parentTS = &turnState{
ctx: ctx,
turnID: "adhoc-root",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
}
// 2. Build Tools slice from registry
var tlSlice []tools.Tool
for _, name := range tls.List() {
if t, ok := tls.Get(name); ok {
tlSlice = append(tlSlice, t)
}
}
// Spawn and spawn_status tools share a SubagentManager.
// Construct it when either tool is enabled (both require subagent).
spawnEnabled := cfg.Tools.IsToolEnabled("spawn")
spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status")
if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// 3. System Prompt
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
"You have access to tools - use them as needed to complete your task.\n" +
"After completing the task, provide a clear summary of what was done.\n\n" +
"Task: " + task
// 4. Resolve Model
modelToUse := agent.Model
if targetAgentID != "" {
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
modelToUse = targetAgent.Model
}
// Set the spawner that links into AgentLoop's turnState
subagentManager.SetSpawner(func(
ctx context.Context,
task, label, targetAgentID string,
tls *tools.ToolRegistry,
maxTokens int,
temperature float64,
hasMaxTokens, hasTemperature bool,
) (*tools.ToolResult, error) {
// 1. Recover parent Turn State from Context
parentTS := turnStateFromContext(ctx)
if parentTS == nil {
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
parentTS = &turnState{
ctx: ctx,
turnID: "adhoc-root",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
}
// 5. Build SubTurnConfig
cfg := SubTurnConfig{
Model: modelToUse,
Tools: tlSlice,
SystemPrompt: systemPrompt,
// 2. Build Tools slice from registry
var tlSlice []tools.Tool
for _, name := range tls.List() {
if t, ok := tls.Get(name); ok {
tlSlice = append(tlSlice, t)
}
if hasMaxTokens {
cfg.MaxTokens = maxTokens
}
// 3. System Prompt
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
"You have access to tools - use them as needed to complete your task.\n" +
"After completing the task, provide a clear summary of what was done.\n\n" +
"Task: " + task
// 4. Resolve Model
modelToUse := agent.Model
if targetAgentID != "" {
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
modelToUse = targetAgent.Model
}
}
// 6. Spawn SubTurn
return spawnSubTurn(ctx, al, parentTS, cfg)
})
// 5. Build SubTurnConfig
cfg := SubTurnConfig{
Model: modelToUse,
Tools: tlSlice,
SystemPrompt: systemPrompt,
}
if hasMaxTokens {
cfg.MaxTokens = maxTokens
}
// 6. Spawn SubTurn
return spawnSubTurn(ctx, al, parentTS, cfg)
})
if spawnEnabled {
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
@@ -311,9 +317,12 @@ func registerSharedTools(
subagentTool := tools.NewSubagentTool(subagentManager)
subagentTool.SetSpawner(spawner)
agent.Tools.Register(subagentTool)
} else {
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
}
if spawnStatusEnabled {
agent.Tools.Register(tools.NewSpawnStatusTool(subagentManager))
}
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
}
}
}
+1 -1
View File
@@ -618,7 +618,7 @@ func (c *FeishuChannel) downloadResource(
}
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
mediaDir := media.TempDir()
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
"error": mkdirErr.Error(),
+1 -2
View File
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
if len(m.channels) == 0 {
logger.WarnC("channels", "No channels enabled")
return errors.New("no channels enabled")
}
logger.InfoC("channels", "Starting all channels")
@@ -397,7 +396,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
"addr": m.httpServer.Addr,
})
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"error": err.Error(),
})
}
+1 -3
View File
@@ -35,8 +35,6 @@ const (
roomKindCacheTTL = 5 * time.Minute
roomKindCacheCleanupPeriod = 1 * time.Minute
roomKindCacheMaxEntries = 2048
matrixMediaTempDirName = "picoclaw_media"
)
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
@@ -1105,7 +1103,7 @@ func (c *MatrixChannel) stripSelfMention(text string) string {
}
func matrixMediaTempDir() (string, error) {
mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName)
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
return "", err
}
+2 -1
View File
@@ -15,6 +15,7 @@ import (
"maunium.net/go/mautrix/id"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
@@ -165,7 +166,7 @@ func TestMatrixMediaTempDir(t *testing.T) {
if err != nil {
t.Fatalf("matrixMediaTempDir failed: %v", err)
}
if filepath.Base(dir) != matrixMediaTempDirName {
if filepath.Base(dir) != media.TempDirName {
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
}
+28 -3
View File
@@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return
}
conn, err := c.upgrader.Upgrade(w, r, nil)
// Echo the matched subprotocol back so the browser accepts the upgrade.
var responseHeader http.Header
if proto := c.matchedSubprotocol(r); proto != "" {
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
}
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
if err != nil {
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
"error": err.Error(),
@@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
go c.readLoop(pc)
}
// authenticate checks the Bearer token from the Authorization header.
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
// authenticate checks the request for a valid token:
// 1. Authorization: Bearer <token> header
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
// 3. Query parameter "token" (only when AllowTokenQuery is on)
func (c *PicoChannel) authenticate(r *http.Request) bool {
token := c.config.Token
if token == "" {
@@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
}
}
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
if c.matchedSubprotocol(r) != "" {
return true
}
// Check query parameter only when explicitly allowed
if c.config.AllowTokenQuery {
if r.URL.Query().Get("token") == token {
@@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
return false
}
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
// the configured token, or "" if none do.
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
token := c.config.Token
for _, proto := range websocket.Subprotocols(r) {
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
return proto
}
}
return ""
}
// readLoop reads messages from a WebSocket connection.
func (c *PicoChannel) readLoop(pc *picoConn) {
defer func() {
+79 -6
View File
@@ -4,11 +4,13 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync/atomic"
"github.com/caarlos0/env/v11"
"github.com/sipeed/picoclaw/pkg/credential"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
@@ -624,8 +626,9 @@ func (c *ModelConfig) Validate() error {
}
type GatewayConfig struct {
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
}
type ToolDiscoveryConfig struct {
@@ -698,8 +701,9 @@ type WebToolsConfig struct {
}
type CronToolsConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
AllowCommand bool ` env:"PICOCLAW_TOOLS_CRON_ALLOW_COMMAND" json:"allow_command"`
}
type ExecConfig struct {
@@ -749,6 +753,7 @@ type ToolsConfig struct {
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SpawnStatus ToolConfig `json:"spawn_status" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
@@ -838,10 +843,24 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
if passphrase := credential.PassphraseProvider(); passphrase != "" {
for _, m := range cfg.ModelList {
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
fmt.Fprintf(os.Stderr,
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
m.ModelName)
}
}
}
if err := env.Parse(cfg); err != nil {
return nil, err
}
if err := resolveAPIKeys(cfg.ModelList, filepath.Dir(path)); err != nil {
return nil, err
}
// Migrate legacy channel config fields to new unified structures
cfg.migrateChannelConfigs()
@@ -858,6 +877,48 @@ func LoadConfig(path string) (*Config, error) {
return cfg, nil
}
// encryptPlaintextAPIKeys returns a copy of models with plaintext api_key values
// encrypted. Returns (nil, nil) when nothing changed (all keys already sealed or
// empty). Returns (nil, error) if any key fails to encrypt — callers must treat
// this as a hard failure to prevent a mixed plaintext/ciphertext state on disk.
// Symmetric counterpart of resolveAPIKeys: both operate purely on []ModelConfig
// and leave JSON marshaling to the caller.
func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelConfig, error) {
sealed := make([]ModelConfig, len(models))
copy(sealed, models)
changed := false
for i := range sealed {
m := &sealed[i]
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
continue
}
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
if err != nil {
return nil, fmt.Errorf("cannot seal api_key for model %q: %w", m.ModelName, err)
}
m.APIKey = encrypted
changed = true
}
if !changed {
return nil, nil
}
return sealed, nil
}
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
func resolveAPIKeys(models []ModelConfig, configDir string) error {
cr := credential.NewResolver(configDir)
for i := range models {
resolved, err := cr.Resolve(models[i].APIKey)
if err != nil {
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
}
models[i].APIKey = resolved
}
return nil
}
func (c *Config) migrateChannelConfigs() {
// Discord: mention_only -> group_trigger.mention_only
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
@@ -872,12 +933,22 @@ func (c *Config) migrateChannelConfigs() {
}
func SaveConfig(path string, cfg *Config) error {
if passphrase := credential.PassphraseProvider(); passphrase != "" {
sealed, err := encryptPlaintextAPIKeys(cfg.ModelList, passphrase)
if err != nil {
return err
}
if sealed != nil {
tmp := *cfg
tmp.ModelList = sealed
cfg = &tmp
}
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
// Use unified atomic write utility with explicit sync for flash storage reliability.
return fileutil.WriteFileAtomic(path, data, 0o600)
}
@@ -1044,6 +1115,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
return t.ReadFile.Enabled
case "spawn":
return t.Spawn.Enabled
case "spawn_status":
return t.SpawnStatus.Enabled
case "spi":
return t.SPI.Enabled
case "subagent":
+386 -5
View File
@@ -7,8 +7,22 @@ import (
"runtime"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/credential"
)
// mustSetupSSHKey generates a temporary Ed25519 SSH key in t.TempDir() and sets
// PICOCLAW_SSH_KEY_PATH to its path for the duration of the test. This is required
// whenever a test exercises encryption/decryption via credential.Encrypt or SaveConfig.
func mustSetupSSHKey(t *testing.T) {
t.Helper()
keyPath := filepath.Join(t.TempDir(), "picoclaw_ed25519.key")
if err := credential.GenerateSSHKey(keyPath); err != nil {
t.Fatalf("mustSetupSSHKey: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", keyPath)
}
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
var m AgentModelConfig
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
@@ -253,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if cfg.Gateway.HotReload {
t.Error("Gateway hot reload should be disabled by default")
}
}
// TestDefaultConfig_Providers verifies provider structure
@@ -391,6 +408,13 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
}
}
func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Tools.Cron.AllowCommand {
t.Fatal("DefaultConfig().Tools.Cron.AllowCommand should be true")
}
}
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -423,6 +447,22 @@ func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) {
}
}
func TestLoadConfig_CronAllowCommandDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"tools":{"cron":{"exec_timeout_minutes":5}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if !cfg.Tools.Cron.AllowCommand {
t.Fatal("tools.cron.allow_command should remain true when unset in config file")
}
}
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -482,13 +522,19 @@ func TestDefaultConfig_DMScope(t *testing.T) {
}
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
// Unset to ensure we test the default
t.Setenv("PICOCLAW_HOME", "")
// Set a known home for consistent test results
t.Setenv("HOME", "/tmp/home")
var fakeHome string
if runtime.GOOS == "windows" {
fakeHome = `C:\tmp\home`
t.Setenv("USERPROFILE", fakeHome)
} else {
fakeHome = "/tmp/home"
t.Setenv("HOME", fakeHome)
}
cfg := DefaultConfig()
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
want := filepath.Join(fakeHome, ".picoclaw", "workspace")
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
@@ -499,7 +545,7 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
cfg := DefaultConfig()
want := "/custom/picoclaw/home/workspace"
want := filepath.Join("/custom/picoclaw/home", "workspace")
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
@@ -621,3 +667,338 @@ func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
}
})
}
// TestLoadConfig_WarnsForPlaintextAPIKey verifies that LoadConfig resolves a plaintext
// api_key into memory but does NOT rewrite the config file. File writes are the sole
// responsibility of SaveConfig.
func TestLoadConfig_WarnsForPlaintextAPIKey(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
const original = `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
if err := os.WriteFile(cfgPath, []byte(original), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
cfg, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
// In-memory value must be the resolved plaintext.
if cfg.ModelList[0].APIKey != "sk-plaintext" {
t.Errorf("in-memory api_key = %q, want %q", cfg.ModelList[0].APIKey, "sk-plaintext")
}
// The file on disk must remain unchanged — LoadConfig must not write anything.
raw, _ := os.ReadFile(cfgPath)
if string(raw) != original {
t.Errorf("LoadConfig must not modify the config file; got:\n%s", string(raw))
}
}
// TestSaveConfig_EncryptsPlaintextAPIKey verifies that SaveConfig writes enc:// ciphertext
// to disk and that a subsequent LoadConfig decrypts it back to the original plaintext.
func TestSaveConfig_EncryptsPlaintextAPIKey(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
cfg := DefaultConfig()
cfg.ModelList = []ModelConfig{
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
// Disk must contain enc://, not the raw key.
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "enc://") {
t.Errorf("saved file should contain enc://, got:\n%s", string(raw))
}
if strings.Contains(string(raw), "sk-plaintext") {
t.Errorf("saved file must not contain the plaintext key")
}
// A fresh load must decrypt back to the original plaintext.
cfg2, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig after SaveConfig: %v", err)
}
if cfg2.ModelList[0].APIKey != "sk-plaintext" {
t.Errorf("loaded api_key = %q, want %q", cfg2.ModelList[0].APIKey, "sk-plaintext")
}
}
// TestLoadConfig_NoSealWithoutPassphrase verifies that api_key values are left
// unchanged when PICOCLAW_KEY_PASSPHRASE is not set.
func TestLoadConfig_NoSealWithoutPassphrase(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
if _, err := LoadConfig(cfgPath); err != nil {
t.Fatalf("LoadConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if strings.Contains(string(raw), "enc://") {
t.Error("config file must not be modified when no passphrase is set")
}
}
// TestLoadConfig_FileRefNotSealed verifies that file:// api_key references are not
// converted to enc:// values (they are resolved at runtime by the Resolver).
func TestLoadConfig_FileRefNotSealed(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
keyFile := filepath.Join(dir, "openai.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"file://openai.key"}]}`
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
if _, err := LoadConfig(cfgPath); err != nil {
t.Fatalf("LoadConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "file://openai.key") {
t.Error("file:// reference should be preserved unchanged in the config file")
}
if strings.Contains(string(raw), "enc://") {
t.Error("file:// reference must not be converted to enc://")
}
}
// TestSaveConfig_MixedKeys verifies that SaveConfig encrypts only plaintext api_keys
// and leaves already-encrypted (enc://) and file:// entries unchanged.
func TestSaveConfig_MixedKeys(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
// Pre-encrypt one key so we have a genuine enc:// value to put in the config.
if err := SaveConfig(cfgPath, &Config{
ModelList: []ModelConfig{
{ModelName: "pre", Model: "openai/gpt-4", APIKey: "sk-already-plain"},
},
}); err != nil {
t.Fatalf("setup SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
// Extract the enc:// value from the saved file.
var tmp struct {
ModelList []struct {
APIKey string `json:"api_key"`
} `json:"model_list"`
}
if err := json.Unmarshal(raw, &tmp); err != nil || len(tmp.ModelList) == 0 {
t.Fatalf("setup: could not parse saved config: %v", err)
}
alreadyEncrypted := tmp.ModelList[0].APIKey
if !strings.HasPrefix(alreadyEncrypted, "enc://") {
t.Fatalf("setup: expected enc:// key, got %q", alreadyEncrypted)
}
// Build a config with three models:
// 1. plaintext → must be encrypted by SaveConfig
// 2. enc:// → must be left unchanged (already encrypted)
// 3. file:// → must be left unchanged (file reference)
keyFile := filepath.Join(dir, "api.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "plain", Model: "openai/gpt-4", APIKey: "sk-new-plaintext"},
{ModelName: "enc", Model: "openai/gpt-4", APIKey: alreadyEncrypted},
{ModelName: "file", Model: "openai/gpt-4", APIKey: "file://api.key"},
},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
raw, _ = os.ReadFile(cfgPath)
s := string(raw)
// 1. Plaintext must be encrypted.
if strings.Contains(s, "sk-new-plaintext") {
t.Error("plaintext key must not appear in saved file")
}
// 2. The pre-existing enc:// value must still be present (byte-for-byte unchanged).
if !strings.Contains(s, alreadyEncrypted) {
t.Error("pre-existing enc:// entry must be preserved unchanged")
}
// 3. file:// must be preserved.
if !strings.Contains(s, "file://api.key") {
t.Error("file:// reference must be preserved unchanged")
}
// Now load and verify all three decrypt/resolve correctly.
cfg2, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig after SaveConfig: %v", err)
}
byName := make(map[string]string)
for _, m := range cfg2.ModelList {
byName[m.ModelName] = m.APIKey
}
if byName["plain"] != "sk-new-plaintext" {
t.Errorf("plain model api_key = %q, want %q", byName["plain"], "sk-new-plaintext")
}
if byName["enc"] != "sk-already-plain" {
t.Errorf("enc model api_key = %q, want %q", byName["enc"], "sk-already-plain")
}
if byName["file"] != "sk-from-file" {
t.Errorf("file model api_key = %q, want %q", byName["file"], "sk-from-file")
}
}
// TestLoadConfig_MixedKeys_NoPassphrase verifies that when PICOCLAW_KEY_PASSPHRASE
// is not set, enc:// entries cause LoadConfig to return an error, while plaintext
// and file:// entries in the same config are not affected.
func TestLoadConfig_MixedKeys_NoPassphrase(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// First encrypt a key so we have a real enc:// value.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
if err := SaveConfig(cfgPath, &Config{
ModelList: []ModelConfig{
{ModelName: "m", Model: "openai/gpt-4", APIKey: "sk-secret"},
},
}); err != nil {
t.Fatalf("setup SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
var tmp struct {
ModelList []struct {
APIKey string `json:"api_key"`
} `json:"model_list"`
}
if err := json.Unmarshal(raw, &tmp); err != nil {
t.Fatalf("setup parse: %v", err)
}
encValue := tmp.ModelList[0].APIKey
// Write a mixed config: enc:// + plaintext + file://
keyFile := filepath.Join(dir, "api.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
mixed, _ := json.Marshal(map[string]any{
"model_list": []map[string]any{
{"model_name": "enc", "model": "openai/gpt-4", "api_key": encValue},
{"model_name": "plain", "model": "openai/gpt-4", "api_key": "sk-plain"},
{"model_name": "file", "model": "openai/gpt-4", "api_key": "file://api.key"},
},
})
if err := os.WriteFile(cfgPath, mixed, 0o600); err != nil {
t.Fatalf("setup write: %v", err)
}
// Now clear the passphrase — LoadConfig must fail because enc:// cannot be decrypted.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
_, err := LoadConfig(cfgPath)
if err == nil {
t.Fatal("LoadConfig should fail when enc:// key is present and no passphrase is set")
}
if !strings.Contains(err.Error(), "passphrase required") {
t.Errorf("error should mention passphrase required, got: %v", err)
}
}
// TestSaveConfig_UsesPassphraseProvider verifies that SaveConfig encrypts plaintext
// api_keys using credential.PassphraseProvider() rather than os.Getenv directly.
// This matters for the launcher, which clears the environment variable and redirects
// PassphraseProvider to an in-memory SecureStore.
func TestSaveConfig_UsesPassphraseProvider(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Ensure the env var is empty — passphrase must come from PassphraseProvider only.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
mustSetupSSHKey(t)
// Replace PassphraseProvider with an in-memory function (simulating SecureStore).
const testPassphrase = "provider-passphrase"
orig := credential.PassphraseProvider
credential.PassphraseProvider = func() string { return testPassphrase }
t.Cleanup(func() { credential.PassphraseProvider = orig })
cfg := DefaultConfig()
cfg.ModelList = []ModelConfig{
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "enc://") {
t.Errorf("SaveConfig should have encrypted plaintext key via PassphraseProvider; got:\n%s", raw)
}
}
// TestLoadConfig_UsesPassphraseProvider verifies that LoadConfig decrypts enc:// keys
// using credential.PassphraseProvider() rather than os.Getenv directly.
func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Ensure the env var is empty throughout.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
mustSetupSSHKey(t)
const testPassphrase = "provider-passphrase"
const plainKey = "sk-secret"
// First, encrypt the key using the same passphrase.
encrypted, err := credential.Encrypt(testPassphrase, "", plainKey)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
raw, _ := json.Marshal(map[string]any{
"model_list": []map[string]any{
{"model_name": "test", "model": "openai/gpt-4", "api_key": encrypted},
},
})
if err = os.WriteFile(cfgPath, raw, 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Redirect PassphraseProvider — env var is empty, so without this the load would fail.
orig := credential.PassphraseProvider
credential.PassphraseProvider = func() string { return testPassphrase }
t.Cleanup(func() { credential.PassphraseProvider = orig })
cfg, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
if cfg.ModelList[0].APIKey != plainKey {
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
}
}
+16 -2
View File
@@ -385,10 +385,20 @@ func DefaultConfig() *Config {
APIBase: "http://localhost:8000/v1",
APIKey: "",
},
// Azure OpenAI - https://portal.azure.com
// model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name
{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIBase: "https://your-resource.openai.azure.com",
APIKey: "",
},
},
Gateway: GatewayConfig{
Host: "127.0.0.1",
Port: 18790,
Host: "127.0.0.1",
Port: 18790,
HotReload: false,
},
Tools: ToolsConfig{
MediaCleanup: MediaCleanupConfig{
@@ -444,6 +454,7 @@ func DefaultConfig() *Config {
Enabled: true,
},
ExecTimeoutMinutes: 5,
AllowCommand: true,
},
Exec: ExecConfig{
ToolConfig: ToolConfig{
@@ -513,6 +524,9 @@ func DefaultConfig() *Config {
Spawn: ToolConfig{
Enabled: true,
},
SpawnStatus: ToolConfig{
Enabled: false,
},
SPI: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
+335
View File
@@ -0,0 +1,335 @@
// Package credential resolves API credential values for model_list entries.
//
// An API key is a form of authorization credential. This package centralizes
// how raw credential strings—plaintext or file references—are resolved into
// their actual values, keeping that logic out of the config loader.
//
// Supported formats for the api_key field:
//
// - Plaintext: "sk-abc123" → returned as-is
// - File ref: "file://filename.key" → content read from configDir/filename.key
// - Encrypted: "enc://<base64>" → AES-256-GCM decrypt via PICOCLAW_KEY_PASSPHRASE
// - Empty: "" → returned as-is (auth_method=oauth etc.)
//
// Encryption uses AES-256-GCM with HKDF-SHA256 key derivation (< 1ms, safe for embedded Linux).
// An SSH private key is required for both encryption and decryption.
// Key derivation:
//
// HKDF-SHA256(ikm=HMAC-SHA256(SHA256(sshKeyBytes), passphrase), salt, info)
//
// SSH key path resolution priority:
//
// 1. sshKeyPath argument to Encrypt (explicit)
// 2. PICOCLAW_SSH_KEY_PATH env var
// 3. ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform)
package credential
import (
"crypto/aes"
"crypto/cipher"
"crypto/hkdf"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// PassphraseEnvVar is the environment variable that holds the encryption passphrase.
// Other packages (e.g. config) reference this constant to avoid duplicating the string.
const PassphraseEnvVar = "PICOCLAW_KEY_PASSPHRASE"
// PassphraseProvider is the function used to retrieve the passphrase for enc://
// credential decryption. It defaults to reading PICOCLAW_KEY_PASSPHRASE from the
// process environment. Replace it at startup to use a different source, such as
// an in-memory SecureStore, so that all LoadConfig() calls everywhere share the
// same passphrase source without needing os.Environ.
//
// Example (launcher main.go):
//
// credential.PassphraseProvider = apiHandler.passphraseStore.Get
var PassphraseProvider func() string = func() string {
return os.Getenv(PassphraseEnvVar)
}
// ErrPassphraseRequired is returned when an enc:// credential is encountered but
// no passphrase is available from PassphraseProvider. Callers can detect this
// with errors.Is to distinguish a missing-passphrase condition from other errors.
var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
// ErrDecryptionFailed is returned when an enc:// credential cannot be decrypted,
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
const (
fileScheme = "file://"
encScheme = "enc://"
hkdfInfo = "picoclaw-credential-v1"
saltLen = 16
nonceLen = 12
keyLen = 32
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
)
// Resolver resolves raw credential strings for model_list api_key fields.
// File references are resolved relative to the directory of the config file.
type Resolver struct {
configDir string
resolvedConfigDir string // symlink-resolved form of configDir
}
// NewResolver returns a Resolver that resolves file:// references relative to
// configDir (typically filepath.Dir of the config file path).
func NewResolver(configDir string) *Resolver {
resolved := configDir
if configDir != "" {
if linkedPath, err := filepath.EvalSymlinks(configDir); err == nil {
resolved = linkedPath
}
}
return &Resolver{configDir: configDir, resolvedConfigDir: resolved}
}
// Resolve returns the actual credential value for raw:
//
// - "" → "" (no error; auth_method=oauth needs no key)
// - "file://name.key" → trimmed content of configDir/name.key
// - anything else → raw unchanged (plaintext credential)
func (r *Resolver) Resolve(raw string) (string, error) {
if raw == "" {
return "", nil
}
if strings.HasPrefix(raw, fileScheme) {
fileName := strings.TrimSpace(strings.TrimPrefix(raw, fileScheme))
if fileName == "" {
return "", fmt.Errorf("credential: file:// reference has no filename")
}
baseDir := r.resolvedConfigDir
if baseDir == "" {
baseDir = r.configDir
}
keyPath := filepath.Join(baseDir, fileName)
// Resolve symlinks before enforcing containment to prevent escaping via symlinks.
realKeyPath, err := filepath.EvalSymlinks(keyPath)
if err != nil {
return "", fmt.Errorf("credential: failed to resolve credential file path %q: %w", keyPath, err)
}
if !isWithinDir(realKeyPath, baseDir) {
return "", fmt.Errorf("credential: file:// path escapes config directory")
}
data, err := os.ReadFile(realKeyPath)
if err != nil {
return "", fmt.Errorf("credential: failed to read credential file %q: %w", realKeyPath, err)
}
value := strings.TrimSpace(string(data))
if value == "" {
return "", fmt.Errorf("credential: credential file %q is empty", realKeyPath)
}
return value, nil
}
if strings.HasPrefix(raw, encScheme) {
return resolveEncrypted(raw)
}
// Plaintext credential — return unchanged.
return raw, nil
}
// resolveEncrypted decrypts an enc:// credential using PassphraseProvider.
func resolveEncrypted(raw string) (string, error) {
passphrase := PassphraseProvider()
if passphrase == "" {
return "", ErrPassphraseRequired
}
sshKeyPath := pickSSHKeyPath("") // override="": consult env then auto-detect
b64 := strings.TrimPrefix(raw, encScheme)
blob, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", fmt.Errorf("credential: enc:// invalid base64: %w", err)
}
if len(blob) < saltLen+nonceLen+1 {
return "", fmt.Errorf("credential: enc:// payload too short")
}
salt := blob[:saltLen]
nonce := blob[saltLen : saltLen+nonceLen]
ciphertext := blob[saltLen+nonceLen:]
key, err := deriveKey(passphrase, sshKeyPath, salt)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("credential: enc:// cipher init: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("credential: enc:// gcm init: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrDecryptionFailed, err)
}
return string(plaintext), nil
}
// Encrypt encrypts plaintext and returns an enc:// credential string.
//
// passphrase is required (PICOCLAW_KEY_PASSPHRASE value).
// sshKeyPath is the SSH private key file to use; pass "" to auto-detect via
// PICOCLAW_SSH_KEY_PATH env var or ~/.ssh/picoclaw_ed25519.key.
// An SSH private key must be resolvable or Encrypt returns an error.
func Encrypt(passphrase, sshKeyPath, plaintext string) (string, error) {
if passphrase == "" {
return "", fmt.Errorf("credential: passphrase must not be empty")
}
sshKeyPath = pickSSHKeyPath(sshKeyPath)
salt := make([]byte, saltLen)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return "", fmt.Errorf("credential: failed to generate salt: %w", err)
}
key, err := deriveKey(passphrase, sshKeyPath, salt)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("credential: cipher init: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("credential: gcm init: %w", err)
}
nonce := make([]byte, nonceLen)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("credential: failed to generate nonce: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
blob := make([]byte, 0, saltLen+nonceLen+len(ciphertext))
blob = append(blob, salt...)
blob = append(blob, nonce...)
blob = append(blob, ciphertext...)
return encScheme + base64.StdEncoding.EncodeToString(blob), nil
}
// isWithinDir reports whether path is contained within (or equal to) dir.
// Uses filepath.IsLocal on the relative path for robust cross-platform traversal detection.
func isWithinDir(path, dir string) bool {
rel, err := filepath.Rel(filepath.Clean(dir), filepath.Clean(path))
return err == nil && filepath.IsLocal(rel)
}
// allowedSSHKeyPath reports whether path is in a permitted location for SSH key files:
// - exact match with PICOCLAW_SSH_KEY_PATH env var
// - within the PICOCLAW_HOME env var directory
// - within ~/.ssh/
func allowedSSHKeyPath(path string) bool {
if path == "" {
return true // passphrase-only mode; no file will be read
}
clean := filepath.Clean(path)
// Exact match with PICOCLAW_SSH_KEY_PATH.
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
if clean == filepath.Clean(envPath) {
return true
}
}
// Within PICOCLAW_HOME.
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
if isWithinDir(clean, picoHome) {
return true
}
}
// Within ~/.ssh/.
if userHome, err := os.UserHomeDir(); err == nil {
if isWithinDir(clean, filepath.Join(userHome, ".ssh")) {
return true
}
}
return false
}
// deriveKey derives a 32-byte AES-256 key from passphrase and SSH private key.
//
// ikm = HMAC-SHA256(key=SHA256(sshKeyBytes), msg=passphrase)
// Final key: HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
// sshKeyPath must be non-empty; returns an error otherwise.
func deriveKey(passphrase, sshKeyPath string, salt []byte) ([]byte, error) {
if sshKeyPath == "" {
return nil, fmt.Errorf(
"credential: SSH private key is required but not found" +
" (set PICOCLAW_SSH_KEY_PATH or place key at ~/.ssh/picoclaw_ed25519.key)")
}
if !allowedSSHKeyPath(sshKeyPath) {
return nil, fmt.Errorf(
"credential: SSH key path %q is not in an allowed location (PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/)",
sshKeyPath,
)
}
sshBytes, err := os.ReadFile(sshKeyPath)
if err != nil {
return nil, fmt.Errorf("credential: cannot read SSH key %q: %w", sshKeyPath, err)
}
sshHash := sha256.Sum256(sshBytes)
mac := hmac.New(sha256.New, sshHash[:])
mac.Write([]byte(passphrase))
ikm := mac.Sum(nil)
key, err := hkdf.Key(sha256.New, ikm, salt, hkdfInfo, keyLen)
if err != nil {
return nil, fmt.Errorf("credential: HKDF expand failed: %w", err)
}
return key, nil
}
// pickSSHKeyPath returns the SSH private key path to use for encryption/decryption.
//
// Priority:
// 1. override (non-empty explicit argument)
// 2. PICOCLAW_SSH_KEY_PATH env var
// 3. ~/.ssh/picoclaw_ed25519.key (auto-detection)
//
// Returns "" when no key is found; deriveKey will return an error in that case.
func pickSSHKeyPath(override string) string {
if override != "" {
return override
}
if p, ok := os.LookupEnv(sshKeyEnv); ok {
return p // respect explicit setting, even if ""
}
return findDefaultSSHKey()
}
// findDefaultSSHKey returns the picoclaw-specific SSH key path if it exists.
func findDefaultSSHKey() string {
p, err := DefaultSSHKeyPath()
if err != nil {
return ""
}
if _, err := os.Stat(p); err == nil {
return p
}
return ""
}
+283
View File
@@ -0,0 +1,283 @@
package credential_test
import (
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/credential"
)
func TestResolve_PlainKey(t *testing.T) {
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve("sk-plaintext-key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-plaintext-key" {
t.Fatalf("got %q, want %q", got, "sk-plaintext-key")
}
}
func TestResolve_FileKey_Success(t *testing.T) {
dir := t.TempDir()
keyFile := "openai_plain.key"
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte("sk-from-file\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
got, err := r.Resolve("file://" + keyFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-from-file" {
t.Fatalf("got %q, want %q", got, "sk-from-file")
}
}
func TestResolve_FileKey_NotFound(t *testing.T) {
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("file://missing.key")
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
}
func TestResolve_FileKey_Empty(t *testing.T) {
dir := t.TempDir()
keyFile := "empty.key"
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte(" \n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
_, err := r.Resolve("file://" + keyFile)
if err == nil {
t.Fatal("expected error for empty credential file, got nil")
}
}
// TestResolve_EncKey_RoundTrip tests basic encryption/decryption round-trip with an SSH key.
func TestResolve_EncKey_RoundTrip(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase-32bytes-long-ok!"
const plaintext = "sk-encrypted-secret"
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt(passphrase, "", plaintext)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve(enc)
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if got != plaintext {
t.Fatalf("got %q, want %q", got, plaintext)
}
}
// TestResolve_EncKey_WithSSHKey tests that the SSH key file is incorporated into key derivation.
func TestResolve_EncKey_WithSSHKey(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-private-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase"
const plaintext = "sk-ssh-protected-secret"
// Set PICOCLAW_SSH_KEY_PATH before Encrypt so the path passes allowedSSHKeyPath validation.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt(passphrase, sshKeyPath, plaintext)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve(enc)
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if got != plaintext {
t.Fatalf("got %q, want %q", got, plaintext)
}
}
func TestResolve_EncKey_NoPassphrase(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("some-passphrase", "", "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected error when PICOCLAW_KEY_PASSPHRASE is unset, got nil")
}
}
func TestResolve_EncKey_BadCiphertext(t *testing.T) {
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("enc://!!not-valid-base64!!")
if err == nil {
t.Fatal("expected error for invalid enc:// payload, got nil")
}
}
func TestResolve_EncKey_PayloadTooShort(t *testing.T) {
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
// Valid base64 but fewer bytes than salt(16)+nonce(12)+1 minimum.
import64 := "dG9vc2hvcnQ=" // "tooshort" = 8 bytes
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("enc://" + import64)
if err == nil {
t.Fatal("expected error for too-short enc:// payload, got nil")
}
}
func TestResolve_EncKey_WrongPassphrase(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("correct-passphrase", "", "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "wrong-passphrase")
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected decryption error for wrong passphrase, got nil")
}
}
func TestEncrypt_EmptyPassphrase(t *testing.T) {
_, err := credential.Encrypt("", "", "sk-secret")
if err == nil {
t.Fatal("expected error for empty passphrase, got nil")
}
}
func TestDeriveKey_SSHKeyNotFound(t *testing.T) {
// Encrypt with a real SSH key path, then try to decrypt with a missing path.
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Register the real key path so allowedSSHKeyPath validation passes for Encrypt.
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
// Point to a non-existent SSH key so deriveKey's ReadFile fails.
// The path is still under the same dir, so allowedSSHKeyPath passes (exact env match).
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", filepath.Join(dir, "nonexistent_key"))
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected error when SSH key file is missing, got nil")
}
}
// TestResolve_FileRef_PathTraversal verifies that file:// references cannot escape configDir
// via relative traversal ("../../etc/passwd") or absolute paths ("/abs/path").
func TestResolve_FileRef_PathTraversal(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Create a file outside configDir that the traversal would point to.
outsideFile := filepath.Join(t.TempDir(), "secret.key")
if err := os.WriteFile(outsideFile, []byte("stolen"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(filepath.Dir(cfgPath))
cases := []string{
"file://../../secret.key",
"file://../secret.key",
"file://" + outsideFile, // absolute path
}
for _, raw := range cases {
_, err := r.Resolve(raw)
if err == nil {
t.Errorf("Resolve(%q): expected path traversal error, got nil", raw)
}
}
}
// TestResolve_FileRef_withinConfigDir verifies that a legitimate relative file:// ref works.
func TestResolve_FileRef_withinConfigDir(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "my.key"), []byte("sk-valid\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
got, err := r.Resolve("file://my.key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-valid" {
t.Fatalf("got %q, want %q", got, "sk-valid")
}
}
// TestEncrypt_SSHKeyOutsideAllowedDirs verifies that Encrypt rejects SSH key paths
// that are not under PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/.
func TestEncrypt_SSHKeyOutsideAllowedDirs(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Make sure none of the allowed env vars point here.
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
t.Setenv("PICOCLAW_HOME", "")
_, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
if err == nil {
t.Fatal("expected error for SSH key outside allowed directories, got nil")
}
}
+62
View File
@@ -0,0 +1,62 @@
package credential
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"golang.org/x/crypto/ssh"
)
// DefaultSSHKeyPath returns the canonical path for the picoclaw-specific SSH key.
// The path is always ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform).
func DefaultSSHKeyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("credential: cannot determine home directory: %w", err)
}
return filepath.Join(home, ".ssh", "picoclaw_ed25519.key"), nil
}
// GenerateSSHKey generates an Ed25519 SSH key pair and writes the private key
// to path (permissions 0600) and the public key to path+".pub" (permissions 0644).
// The ~/.ssh/ directory is created with 0700 if it does not exist.
// If the files already exist they are overwritten.
func GenerateSSHKey(path string) error {
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("credential: keygen: cannot create directory %q: %w", filepath.Dir(path), err)
}
pubRaw, privRaw, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("credential: keygen: ed25519 key generation failed: %w", err)
}
// Marshal private key as OpenSSH PEM.
block, err := ssh.MarshalPrivateKey(privRaw, "")
if err != nil {
return fmt.Errorf("credential: keygen: marshal private key: %w", err)
}
privPEM := pem.EncodeToMemory(block)
if err = os.WriteFile(path, privPEM, 0o600); err != nil {
return fmt.Errorf("credential: keygen: write private key %q: %w", path, err)
}
// Marshal public key as authorized_keys line.
sshPub, err := ssh.NewPublicKey(pubRaw)
if err != nil {
return fmt.Errorf("credential: keygen: marshal public key: %w", err)
}
pubLine := ssh.MarshalAuthorizedKey(sshPub)
pubPath := path + ".pub"
if err := os.WriteFile(pubPath, pubLine, 0o644); err != nil {
return fmt.Errorf("credential: keygen: write public key %q: %w", pubPath, err)
}
return nil
}
+115
View File
@@ -0,0 +1,115 @@
package credential
import (
"crypto/ed25519"
"os"
"path/filepath"
"runtime"
"testing"
"golang.org/x/crypto/ssh"
)
func TestGenerateSSHKey_CreatesFiles(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "test_ed25519.key")
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("GenerateSSHKey() error = %v", err)
}
// Private key must exist.
privInfo, err := os.Stat(keyPath)
if err != nil {
t.Fatalf("private key file missing: %v", err)
}
// Check permissions on non-Windows (Windows does not support Unix permission bits).
if runtime.GOOS != "windows" {
if got := privInfo.Mode().Perm(); got != 0o600 {
t.Errorf("private key permissions = %04o, want 0600", got)
}
}
// Public key must exist.
pubPath := keyPath + ".pub"
pubInfo, err := os.Stat(pubPath)
if err != nil {
t.Fatalf("public key file missing: %v", err)
}
if runtime.GOOS != "windows" {
if got := pubInfo.Mode().Perm(); got != 0o644 {
t.Errorf("public key permissions = %04o, want 0644", got)
}
}
// Private key must be parseable as an OpenSSH ed25519 key.
privPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read private key: %v", err)
}
privKey, err := ssh.ParseRawPrivateKey(privPEM)
if err != nil {
t.Fatalf("parse private key: %v", err)
}
if _, ok := privKey.(*ed25519.PrivateKey); !ok {
t.Errorf("private key type = %T, want *ed25519.PrivateKey", privKey)
}
// Public key must be parseable as authorized_keys line.
pubBytes, err := os.ReadFile(pubPath)
if err != nil {
t.Fatalf("read public key: %v", err)
}
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(pubBytes)
if err != nil {
t.Fatalf("parse public key: %v", err)
}
if pubKey == nil {
t.Fatal("expected non-nil public key")
}
if len(rest) > 0 {
t.Errorf("unexpected trailing bytes after public key: %d bytes", len(rest))
}
}
func TestGenerateSSHKey_OverwritesExisting(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "test_ed25519.key")
// Generate twice; second call must not error and must produce a different key.
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("first GenerateSSHKey() error = %v", err)
}
first, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read first key: %v", err)
}
if err = GenerateSSHKey(keyPath); err != nil {
t.Fatalf("second GenerateSSHKey() error = %v", err)
}
second, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read second key: %v", err)
}
// Two independently generated Ed25519 keys must differ.
if string(first) == string(second) {
t.Error("expected overwritten key to differ from original")
}
}
func TestGenerateSSHKey_CreatesDirectory(t *testing.T) {
dir := t.TempDir()
// Nested directory that does not yet exist.
keyPath := filepath.Join(dir, "subdir", ".ssh", "picoclaw_ed25519.key")
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("GenerateSSHKey() error = %v", err)
}
if _, err := os.Stat(keyPath); err != nil {
t.Fatalf("private key not created: %v", err)
}
}
+44
View File
@@ -0,0 +1,44 @@
package credential
import "sync/atomic"
// SecureStore holds a passphrase in memory.
//
// Uses atomic.Pointer so reads and writes are lock-free.
// The passphrase is never written to disk; callers decide how to
// transport it outside this store (e.g., via cmd.Env or os.Environ).
type SecureStore struct {
val atomic.Pointer[string]
}
// NewSecureStore creates an empty SecureStore.
func NewSecureStore() *SecureStore {
return &SecureStore{}
}
// SetString stores the passphrase. An empty string clears the store.
func (s *SecureStore) SetString(passphrase string) {
if passphrase == "" {
s.val.Store(nil)
return
}
s.val.Store(&passphrase)
}
// Get returns the stored passphrase, or "" if not set.
func (s *SecureStore) Get() string {
if p := s.val.Load(); p != nil {
return *p
}
return ""
}
// IsSet reports whether a passphrase is currently stored.
func (s *SecureStore) IsSet() bool {
return s.val.Load() != nil
}
// Clear removes the stored passphrase.
func (s *SecureStore) Clear() {
s.val.Store(nil)
}
+81
View File
@@ -0,0 +1,81 @@
package credential
import (
"sync"
"testing"
)
func TestSecureStore_SetGet(t *testing.T) {
s := NewSecureStore()
if s.IsSet() {
t.Error("expected empty store")
}
s.SetString("hunter2")
if !s.IsSet() {
t.Error("expected store to be set")
}
if got := s.Get(); got != "hunter2" {
t.Errorf("Get() = %q, want %q", got, "hunter2")
}
}
func TestSecureStore_Clear(t *testing.T) {
s := NewSecureStore()
s.SetString("secret")
s.Clear()
if s.IsSet() {
t.Error("expected store to be empty after Clear()")
}
if got := s.Get(); got != "" {
t.Errorf("Get() after Clear() = %q, want empty", got)
}
}
func TestSecureStore_SetOverwrites(t *testing.T) {
s := NewSecureStore()
s.SetString("first")
s.SetString("second")
if got := s.Get(); got != "second" {
t.Errorf("Get() = %q, want %q", got, "second")
}
}
func TestSecureStore_EmptyPassphrase(t *testing.T) {
s := NewSecureStore()
s.SetString("") // empty → should not mark as set
if s.IsSet() {
t.Error("empty passphrase should not mark store as set")
}
}
func TestSecureStore_ConcurrentSetGet(t *testing.T) {
s := NewSecureStore()
const goroutines = 10
const iterations = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
if id%2 == 0 {
s.SetString("even")
} else {
s.SetString("odd")
}
_ = s.Get()
}
}(i)
}
wg.Wait()
final := s.Get()
if final != "" && final != "even" && final != "odd" {
t.Errorf("Get() returned unexpected value %q after concurrent Set/Get", final)
}
}
+594
View File
@@ -0,0 +1,594 @@
package gateway
import (
"context"
"fmt"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
_ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
_ "github.com/sipeed/picoclaw/pkg/channels/discord"
_ "github.com/sipeed/picoclaw/pkg/channels/feishu"
_ "github.com/sipeed/picoclaw/pkg/channels/irc"
_ "github.com/sipeed/picoclaw/pkg/channels/line"
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
_ "github.com/sipeed/picoclaw/pkg/channels/matrix"
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp"
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp_native"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/devices"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
serviceShutdownTimeout = 30 * time.Second
providerReloadTimeout = 30 * time.Second
gracefulShutdownTimeout = 15 * time.Second
)
type services struct {
CronService *cron.CronService
HeartbeatService *heartbeat.HeartbeatService
MediaStore media.MediaStore
ChannelManager *channels.Manager
DeviceService *devices.Service
HealthServer *health.Server
}
type startupBlockedProvider struct {
reason string
}
func (p *startupBlockedProvider) Chat(
_ context.Context,
_ []providers.Message,
_ []providers.ToolDefinition,
_ string,
_ map[string]any,
) (*providers.LLMResponse, error) {
return nil, fmt.Errorf("%s", p.reason)
}
func (p *startupBlockedProvider) GetDefaultModel() string {
return ""
}
// Run starts the gateway runtime using the configuration loaded from configPath.
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
if debug {
logger.SetLevel(logger.DEBUG)
fmt.Println("🔍 Debug mode enabled")
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
return fmt.Errorf("error loading config: %w", err)
}
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
fmt.Println("\n📦 Agent Status:")
startupInfo := agentLoop.GetStartupInfo()
toolsInfo := startupInfo["tools"].(map[string]any)
skillsInfo := startupInfo["skills"].(map[string]any)
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
fmt.Printf(" • Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
logger.InfoCF("agent", "Agent initialized",
map[string]any{
"tools_count": toolsInfo["count"],
"skills_total": skillsInfo["total"],
"skills_available": skillsInfo["available"],
})
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
if err != nil {
return err
}
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
fmt.Println("Press Ctrl+C to stop")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go agentLoop.Run(ctx)
var configReloadChan <-chan *config.Config
stopWatch := func() {}
if cfg.Gateway.HotReload {
configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
logger.Info("Config hot reload enabled")
}
defer stopWatch()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
for {
select {
case <-sigChan:
logger.Info("Shutting down...")
shutdownGateway(runningServices, agentLoop, provider, true)
return nil
case newCfg := <-configReloadChan:
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
if err != nil {
logger.Errorf("Config reload failed: %v", err)
}
}
}
}
func createStartupProvider(
cfg *config.Config,
allowEmptyStartup bool,
) (providers.LLMProvider, string, error) {
modelName := cfg.Agents.Defaults.GetModelName()
if modelName == "" && allowEmptyStartup {
reason := "no default model configured; gateway started in limited mode"
fmt.Printf("⚠ Warning: %s\n", reason)
logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
"limited_mode": true,
})
return &startupBlockedProvider{reason: reason}, "", nil
}
return providers.CreateProvider(cfg)
}
func setupAndStartServices(
cfg *config.Config,
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
) (*services, error) {
runningServices := &services{}
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
var err error
runningServices.CronService, err = setupCronTool(
agentLoop,
msgBus,
cfg.WorkspacePath(),
cfg.Agents.Defaults.RestrictToWorkspace,
execTimeout,
cfg,
)
if err != nil {
return nil, fmt.Errorf("error setting up cron service: %w", err)
}
if err = runningServices.CronService.Start(); err != nil {
return nil, fmt.Errorf("error starting cron service: %w", err)
}
fmt.Println("✓ Cron service started")
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
runningServices.HeartbeatService.SetBus(msgBus)
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
if err = runningServices.HeartbeatService.Start(); err != nil {
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
}
fmt.Println("✓ Heartbeat service started")
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
return nil, fmt.Errorf("error creating channel manager: %w", err)
}
agentLoop.SetChannelManager(runningServices.ChannelManager)
agentLoop.SetMediaStore(runningServices.MediaStore)
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
agentLoop.SetTranscriber(transcriber)
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println("⚠ Warning: No channels enabled")
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return nil, fmt.Errorf("error starting channels: %w", err)
}
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
stateManager := state.NewManager(cfg.WorkspacePath())
runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
runningServices.DeviceService.SetBus(msgBus)
if err = runningServices.DeviceService.Start(context.Background()); err != nil {
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println("✓ Device event service started")
}
return runningServices, nil
}
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer shutdownCancel()
if runningServices.ChannelManager != nil {
runningServices.ChannelManager.StopAll(shutdownCtx)
}
if runningServices.DeviceService != nil {
runningServices.DeviceService.Stop()
}
if runningServices.HeartbeatService != nil {
runningServices.HeartbeatService.Stop()
}
if runningServices.CronService != nil {
runningServices.CronService.Stop()
}
if runningServices.MediaStore != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
}
}
func shutdownGateway(
runningServices *services,
agentLoop *agent.AgentLoop,
provider providers.LLMProvider,
fullShutdown bool,
) {
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
cp.Close()
}
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
agentLoop.Stop()
agentLoop.Close()
logger.Info("✓ Gateway stopped")
}
func handleConfigReload(
ctx context.Context,
al *agent.AgentLoop,
newCfg *config.Config,
providerRef *providers.LLMProvider,
runningServices *services,
msgBus *bus.MessageBus,
allowEmptyStartup bool,
) error {
logger.Info("🔄 Config file changed, reloading...")
newModel := newCfg.Agents.Defaults.ModelName
if newModel == "" {
newModel = newCfg.Agents.Defaults.Model
}
logger.Infof(" New model is '%s', recreating provider...", newModel)
logger.Info(" Stopping all services...")
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
if err != nil {
logger.Errorf(" ⚠ Error creating new provider: %v", err)
logger.Warn(" Attempting to restart services with old provider and config...")
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error creating new provider: %w", err)
}
if newModelID != "" {
newCfg.Agents.Defaults.ModelName = newModelID
}
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
defer reloadCancel()
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
if cp, ok := newProvider.(providers.StatefulProvider); ok {
cp.Close()
}
logger.Warn(" Attempting to restart services with old provider and config...")
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error reloading agent loop: %w", err)
}
*providerRef = newProvider
logger.Info(" Restarting all services with new configuration...")
if err := restartServices(al, runningServices, msgBus); err != nil {
logger.Errorf(" ⚠ Error restarting services: %v", err)
return fmt.Errorf("error restarting services: %w", err)
}
logger.Info(" ✓ Provider, configuration, and services reloaded successfully (thread-safe)")
return nil
}
func restartServices(
al *agent.AgentLoop,
runningServices *services,
msgBus *bus.MessageBus,
) error {
cfg := al.GetConfig()
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
var err error
runningServices.CronService, err = setupCronTool(
al,
msgBus,
cfg.WorkspacePath(),
cfg.Agents.Defaults.RestrictToWorkspace,
execTimeout,
cfg,
)
if err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
if err = runningServices.CronService.Start(); err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
fmt.Println(" ✓ Cron service restarted")
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
runningServices.HeartbeatService.SetBus(msgBus)
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
if err = runningServices.HeartbeatService.Start(); err != nil {
return fmt.Errorf("error restarting heartbeat service: %w", err)
}
fmt.Println(" ✓ Heartbeat service restarted")
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
al.SetMediaStore(runningServices.MediaStore)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
return fmt.Errorf("error recreating channel manager: %w", err)
}
al.SetChannelManager(runningServices.ChannelManager)
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println(" ⚠ Warning: No channels enabled")
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return fmt.Errorf("error restarting channels: %w", err)
}
fmt.Printf(
" ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n",
cfg.Gateway.Host,
cfg.Gateway.Port,
)
stateManager := state.NewManager(cfg.WorkspacePath())
runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
runningServices.DeviceService.SetBus(msgBus)
if err := runningServices.DeviceService.Start(context.Background()); err != nil {
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println(" ✓ Device event service restarted")
}
transcriber := voice.DetectTranscriber(cfg)
al.SetTranscriber(transcriber)
if transcriber != nil {
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
} else {
logger.InfoCF("voice", "Transcription disabled", nil)
}
return nil
}
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
configChan := make(chan *config.Config, 1)
stop := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
lastModTime := getFileModTime(configPath)
lastSize := getFileSize(configPath)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
currentModTime := getFileModTime(configPath)
currentSize := getFileSize(configPath)
if currentModTime.After(lastModTime) || currentSize != lastSize {
if debug {
logger.Debugf("🔍 Config file change detected")
}
time.Sleep(500 * time.Millisecond)
lastModTime = currentModTime
lastSize = currentSize
newCfg, err := config.LoadConfig(configPath)
if err != nil {
logger.Errorf("⚠ Error loading new config: %v", err)
logger.Warn(" Using previous valid config")
continue
}
if err := newCfg.ValidateModelList(); err != nil {
logger.Errorf(" ⚠ New config validation failed: %v", err)
logger.Warn(" Using previous valid config")
continue
}
logger.Info("✓ Config file validated and loaded")
select {
case configChan <- newCfg:
default:
logger.Warn("⚠ Previous config reload still in progress, skipping")
}
}
case <-stop:
return
}
}
}()
stopFunc := func() {
close(stop)
wg.Wait()
}
return configChan, stopFunc
}
func getFileModTime(path string) time.Time {
info, err := os.Stat(path)
if err != nil {
return time.Time{}
}
return info.ModTime()
}
func getFileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {
return 0
}
return info.Size()
}
func setupCronTool(
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
workspace string,
restrict bool,
execTimeout time.Duration,
cfg *config.Config,
) (*cron.CronService, error) {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
cronService := cron.NewCronService(cronStorePath, nil)
var cronTool *tools.CronTool
if cfg.Tools.IsToolEnabled("cron") {
var err error
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
if err != nil {
return nil, fmt.Errorf("critical error during CronTool initialization: %w", err)
}
agentLoop.RegisterTool(cronTool)
}
if cronTool != nil {
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
result := cronTool.ExecuteJob(context.Background(), job)
return result, nil
})
}
return cronService, nil
}
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
return func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
return tools.SilentResult(response)
}
}
+3
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"maps"
"net/http"
"os"
"sync"
"time"
)
@@ -29,6 +30,7 @@ type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
Pid int `json:"pid"`
}
func NewServer(host string, port int) *Server {
@@ -112,6 +114,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
Pid: os.Getpid(),
}
json.NewEncoder(w).Encode(resp)
+58 -7
View File
@@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
@@ -45,6 +46,9 @@ func init() {
consoleWriter := zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: "15:04:05", // TODO: make it configurable???
// Custom formatter to handle multiline strings and JSON objects
FormatFieldValue: formatFieldValue,
}
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
@@ -52,6 +56,37 @@ func init() {
})
}
func formatFieldValue(i any) string {
var s string
switch val := i.(type) {
case string:
s = val
case []byte:
s = string(val)
default:
return fmt.Sprintf("%v", i)
}
if unquoted, err := strconv.Unquote(s); err == nil {
s = unquoted
}
if strings.Contains(s, "\n") {
return fmt.Sprintf("\n%s", s)
}
if strings.Contains(s, " ") {
if (strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}")) ||
(strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]")) {
return s
}
return fmt.Sprintf("%q", s)
}
return s
}
func SetLevel(level LogLevel) {
mu.Lock()
defer mu.Unlock()
@@ -163,10 +198,7 @@ func logMessage(level LogLevel, component string, message string, fields map[str
event.Str("caller", fmt.Sprintf("<none> %s:%d (%s)", callerFile, callerLine, callerFunc))
}
for k, v := range fields {
event.Interface(k, v)
}
appendFields(event, fields)
event.Msg(message)
// Also log to file if enabled
@@ -176,9 +208,8 @@ func logMessage(level LogLevel, component string, message string, fields map[str
if component != "" {
fileEvent.Str("component", component)
}
for k, v := range fields {
fileEvent.Interface(k, v)
}
appendFields(event, fields)
fileEvent.Msg(message)
}
@@ -187,6 +218,26 @@ func logMessage(level LogLevel, component string, message string, fields map[str
}
}
func appendFields(event *zerolog.Event, fields map[string]any) {
for k, v := range fields {
// Type switch to avoid double JSON serialization of strings
switch val := v.(type) {
case string:
event.Str(k, val)
case int:
event.Int(k, val)
case int64:
event.Int64(k, val)
case float64:
event.Float64(k, val)
case bool:
event.Bool(k, val)
default:
event.Interface(k, v) // Fallback for struct, slice and maps
}
}
}
func Debug(message string) {
logMessage(DEBUG, "", message, nil)
}
+25 -12
View File
@@ -2,7 +2,20 @@
package logger
import "fmt"
import (
"fmt"
"regexp"
)
// botTokenRe matches the bot ID prefix and the secret part of a Telegram bot token.
// Groups: 1 = "bot<id>:", 2 = first 4 chars of secret, 3 = middle, 4 = last 4 chars.
var botTokenRe = regexp.MustCompile(`(bot\d+:)([A-Za-z0-9_-]{4})[A-Za-z0-9_-]{12,}([A-Za-z0-9_-]{4})`)
// maskSecrets replaces any embedded bot tokens in s with a redacted placeholder
// that keeps the first and last 4 characters of the secret for identification.
func maskSecrets(s string) string {
return botTokenRe.ReplaceAllString(s, "${1}${2}****${3}")
}
// Logger implements common Logger interface
type Logger struct {
@@ -12,52 +25,52 @@ type Logger struct {
// Debug logs debug messages
func (b *Logger) Debug(v ...any) {
logMessage(DEBUG, b.component, fmt.Sprint(v...), nil)
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Info logs info messages
func (b *Logger) Info(v ...any) {
logMessage(INFO, b.component, fmt.Sprint(v...), nil)
logMessage(INFO, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Warn logs warning messages
func (b *Logger) Warn(v ...any) {
logMessage(WARN, b.component, fmt.Sprint(v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Error logs error messages
func (b *Logger) Error(v ...any) {
logMessage(ERROR, b.component, fmt.Sprint(v...), nil)
logMessage(ERROR, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Debugf logs formatted debug messages
func (b *Logger) Debugf(format string, v ...any) {
logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil)
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Infof logs formatted info messages
func (b *Logger) Infof(format string, v ...any) {
logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil)
logMessage(INFO, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Warnf logs formatted warning messages
func (b *Logger) Warnf(format string, v ...any) {
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Warningf logs formatted warning messages
func (b *Logger) Warningf(format string, v ...any) {
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Errorf logs formatted error messages
func (b *Logger) Errorf(format string, v ...any) {
logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil)
logMessage(ERROR, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Fatalf logs formatted fatal messages and exits
func (b *Logger) Fatalf(format string, v ...any) {
logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil)
logMessage(FATAL, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Log logs a message at a given level with caller information
@@ -75,7 +88,7 @@ func (b *Logger) Log(msgL, caller int, format string, a ...any) {
level = lvl
}
}
logMessage(level, b.component, fmt.Sprintf(format, a...), nil)
logMessage(level, b.component, maskSecrets(fmt.Sprintf(format, a...)), nil)
}
// Sync flushes log buffer (no-op for this implementation)
+111
View File
@@ -141,3 +141,114 @@ func TestLoggerHelperFunctions(t *testing.T) {
Debugf("test from %v", "Debugf")
WarnF("Warning with fields", map[string]any{"key": "value"})
}
func TestFormatFieldValue(t *testing.T) {
tests := []struct {
name string
input any
expected string
}{
// Basic types test (default case of the switch)
{
name: "Integer Type",
input: 42,
expected: "42",
},
{
name: "Boolean Type",
input: true,
expected: "true",
},
{
name: "Unsupported Struct Type",
input: struct{ A int }{A: 1},
expected: "{1}",
},
// Simple strings and byte slices test
{
name: "Simple string without spaces",
input: "simple_value",
expected: "simple_value",
},
{
name: "Simple byte slice",
input: []byte("byte_value"),
expected: "byte_value",
},
// Unquoting test (strconv.Unquote)
{
name: "Quoted string",
input: `"quoted_value"`,
expected: "quoted_value",
},
// Strings with newline (\n) test
{
name: "String with newline",
input: "line1\nline2",
expected: "\nline1\nline2",
},
{
name: "Quoted string with newline (Unquote -> newline)",
input: `"line1\nline2"`, // Escaped \n that Unquote will resolve
expected: "\nline1\nline2",
},
// Strings with spaces test (which should be quoted)
{
name: "String with spaces",
input: "hello world",
expected: `"hello world"`,
},
{
name: "Quoted string with spaces (Unquote -> has spaces -> Re-quote)",
input: `"hello world"`,
expected: `"hello world"`,
},
// JSON formats test (strings with spaces that start/end with brackets)
{
name: "Valid JSON object",
input: `{"key": "value"}`,
expected: `{"key": "value"}`,
},
{
name: "Valid JSON array",
input: `[1, 2, "three"]`,
expected: `[1, 2, "three"]`,
},
{
name: "Fake JSON (starts with { but doesn't end with })",
input: `{"key": "value"`, // Missing closing bracket, has spaces
expected: `"{\"key\": \"value\""`,
},
{
name: "Empty JSON (object)",
input: `{ }`,
expected: `{ }`,
},
// 7. Edge Cases
{
name: "Empty string",
input: "",
expected: "",
},
{
name: "Whitespace only string",
input: " ",
expected: `" "`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := formatFieldValue(tt.input)
if actual != tt.expected {
t.Errorf("formatFieldValue() = %q, expected %q", actual, tt.expected)
}
})
}
}
+13
View File
@@ -0,0 +1,13 @@
package media
import (
"os"
"path/filepath"
)
const TempDirName = "picoclaw_media"
// TempDir returns the shared temporary directory used for downloaded media.
func TempDir() string {
return filepath.Join(os.TempDir(), TempDirName)
}
+7 -1
View File
@@ -221,11 +221,17 @@ func buildRequestBody(
// Add tool_use blocks
for _, tc := range msg.ToolCalls {
// Handle nil Arguments (GLM-4 may return null input)
input := tc.Arguments
if input == nil {
input = map[string]any{}
}
toolUse := map[string]any{
"type": "tool_use",
"id": tc.ID,
"name": tc.Name,
"input": tc.Arguments,
"input": input,
}
content = append(content, toolUse)
}
+150
View File
@@ -0,0 +1,150 @@
package azure
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
LLMResponse = protocoltypes.LLMResponse
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
)
const (
// azureAPIVersion is the Azure OpenAI API version used for all requests.
azureAPIVersion = "2024-10-21"
defaultRequestTimeout = common.DefaultRequestTimeout
)
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
// It handles Azure-specific authentication (api-key header), URL construction
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
}
// Option configures the Azure Provider.
type Option func(*Provider)
// WithRequestTimeout sets the HTTP request timeout.
func WithRequestTimeout(timeout time.Duration) Option {
return func(p *Provider) {
if timeout > 0 {
p.httpClient.Timeout = timeout
}
}
}
// NewProvider creates a new Azure OpenAI provider.
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: common.NewHTTPClient(proxy),
}
for _, opt := range opts {
if opt != nil {
opt(p)
}
}
return p
}
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
return NewProvider(
apiKey, apiBase, proxy,
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
)
}
// Chat sends a chat completion request to the Azure OpenAI endpoint.
// The model parameter is used as the Azure deployment name in the URL.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("Azure API base not configured")
}
// model is the deployment name for Azure OpenAI
deployment := model
// Build Azure-specific URL safely using url.JoinPath and query encoding
// to prevent path traversal or query injection via deployment names.
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
if err != nil {
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
}
requestURL := base + "?api-version=" + azureAPIVersion
// Build request body — no "model" field (Azure infers from deployment URL)
requestBody := map[string]any{
"messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
// Azure OpenAI always uses max_completion_tokens
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
requestBody["max_completion_tokens"] = maxTokens
}
if temperature, ok := common.AsFloat(options["temperature"]); ok {
requestBody["temperature"] = temperature
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Azure uses api-key header instead of Authorization: Bearer
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Api-Key", p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
return common.ReadAndParseResponse(resp, p.apiBase)
}
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""
}
+232
View File
@@ -0,0 +1,232 @@
package azure
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
func writeValidResponse(w http.ResponseWriter) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func TestProviderChat_AzureURLConstruction(t *testing.T) {
var capturedPath string
var capturedAPIVersion string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAPIVersion = r.URL.Query().Get("api-version")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
if capturedPath != wantPath {
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
}
if capturedAPIVersion != azureAPIVersion {
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
}
}
func TestProviderChat_AzureAuthHeader(t *testing.T) {
var capturedAPIKey string
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAPIKey = r.Header.Get("Api-Key")
capturedAuth = r.Header.Get("Authorization")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-azure-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if capturedAPIKey != "test-azure-key" {
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
}
if capturedAuth != "" {
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
}
}
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["model"]; exists {
t.Error("request body should not contain 'model' field for Azure OpenAI")
}
}
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"deployment",
map[string]any{"max_tokens": 2048},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["max_completion_tokens"]; !exists {
t.Error("request body should contain 'max_completion_tokens'")
}
if _, exists := requestBody["max_tokens"]; exists {
t.Error("request body should not contain 'max_tokens'")
}
}
func TestProviderChat_AzureHTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
}))
defer server.Close()
p := NewProvider("bad-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
},
},
"finish_reason": "tool_calls",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
p := NewProvider("test-key", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
}
}
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
}
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
var capturedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
if capturedPath == "" {
capturedPath = r.URL.Path
}
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
// Deployment name with characters that could cause path injection
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// The slash and special chars in the deployment name must be escaped, not treated as path separators
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
t.Fatal("deployment name was interpolated without escaping — path injection possible")
}
}
+380
View File
@@ -0,0 +1,380 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package common provides shared utilities used by multiple LLM provider
// implementations (openai_compat, azure, etc.).
package common
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// Re-export protocol types used across providers.
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
ReasoningDetail = protocoltypes.ReasoningDetail
)
const DefaultRequestTimeout = 120 * time.Second
// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
func NewHTTPClient(proxy string) *http.Client {
client := &http.Client{
Timeout: DefaultRequestTimeout,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
// Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
if base, ok := http.DefaultTransport.(*http.Transport); ok {
tr := base.Clone()
tr.Proxy = http.ProxyURL(parsed)
client.Transport = tr
} else {
// Fallback: minimal transport if DefaultTransport is not *http.Transport.
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
}
} else {
log.Printf("common: invalid proxy URL %q: %v", proxy, err)
}
}
return client
}
// --- Message serialization ---
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// SerializeMessages converts internal Message structs to the OpenAI wire format.
// - Strips SystemParts (unknown to third-party endpoints)
// - Converts messages with Media to multipart content format (text + image_url parts)
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
func SerializeMessages(messages []Message) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
out = append(out, openaiMessage{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.ReasoningContent,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
})
continue
}
// Multipart content format for messages with media
parts := make([]map[string]any, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]any{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
out = append(out, msg)
}
return out
}
// --- Response parsing ---
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
func ParseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
}
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return arguments
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
arguments["raw"] = string(raw)
return arguments
}
switch v := decoded.(type) {
case string:
if strings.TrimSpace(v) == "" {
return arguments
}
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = v
}
return arguments
case map[string]any:
return v
default:
log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
arguments["raw"] = string(raw)
return arguments
}
}
// --- HTTP response helpers ---
// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
func HandleErrorResponse(resp *http.Response, apiBase string) error {
contentType := resp.Header.Get("Content-Type")
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return fmt.Errorf("failed to read response: %w", readErr)
}
if LooksLikeHTML(body, contentType) {
return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
}
return fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
ResponsePreview(body, 128),
)
}
// ReadAndParseResponse peeks at the response body to detect HTML errors,
// then parses the JSON response into an LLMResponse.
func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
contentType := resp.Header.Get("Content-Type")
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256)
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if LooksLikeHTML(prefix, contentType) {
return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
}
out, err := ParseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
// LooksLikeHTML checks if the response body appears to be HTML.
func LooksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
// WrapHTMLResponseError creates a descriptive error for HTML responses.
func WrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := ResponsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}
// ResponsePreview returns a truncated preview of response body for error messages.
func ResponsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
// --- Numeric helpers ---
// AsInt converts various numeric types to int.
func AsInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
// AsFloat converts various numeric types to float64.
func AsFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
+558
View File
@@ -0,0 +1,558 @@
package common
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// --- NewHTTPClient tests ---
func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
client := NewHTTPClient("")
if client.Timeout != DefaultRequestTimeout {
t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
}
}
func TestNewHTTPClient_WithProxy(t *testing.T) {
client := NewHTTPClient("http://127.0.0.1:8080")
transport, ok := client.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
}
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
gotProxy, err := transport.Proxy(req)
if err != nil {
t.Fatalf("proxy function error: %v", err)
}
if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
}
}
func TestNewHTTPClient_NoProxy(t *testing.T) {
client := NewHTTPClient("")
if client.Transport != nil {
t.Errorf("expected nil transport without proxy, got %T", client.Transport)
}
}
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
// Should not panic, just log and return client without proxy
client := NewHTTPClient("://bad-url")
if client == nil {
t.Fatal("expected non-nil client even with invalid proxy")
}
}
// --- SerializeMessages tests ---
func TestSerializeMessages_PlainText(t *testing.T) {
messages := []Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["content"] != "hello" {
t.Errorf("expected plain string content, got %v", msgs[0]["content"])
}
if msgs[1]["reasoning_content"] != "thinking..." {
t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
}
}
func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
content, ok := msgs[0]["content"].([]any)
if !ok {
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 content parts, got %d", len(content))
}
}
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["tool_call_id"] != "call_1" {
t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
}
}
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
messages := []Message{
{
Role: "system",
Content: "you are helpful",
SystemParts: []protocoltypes.ContentBlock{
{Type: "text", Text: "you are helpful"},
},
},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
if strings.Contains(string(data), "system_parts") {
t.Error("system_parts should not appear in serialized output")
}
}
// --- ParseResponse tests ---
func TestParseResponse_BasicContent(t *testing.T) {
body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Content != "hello world" {
t.Errorf("Content = %q, want %q", out.Content, "hello world")
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
}
func TestParseResponse_EmptyChoices(t *testing.T) {
body := `{"choices":[]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Content != "" {
t.Errorf("Content = %q, want empty", out.Content)
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
}
func TestParseResponse_WithToolCalls(t *testing.T) {
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
}
func TestParseResponse_WithUsage(t *testing.T) {
body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Usage == nil {
t.Fatal("Usage is nil")
}
if out.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
}
}
func TestParseResponse_WithReasoningContent(t *testing.T) {
body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.ReasoningContent != "Let me think... 1+1=2" {
t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
}
}
func TestParseResponse_InvalidJSON(t *testing.T) {
_, err := ParseResponse(strings.NewReader("not json"))
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- DecodeToolCallArguments tests ---
func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
args := DecodeToolCallArguments(raw, "test")
if args["city"] != "Seattle" {
t.Errorf("city = %v, want Seattle", args["city"])
}
if args["units"] != "metric" {
t.Errorf("units = %v, want metric", args["units"])
}
}
func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
args := DecodeToolCallArguments(raw, "test")
if args["city"] != "SF" {
t.Errorf("city = %v, want SF", args["city"])
}
}
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
args := DecodeToolCallArguments(nil, "test")
if len(args) != 0 {
t.Errorf("expected empty map, got %v", args)
}
}
func TestDecodeToolCallArguments_NullInput(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
if len(args) != 0 {
t.Errorf("expected empty map, got %v", args)
}
}
func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
if _, ok := args["raw"]; !ok {
t.Error("expected 'raw' fallback key for invalid JSON")
}
}
func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
if len(args) != 0 {
t.Errorf("expected empty map for whitespace string, got %v", args)
}
}
// --- HandleErrorResponse tests ---
func TestHandleErrorResponse_JSONError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"bad request"}`))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("error should contain status code, got %v", err)
}
if strings.Contains(err.Error(), "HTML") {
t.Errorf("should not mention HTML for JSON error, got %v", err)
}
}
func TestHandleErrorResponse_HTMLError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte("<!DOCTYPE html><html><body>bad gateway</body></html>"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "HTML instead of JSON") {
t.Errorf("expected HTML error message, got %v", err)
}
}
// --- ReadAndParseResponse tests ---
func TestReadAndParseResponse_ValidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
out, err := ReadAndParseResponse(resp, server.URL)
if err != nil {
t.Fatalf("ReadAndParseResponse() error = %v", err)
}
if out.Content != "ok" {
t.Errorf("Content = %q, want %q", out.Content, "ok")
}
}
func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("<!DOCTYPE html><html><body>login page</body></html>"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
_, err = ReadAndParseResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error for HTML response")
}
if !strings.Contains(err.Error(), "HTML instead of JSON") {
t.Errorf("expected HTML error, got %v", err)
}
}
// --- LooksLikeHTML tests ---
func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
t.Error("expected true for text/html content type")
}
}
func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
if !LooksLikeHTML(nil, "application/xhtml+xml") {
t.Error("expected true for xhtml content type")
}
}
func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
tests := []struct {
name string
body string
}{
{"doctype", "<!DOCTYPE html><html>"},
{"html tag", "<html><body>"},
{"head tag", "<head><title>"},
{"body tag", "<body>content"},
{"whitespace before", " \n\t<!DOCTYPE html>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !LooksLikeHTML([]byte(tt.body), "application/json") {
t.Errorf("expected true for body %q", tt.body)
}
})
}
}
func TestLooksLikeHTML_NotHTML(t *testing.T) {
if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
t.Error("expected false for JSON body")
}
}
// --- ResponsePreview tests ---
func TestResponsePreview_Short(t *testing.T) {
got := ResponsePreview([]byte("hello"), 128)
if got != "hello" {
t.Errorf("got %q, want %q", got, "hello")
}
}
func TestResponsePreview_Truncated(t *testing.T) {
body := strings.Repeat("a", 200)
got := ResponsePreview([]byte(body), 128)
if len(got) != 131 { // 128 + "..."
t.Errorf("len = %d, want 131", len(got))
}
if !strings.HasSuffix(got, "...") {
t.Error("expected ... suffix")
}
}
func TestResponsePreview_Empty(t *testing.T) {
got := ResponsePreview([]byte(""), 128)
if got != "<empty>" {
t.Errorf("got %q, want %q", got, "<empty>")
}
}
func TestResponsePreview_Whitespace(t *testing.T) {
got := ResponsePreview([]byte(" \n\t "), 128)
if got != "<empty>" {
t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>")
}
}
// --- AsInt tests ---
func TestAsInt(t *testing.T) {
tests := []struct {
name string
val any
want int
ok bool
}{
{"int", 42, 42, true},
{"int64", int64(99), 99, true},
{"float64", float64(512), 512, true},
{"float32", float32(256), 256, true},
{"string", "nope", 0, false},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := AsInt(tt.val)
if ok != tt.ok || got != tt.want {
t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
}
})
}
}
// --- AsFloat tests ---
func TestAsFloat(t *testing.T) {
tests := []struct {
name string
val any
want float64
ok bool
}{
{"float64", float64(0.7), 0.7, true},
{"float32", float32(0.5), float64(float32(0.5)), true},
{"int", 1, 1.0, true},
{"int64", int64(100), 100.0, true},
{"string", "nope", 0, false},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := AsFloat(tt.val)
if ok != tt.ok || got != tt.want {
t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
}
})
}
}
// --- WrapHTMLResponseError tests ---
func TestWrapHTMLResponseError(t *testing.T) {
err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com")
if err == nil {
t.Fatal("expected error")
}
msg := err.Error()
if !strings.Contains(msg, "502") {
t.Errorf("expected status code in error, got %v", msg)
}
if !strings.Contains(msg, "https://api.example.com") {
t.Errorf("expected api base in error, got %v", msg)
}
if !strings.Contains(msg, "HTML instead of JSON") {
t.Errorf("expected HTML mention in error, got %v", msg)
}
}
// --- HandleErrorResponse with read failure ---
func TestHandleErrorResponse_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
// empty body
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("expected status code, got %v", err)
}
}
// --- ReadAndParseResponse with invalid JSON ---
func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("not valid json"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
_, err = ReadAndParseResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- ParseResponse with thought_signature (Google/Gemini) ---
func TestParseResponse_WithThoughtSignature(t *testing.T) {
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].ThoughtSignature != "sig123" {
t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
}
if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
t.Fatal("ExtraContent.Google is nil")
}
if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
}
}
+19
View File
@@ -11,6 +11,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
"github.com/sipeed/picoclaw/pkg/providers/azure"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
@@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
case "azure", "azure-openai":
// Azure OpenAI uses deployment-based URLs, api-key header auth,
// and always sends max_completion_tokens.
if cfg.APIKey == "" {
return nil, "", fmt.Errorf("api_key is required for azure protocol")
}
if cfg.APIBase == "" {
return nil, "", fmt.Errorf(
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
)
}
return azure.NewProviderWithTimeout(
cfg.APIKey,
cfg.APIBase,
cfg.Proxy,
cfg.RequestTimeout,
), modelID, nil
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
+72
View File
@@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) {
wantProtocol: "nvidia",
wantModelID: "meta/llama-3.1-8b",
},
{
name: "azure with prefix",
model: "azure/my-gpt5-deployment",
wantProtocol: "azure",
wantModelID: "my-gpt5-deployment",
},
}
for _, tt := range tests {
@@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
}
}
func TestCreateProviderFromConfig_Azure(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "my-gpt5-deployment" {
t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
}
}
func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt4",
Model: "azure-openai/my-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "my-deployment" {
t.Errorf("modelID = %q, want %q", modelID, "my-deployment")
}
}
func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIBase: "https://my-resource.openai.azure.com",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
}
}
func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
}
}
+8 -319
View File
@@ -1,18 +1,16 @@
package openai_compat
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -38,7 +36,7 @@ type Provider struct {
type Option func(*Provider)
const defaultRequestTimeout = 120 * time.Second
const defaultRequestTimeout = common.DefaultRequestTimeout
func WithMaxTokensField(maxTokensField string) Option {
return func(p *Provider) {
@@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
client := &http.Client{
Timeout: defaultRequestTimeout,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
httpClient: common.NewHTTPClient(proxy),
}
for _, opt := range opts {
@@ -117,7 +100,7 @@ func (p *Provider) Chat(
requestBody := map[string]any{
"model": model,
"messages": serializeMessages(messages),
"messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
@@ -125,7 +108,7 @@ func (p *Provider) Chat(
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
@@ -141,7 +124,7 @@ func (p *Provider) Chat(
requestBody[fieldName] = maxTokens
}
if temperature, ok := asFloat(options["temperature"]); ok {
if temperature, ok := common.AsFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
@@ -185,275 +168,11 @@ func (p *Provider) Chat(
}
defer resp.Body.Close()
contentType := resp.Header.Get("Content-Type")
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return nil, fmt.Errorf("failed to read response: %w", readErr)
}
if looksLikeHTML(body, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
}
return nil, fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
responsePreview(body, 128),
)
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
// Peek without consuming so the full stream reaches the JSON decoder.
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if looksLikeHTML(prefix, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
}
out, err := parseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := responsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}
func looksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
func responsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}
func parseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
}
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return arguments
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
arguments["raw"] = string(raw)
return arguments
}
switch v := decoded.(type) {
case string:
if strings.TrimSpace(v) == "" {
return arguments
}
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = v
}
return arguments
case map[string]any:
return v
default:
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
arguments["raw"] = string(raw)
return arguments
}
}
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// serializeMessages converts internal Message structs to the OpenAI wire format.
// - Strips SystemParts (unknown to third-party endpoints)
// - Converts messages with Media to multipart content format (text + image_url parts)
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
func serializeMessages(messages []Message) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
out = append(out, openaiMessage{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.ReasoningContent,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
})
continue
}
// Multipart content format for messages with media
parts := make([]map[string]any, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]any{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
out = append(out, msg)
}
return out
return common.ReadAndParseResponse(resp, p.apiBase)
}
func normalizeModel(model, apiBase string) string {
@@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string {
}
}
func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API and Azure OpenAI support this. All other OpenAI-compatible providers
+5 -4
View File
@@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) {
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, err := json.Marshal(result)
if err != nil {
@@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
},
},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
raw := string(data)
+53 -20
View File
@@ -20,10 +20,12 @@ type JobExecutor interface {
// CronTool provides scheduling capabilities for the agent
type CronTool struct {
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
allowCommand bool
execEnabled bool
}
// NewCronTool creates a new CronTool
@@ -32,17 +34,32 @@ func NewCronTool(
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
execTimeout time.Duration, config *config.Config,
) (*CronTool, error) {
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
if err != nil {
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
allowCommand := true
execEnabled := true
if config != nil {
allowCommand = config.Tools.Cron.AllowCommand
execEnabled = config.Tools.Exec.Enabled
}
execTool.SetTimeout(execTimeout)
var execTool *ExecTool
if execEnabled {
var err error
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
if err != nil {
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
}
}
if execTool != nil {
execTool.SetTimeout(execTimeout)
}
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
allowCommand: allowCommand,
execEnabled: execEnabled,
}, nil
}
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
},
"command_confirm": map[string]any{
"type": "boolean",
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
},
"at_seconds": map[string]any{
"type": "integer",
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
},
"deliver": map[string]any{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
},
},
"required": []string{"action"},
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
}
// Read deliver parameter, default to true
deliver := true
// Read deliver parameter, default to false so scheduled tasks execute through the agent
deliver := false
if d, ok := args["deliver"].(bool); ok {
deliver = d
}
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
// Non-command reminders (plain messages) remain open to all channels.
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
// allow_command is disabled, explicit confirmation is required as an override.
// Non-command reminders remain open to all channels.
command, _ := args["command"].(string)
commandConfirm, _ := args["command_confirm"].(bool)
if command != "" {
if !t.execEnabled {
return ErrorResult("command execution is disabled")
}
if !constants.IsInternalChannel(channel) {
return ErrorResult("scheduling command execution is restricted to internal channels")
}
if !commandConfirm {
return ErrorResult("command_confirm=true is required to schedule command execution")
if !t.allowCommand && !commandConfirm {
return ErrorResult("command_confirm=true is required when allow_command is disabled")
}
deliver = false
}
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// Execute command if present
if job.Payload.Command != "" {
if !t.execEnabled || t.execTool == nil {
output := "Error executing scheduled command: command execution is disabled"
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: output,
})
return "ok"
}
args := map[string]any{
"command": job.Payload.Command,
"__channel": channel,
+126 -6
View File
@@ -5,18 +5,18 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
)
func newTestCronTool(t *testing.T) *CronTool {
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
t.Helper()
storePath := filepath.Join(t.TempDir(), "cron.json")
cronService := cron.NewCronService(storePath, nil)
msgBus := bus.NewMessageBus()
cfg := config.DefaultConfig()
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
if err != nil {
t.Fatalf("NewCronTool() error: %v", err)
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
return tool
}
func newTestCronTool(t *testing.T) *CronTool {
t.Helper()
return newTestCronToolWithConfig(t, config.DefaultConfig())
}
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
tool := newTestCronTool(t)
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
}
}
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
"at_seconds": float64(60),
})
if result.IsError {
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
}
}
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Cron.AllowCommand = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected error when command_confirm is missing")
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
}
if !strings.Contains(result.ForLLM, "command_confirm=true") {
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
}
}
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Cron.AllowCommand = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if result.IsError {
t.Fatalf(
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
result.ForLLM,
)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
}
}
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Exec.Enabled = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected command scheduling to be blocked when exec is disabled")
}
if !strings.Contains(result.ForLLM, "command execution is disabled") {
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
}
}
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
}
}
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "send me a poem",
"at_seconds": float64(600),
})
if result.IsError {
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
}
jobs := tool.cronService.ListJobs(false)
if len(jobs) != 1 {
t.Fatalf("expected 1 job, got %d", len(jobs))
}
if jobs[0].Payload.Deliver {
t.Fatal("expected deliver=false by default for non-command jobs")
}
}
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Exec.Enabled = false
tool := newTestCronToolWithConfig(t, cfg)
job := &cron.CronJob{}
job.Payload.Channel = "cli"
job.Payload.To = "direct"
job.Payload.Command = "df -h"
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
t.Fatalf("ExecuteJob() = %q, want ok", got)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
}
if !strings.Contains(msg.Content, "command execution is disabled") {
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
}
}
+161 -9
View File
@@ -20,8 +20,7 @@ import (
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
if workspace == "" {
return path, fmt.Errorf("workspace is not defined")
}
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
}
if restrict {
if isAllowedPath(absPath, patterns) {
return absPath, nil
}
if !isWithinWorkspace(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
return absPath, nil
}
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
if len(patterns) == 0 {
return false
}
cleaned := filepath.Clean(path)
if !filepath.IsAbs(cleaned) {
return false
}
if !matchesAllowedPath(cleaned, patterns) {
return false
}
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
if err != nil {
return false
}
return matchesAllowedPath(resolved, patterns)
}
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
cleaned := filepath.Clean(path)
for _, pattern := range patterns {
if pattern.MatchString(cleaned) {
return true
}
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
return true
}
}
return false
}
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
raw := pattern.String()
if !strings.HasPrefix(raw, "^") {
return "", false
}
literal := strings.TrimPrefix(raw, "^")
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
literal = strings.TrimSuffix(literal, "(?:/|$)")
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
// Reject patterns that still contain regex operators after removing the
// optional anchored-directory suffix. That keeps arbitrary regex behavior
// unchanged and only enables normalized prefix matching for literal paths.
if containsUnescapedRegexMeta(literal) {
return "", false
}
unescaped, ok := unescapeRegexLiteral(literal)
if !ok || unescaped == "" {
return "", false
}
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
}
func appendUniquePath(paths []string, path string) []string {
for _, existing := range paths {
if existing == path {
return paths
}
}
return append(paths, path)
}
func containsUnescapedRegexMeta(s string) bool {
escaped := false
for _, r := range s {
if escaped {
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
switch r {
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
return true
}
}
return escaped
}
func unescapeRegexLiteral(s string) (string, bool) {
var b strings.Builder
b.Grow(len(s))
escaped := false
for _, r := range s {
if escaped {
b.WriteRune(r)
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
b.WriteRune(r)
}
if escaped {
return "", false
}
return b.String(), true
}
func isWithinAllowedRoot(path, root string) bool {
candidate := filepath.Clean(path)
allowedVariants := []string{filepath.Clean(root)}
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
}
for _, allowedRoot := range allowedVariants {
if isWithinWorkspace(candidate, allowedRoot) {
return true
}
}
return false
}
func resolveExistingAncestor(path string) (string, error) {
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
if resolved, err := filepath.EvalSymlinks(current); err == nil {
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
}
}
func resolvePathAgainstExistingAncestor(path string) (string, error) {
cleaned := filepath.Clean(path)
for current := cleaned; ; current = filepath.Dir(current) {
resolved, err := filepath.EvalSymlinks(current)
if err == nil {
suffix, relErr := filepath.Rel(current, cleaned)
if relErr != nil {
return "", relErr
}
if suffix == "." {
return filepath.Clean(resolved), nil
}
return filepath.Clean(filepath.Join(resolved, suffix)), nil
}
if !os.IsNotExist(err) {
return "", err
}
if filepath.Dir(current) == current {
return "", os.ErrNotExist
}
}
}
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
return err == nil && filepath.IsLocal(rel)
return err == nil && (rel == "." || filepath.IsLocal(rel))
}
type ReadFileTool struct {
@@ -625,12 +782,7 @@ type whitelistFs struct {
}
func (w *whitelistFs) matches(path string) bool {
for _, p := range w.patterns {
if p.MatchString(path) {
return true
}
}
return false
return isAllowedPath(path, w.patterns)
}
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
+84
View File
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
}
}
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
workspace := t.TempDir()
allowedDir := t.TempDir()
secretDir := t.TempDir()
secretFile := filepath.Join(secretDir, "secret.txt")
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
t.Fatalf("WriteFile(secretFile) error = %v", err)
}
linkPath := filepath.Join(allowedDir, "link_out")
if err := os.Symlink(secretDir, linkPath); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
if !result.IsError {
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
}
}
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
workspace := t.TempDir()
rootDir := t.TempDir()
allowedDir := filepath.Join(rootDir, "allowed")
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewWriteFileTool(workspace, true, patterns)
result := tool.Execute(context.Background(), map[string]any{
"path": targetFile,
"content": "outside write",
})
if result.IsError {
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
}
data, err := os.ReadFile(targetFile)
if err != nil {
t.Fatalf("ReadFile(targetFile) error = %v", err)
}
if string(data) != "outside write" {
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
}
}
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
workspace := t.TempDir()
realDir := t.TempDir()
linkParent := t.TempDir()
allowedAlias := filepath.Join(linkParent, "allowed-link")
if err := os.Symlink(realDir, allowedAlias); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
}
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
t.Fatalf("WriteFile(targetFile) error = %v", err)
}
patterns := []*regexp.Regexp{
regexp.MustCompile(
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
),
}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
if result.IsError {
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "through alias") {
t.Fatalf("expected file content, got: %s", result.ForLLM)
}
}
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
// by reading a file in multiple chunks using 'offset' and 'length'.
func TestReadFileTool_ChunkedReading(t *testing.T) {
+15 -2
View File
@@ -6,6 +6,7 @@ import (
"mime"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/h2non/filetype"
@@ -21,20 +22,32 @@ type SendFileTool struct {
restrict bool
maxFileSize int
mediaStore media.MediaStore
allowPaths []*regexp.Regexp
defaultChannel string
defaultChatID string
}
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
func NewSendFileTool(
workspace string,
restrict bool,
maxFileSize int,
store media.MediaStore,
allowPaths ...[]*regexp.Regexp,
) *SendFileTool {
if maxFileSize <= 0 {
maxFileSize = config.DefaultMaxMediaSize
}
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &SendFileTool{
workspace: workspace,
restrict: restrict,
maxFileSize: maxFileSize,
mediaStore: store,
allowPaths: patterns,
}
}
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult("media store not configured")
}
resolved, err := validatePath(path, t.workspace, t.restrict)
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
if err != nil {
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
}
+39
View File
@@ -4,6 +4,7 @@ import (
"context"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
}
}
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
if err != nil {
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
}
testPath := testFile.Name()
if _, err := testFile.WriteString("forward me"); err != nil {
testFile.Close()
t.Fatalf("WriteString(testFile) error = %v", err)
}
if err := testFile.Close(); err != nil {
t.Fatalf("Close(testFile) error = %v", err)
}
t.Cleanup(func() { _ = os.Remove(testPath) })
pattern := regexp.MustCompile(
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
)
store := media.NewFileMediaStore()
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
if result.IsError {
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
}
func TestDetectMediaType_MagicBytes(t *testing.T) {
dir := t.TempDir()
+31 -13
View File
@@ -23,6 +23,7 @@ type ExecTool struct {
denyPatterns []*regexp.Regexp
allowPatterns []*regexp.Regexp
customAllowPatterns []*regexp.Regexp
allowedPathPatterns []*regexp.Regexp
restrictToWorkspace bool
allowRemote bool
}
@@ -95,14 +96,23 @@ var (
}
)
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
}
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
func NewExecToolWithConfig(
workingDir string,
restrict bool,
config *config.Config,
allowPaths ...[]*regexp.Regexp,
) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
customAllowPatterns := make([]*regexp.Regexp, 0)
var allowedPathPatterns []*regexp.Regexp
allowRemote := true
if len(allowPaths) > 0 {
allowedPathPatterns = allowPaths[0]
}
if config != nil {
execConfig := config.Tools.Exec
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
allowedPathPatterns: allowedPathPatterns,
restrictToWorkspace: restrict,
allowRemote: allowRemote,
}, nil
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
cwd := t.workingDir
if wd, ok := args["working_dir"].(string); ok && wd != "" {
if t.restrictToWorkspace && t.workingDir != "" {
resolvedWD, err := validatePath(wd, t.workingDir, true)
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
if err != nil {
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
}
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
if err != nil {
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
}
absWorkspace, _ := filepath.Abs(t.workingDir)
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
if wsResolved == "" {
wsResolved = absWorkspace
if isAllowedPath(resolved, t.allowedPathPatterns) {
cwd = resolved
} else {
absWorkspace, _ := filepath.Abs(t.workingDir)
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
if wsResolved == "" {
wsResolved = absWorkspace
}
rel, err := filepath.Rel(wsResolved, resolved)
if err != nil || !filepath.IsLocal(rel) {
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
}
cwd = resolved
}
rel, err := filepath.Rel(wsResolved, resolved)
if err != nil || !filepath.IsLocal(rel) {
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
}
cwd = resolved
}
// timeout == 0 means no timeout
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
if safePaths[p] {
continue
}
if isAllowedPath(p, t.allowedPathPatterns) {
continue
}
rel, err := filepath.Rel(cwdPath, p)
if err != nil {
+178
View File
@@ -0,0 +1,178 @@
package tools
import (
"context"
"fmt"
"sort"
"strings"
"time"
)
// SpawnStatusTool reports the status of subagents that were spawned via the
// spawn tool. It can query a specific task by ID, or list every known task with
// a summary count broken-down by status.
type SpawnStatusTool struct {
manager *SubagentManager
}
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
return &SpawnStatusTool{manager: manager}
}
func (t *SpawnStatusTool) Name() string {
return "spawn_status"
}
func (t *SpawnStatusTool) Description() string {
return "Get the status of spawned subagents. " +
"Returns a list of all subagents and their current state " +
"(running, completed, failed, or canceled), or retrieves details " +
"for a specific subagent task when task_id is provided. " +
"Results are scoped to the current conversation's channel and chat ID; " +
"all tasks are listed only when no channel/chat context is injected " +
"(e.g. direct programmatic calls via Execute)."
}
func (t *SpawnStatusTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"task_id": map[string]any{
"type": "string",
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
"subagent. When omitted, all visible subagents are listed.",
},
},
"required": []string{},
}
}
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if t.manager == nil {
return ErrorResult("Subagent manager not configured")
}
// Derive the calling conversation's identity so we can scope results to the
// current chat only — preventing cross-conversation task leakage in
// multi-user deployments.
callerChannel := ToolChannel(ctx)
callerChatID := ToolChatID(ctx)
var taskID string
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
taskIDStr, ok := rawTaskID.(string)
if !ok {
return ErrorResult("task_id must be a string")
}
taskID = strings.TrimSpace(taskIDStr)
}
if taskID != "" {
// GetTaskCopy returns a consistent snapshot under the manager lock,
// eliminating any data race with the concurrent subagent goroutine.
taskCopy, ok := t.manager.GetTaskCopy(taskID)
if !ok {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
// Restrict lookup to tasks that belong to this conversation.
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
return NewToolResult(spawnStatusFormatTask(&taskCopy))
}
// ListTaskCopies returns consistent snapshots under the manager lock.
origTasks := t.manager.ListTaskCopies()
if len(origTasks) == 0 {
return NewToolResult("No subagents have been spawned yet.")
}
tasks := make([]*SubagentTask, 0, len(origTasks))
for i := range origTasks {
cpy := &origTasks[i]
// Filter to tasks that originate from the current conversation only.
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
continue
}
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
continue
}
tasks = append(tasks, cpy)
}
if len(tasks) == 0 {
return NewToolResult("No subagents found for this conversation.")
}
// Order by creation time (ascending) so spawning order is preserved.
// Fall back to ID string for tasks created in the same millisecond.
sort.Slice(tasks, func(i, j int) bool {
if tasks[i].Created != tasks[j].Created {
return tasks[i].Created < tasks[j].Created
}
return tasks[i].ID < tasks[j].ID
})
counts := map[string]int{}
for _, task := range tasks {
counts[task.Status]++
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
for _, status := range []string{"running", "completed", "failed", "canceled"} {
if n := counts[status]; n > 0 {
label := strings.ToUpper(status[:1]) + status[1:] + ":"
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
}
}
sb.WriteString("\n")
for _, task := range tasks {
sb.WriteString(spawnStatusFormatTask(task))
sb.WriteString("\n\n")
}
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
}
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
func spawnStatusFormatTask(task *SubagentTask) string {
var sb strings.Builder
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
if task.Label != "" {
header += fmt.Sprintf(" label=%q", task.Label)
}
if task.AgentID != "" {
header += fmt.Sprintf(" agent=%s", task.AgentID)
}
if task.Created > 0 {
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
header += fmt.Sprintf(" created=%s", created)
}
sb.WriteString(header)
if task.Task != "" {
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
}
if task.Result != "" {
result := task.Result
const maxResultLen = 300
runes := []rune(result)
if len(runes) > maxResultLen {
result = string(runes[:maxResultLen]) + "…"
}
sb.WriteString(fmt.Sprintf("\n result: %s", result))
}
return sb.String()
}
+406
View File
@@ -0,0 +1,406 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
"time"
)
func TestSpawnStatusTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
if tool.Name() != "spawn_status" {
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
}
}
func TestSpawnStatusTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(strings.ToLower(desc), "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
func TestSpawnStatusTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
params := tool.Parameters()
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Expected 'properties' to be a map")
}
if _, hasTaskID := props["task_id"]; !hasTaskID {
t.Error("Expected 'task_id' parameter in properties")
}
}
func TestSpawnStatusTool_NilManager(t *testing.T) {
tool := &SpawnStatusTool{manager: nil}
result := tool.Execute(context.Background(), map[string]any{})
if !result.IsError {
t.Error("Expected error result when manager is nil")
}
}
func TestSpawnStatusTool_Empty(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "No subagents") {
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ListAll(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
now := time.Now().UnixMilli()
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Do task A",
Label: "task-a",
Status: "running",
Created: now,
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2",
Task: "Do task B",
Label: "task-b",
Status: "completed",
Result: "Done successfully",
Created: now,
}
manager.tasks["subagent-3"] = &SubagentTask{
ID: "subagent-3",
Task: "Do task C",
Status: "failed",
Result: "Error: something went wrong",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
// Summary header
if !strings.Contains(result.ForLLM, "3 total") {
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
}
// Individual task IDs
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
if !strings.Contains(result.ForLLM, id) {
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
}
}
// Status values
for _, status := range []string{"running", "completed", "failed"} {
if !strings.Contains(result.ForLLM, status) {
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
}
}
// Result content
if !strings.Contains(result.ForLLM, "Done successfully") {
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_GetByID(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-42"] = &SubagentTask{
ID: "subagent-42",
Task: "Specific task",
Label: "my-task",
Status: "failed",
Result: "Something went wrong",
Created: time.Now().UnixMilli(),
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-42") {
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "failed") {
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Something went wrong") {
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "my-task") {
t.Errorf("Expected label in output, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
if !result.IsError {
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "nonexistent-999") {
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnStatusTool(manager)
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
if !result.IsError {
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
}
if !strings.Contains(result.ForLLM, "task_id must be a string") {
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
}
}
}
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
longResult := strings.Repeat("X", 500)
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Long task",
Status: "completed",
Result: longResult,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
// Output should be shorter than the raw result due to truncation
if len(result.ForLLM) >= len(longResult) {
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
}
if !strings.Contains(result.ForLLM, "…") {
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
cjkChar := string(rune(0x5b57))
longResult := strings.Repeat(cjkChar, 400)
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Unicode task",
Status: "completed",
Result: longResult,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "…") {
t.Errorf("Expected truncation indicator in output")
}
// The truncated result must be valid UTF-8 (no split rune boundaries).
if !strings.Contains(result.ForLLM, cjkChar) {
t.Errorf("Expected CJK runes to appear intact in output")
}
}
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
id := fmt.Sprintf("subagent-%d", i+1)
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
// The summary line should mention all statuses that have counts
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
if !strings.Contains(result.ForLLM, want) {
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
}
}
}
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
now := time.Now().UnixMilli()
manager.mu.Lock()
// Intentionally insert with out-of-order IDs and timestamps that reflect
// true spawn order: subagent-2 was spawned first, subagent-10 second.
manager.tasks["subagent-10"] = &SubagentTask{
ID: "subagent-10", Task: "second", Status: "running",
Created: now + 1,
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2", Task: "first", Status: "running",
Created: now,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
pos2 := strings.Index(result.ForLLM, "subagent-2")
pos10 := strings.Index(result.ForLLM, "subagent-10")
if pos2 < 0 || pos10 < 0 {
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
}
if pos2 > pos10 {
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1", Task: "mine", Status: "running",
OriginChannel: "telegram", OriginChatID: "chat-A",
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2", Task: "other user", Status: "running",
OriginChannel: "telegram", OriginChatID: "chat-B",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// Caller is chat-A — should only see subagent-1.
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
result := tool.Execute(ctx, map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-1") {
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
}
if strings.Contains(result.ForLLM, "subagent-2") {
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-99"] = &SubagentTask{
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
OriginChannel: "slack", OriginChatID: "room-Z",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// Different chat trying to look up subagent-99 by ID.
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
if !result.IsError {
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1", Task: "t", Status: "completed",
OriginChannel: "telegram", OriginChatID: "chat-A",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// No ToolContext injected (e.g. a direct programmatic call that bypasses
// WithToolContext entirely) — callerChannel and callerChatID are both "".
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
// which *does* inject a non-empty context; this test covers the case where
// no context injection happens at all.
// The filter conditions require a non-empty caller value, so all tasks pass through.
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-1") {
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
}
}
+25
View File
@@ -255,6 +255,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
return task, ok
}
// GetTaskCopy returns a copy of the task with the given ID, taken under the
// read lock, so the caller receives a consistent snapshot with no data race.
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
sm.mu.RLock()
defer sm.mu.RUnlock()
task, ok := sm.tasks[taskID]
if !ok {
return SubagentTask{}, false
}
return *task, true
}
func (sm *SubagentManager) ListTasks() []*SubagentTask {
sm.mu.RLock()
defer sm.mu.RUnlock()
@@ -266,6 +278,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
return tasks
}
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
// so callers receive consistent snapshots with no data race.
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
sm.mu.RLock()
defer sm.mu.RUnlock()
copies := make([]SubagentTask, 0, len(sm.tasks))
for _, task := range sm.tasks {
copies = append(copies, *task)
}
return copies
}
// SubagentTool executes a subagent task synchronously and returns the result.
// It directly calls SubTurnSpawner with Async=false for synchronous execution.
type SubagentTool struct {
+2 -1
View File
@@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
@@ -67,7 +68,7 @@ func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{
"error": err.Error(),