mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main into feat/subturn-poc
Includes JSONL session persistence (#1170), spawn_status tool, Azure provider, credential encryption, and various fixes. SubTurn features preserved and integrated with new spawn_status functionality.
This commit is contained in:
+25
-2
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -66,7 +67,7 @@ func NewAgentInstance(
|
||||
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
|
||||
|
||||
// Compile path whitelist patterns from config.
|
||||
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
@@ -82,7 +83,7 @@ func NewAgentInstance(
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("exec") {
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
@@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
|
||||
return compiled
|
||||
}
|
||||
|
||||
func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp {
|
||||
var configured []string
|
||||
if cfg != nil {
|
||||
configured = cfg.Tools.AllowReadPaths
|
||||
}
|
||||
|
||||
compiled := compilePatterns(configured)
|
||||
mediaDirPattern := regexp.MustCompile(mediaTempDirPattern())
|
||||
for _, pattern := range compiled {
|
||||
if pattern.String() == mediaDirPattern.String() {
|
||||
return compiled
|
||||
}
|
||||
}
|
||||
|
||||
return append(compiled, mediaDirPattern)
|
||||
}
|
||||
|
||||
func mediaTempDirPattern() string {
|
||||
sep := regexp.QuoteMeta(string(os.PathSeparator))
|
||||
return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)"
|
||||
}
|
||||
|
||||
// Close releases resources held by the agent's session store.
|
||||
func (a *AgentInstance) Close() error {
|
||||
if a.Sessions != nil {
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
|
||||
@@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
mediaPath := mediaFile.Name()
|
||||
if _, err := mediaFile.WriteString("attachment content"); err != nil {
|
||||
mediaFile.Close()
|
||||
t.Fatalf("WriteString(mediaFile) error = %v", err)
|
||||
}
|
||||
if err := mediaFile.Close(); err != nil {
|
||||
t.Fatalf("Close(mediaFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(mediaPath) })
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "test-model",
|
||||
RestrictToWorkspace: true,
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{Enabled: true},
|
||||
ListDir: config.ToolConfig{Enabled: true},
|
||||
Exec: config.ExecConfig{
|
||||
ToolConfig: config.ToolConfig{Enabled: true},
|
||||
EnableDenyPatterns: true,
|
||||
AllowRemote: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
|
||||
readTool, ok := agent.Tools.Get("read_file")
|
||||
if !ok {
|
||||
t.Fatal("read_file tool not registered")
|
||||
}
|
||||
readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath})
|
||||
if readResult.IsError {
|
||||
t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(readResult.ForLLM, "attachment content") {
|
||||
t.Fatalf("read_file output missing media content: %s", readResult.ForLLM)
|
||||
}
|
||||
|
||||
listTool, ok := agent.Tools.Get("list_dir")
|
||||
if !ok {
|
||||
t.Fatal("list_dir tool not registered")
|
||||
}
|
||||
listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir})
|
||||
if listResult.IsError {
|
||||
t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) {
|
||||
t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM)
|
||||
}
|
||||
|
||||
execTool, ok := agent.Tools.Get("exec")
|
||||
if !ok {
|
||||
t.Fatal("exec tool not registered")
|
||||
}
|
||||
execResult := execTool.Execute(context.Background(), map[string]any{
|
||||
"command": "cat " + filepath.Base(mediaPath),
|
||||
"working_dir": mediaDir,
|
||||
})
|
||||
if execResult.IsError {
|
||||
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(execResult.ForLLM, "attachment content") {
|
||||
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+70
-61
@@ -124,6 +124,8 @@ func registerSharedTools(
|
||||
registry *AgentRegistry,
|
||||
provider providers.LLMProvider,
|
||||
) {
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
@@ -202,6 +204,7 @@ func registerSharedTools(
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
cfg.Agents.Defaults.GetMaxMediaSize(),
|
||||
nil,
|
||||
allowReadPaths,
|
||||
)
|
||||
agent.Tools.Register(sendFileTool)
|
||||
}
|
||||
@@ -229,72 +232,75 @@ func registerSharedTools(
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
if cfg.Tools.IsToolEnabled("spawn") {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
|
||||
// Set the spawner that links into AgentLoop's turnState
|
||||
subagentManager.SetSpawner(func(
|
||||
ctx context.Context,
|
||||
task, label, targetAgentID string,
|
||||
tls *tools.ToolRegistry,
|
||||
maxTokens int,
|
||||
temperature float64,
|
||||
hasMaxTokens, hasTemperature bool,
|
||||
) (*tools.ToolResult, error) {
|
||||
// 1. Recover parent Turn State from Context
|
||||
parentTS := turnStateFromContext(ctx)
|
||||
if parentTS == nil {
|
||||
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
|
||||
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
|
||||
parentTS = &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "adhoc-root",
|
||||
depth: 0,
|
||||
session: newEphemeralSession(nil),
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Build Tools slice from registry
|
||||
var tlSlice []tools.Tool
|
||||
for _, name := range tls.List() {
|
||||
if t, ok := tls.Get(name); ok {
|
||||
tlSlice = append(tlSlice, t)
|
||||
}
|
||||
}
|
||||
// Spawn and spawn_status tools share a SubagentManager.
|
||||
// Construct it when either tool is enabled (both require subagent).
|
||||
spawnEnabled := cfg.Tools.IsToolEnabled("spawn")
|
||||
spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status")
|
||||
if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
|
||||
// 3. System Prompt
|
||||
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
|
||||
"You have access to tools - use them as needed to complete your task.\n" +
|
||||
"After completing the task, provide a clear summary of what was done.\n\n" +
|
||||
"Task: " + task
|
||||
|
||||
// 4. Resolve Model
|
||||
modelToUse := agent.Model
|
||||
if targetAgentID != "" {
|
||||
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
|
||||
modelToUse = targetAgent.Model
|
||||
}
|
||||
// Set the spawner that links into AgentLoop's turnState
|
||||
subagentManager.SetSpawner(func(
|
||||
ctx context.Context,
|
||||
task, label, targetAgentID string,
|
||||
tls *tools.ToolRegistry,
|
||||
maxTokens int,
|
||||
temperature float64,
|
||||
hasMaxTokens, hasTemperature bool,
|
||||
) (*tools.ToolResult, error) {
|
||||
// 1. Recover parent Turn State from Context
|
||||
parentTS := turnStateFromContext(ctx)
|
||||
if parentTS == nil {
|
||||
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
|
||||
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
|
||||
parentTS = &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "adhoc-root",
|
||||
depth: 0,
|
||||
session: newEphemeralSession(nil),
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Build SubTurnConfig
|
||||
cfg := SubTurnConfig{
|
||||
Model: modelToUse,
|
||||
Tools: tlSlice,
|
||||
SystemPrompt: systemPrompt,
|
||||
// 2. Build Tools slice from registry
|
||||
var tlSlice []tools.Tool
|
||||
for _, name := range tls.List() {
|
||||
if t, ok := tls.Get(name); ok {
|
||||
tlSlice = append(tlSlice, t)
|
||||
}
|
||||
if hasMaxTokens {
|
||||
cfg.MaxTokens = maxTokens
|
||||
}
|
||||
|
||||
// 3. System Prompt
|
||||
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
|
||||
"You have access to tools - use them as needed to complete your task.\n" +
|
||||
"After completing the task, provide a clear summary of what was done.\n\n" +
|
||||
"Task: " + task
|
||||
|
||||
// 4. Resolve Model
|
||||
modelToUse := agent.Model
|
||||
if targetAgentID != "" {
|
||||
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
|
||||
modelToUse = targetAgent.Model
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Spawn SubTurn
|
||||
return spawnSubTurn(ctx, al, parentTS, cfg)
|
||||
})
|
||||
// 5. Build SubTurnConfig
|
||||
cfg := SubTurnConfig{
|
||||
Model: modelToUse,
|
||||
Tools: tlSlice,
|
||||
SystemPrompt: systemPrompt,
|
||||
}
|
||||
if hasMaxTokens {
|
||||
cfg.MaxTokens = maxTokens
|
||||
}
|
||||
|
||||
// 6. Spawn SubTurn
|
||||
return spawnSubTurn(ctx, al, parentTS, cfg)
|
||||
})
|
||||
|
||||
if spawnEnabled {
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
@@ -311,9 +317,12 @@ func registerSharedTools(
|
||||
subagentTool := tools.NewSubagentTool(subagentManager)
|
||||
subagentTool.SetSpawner(spawner)
|
||||
agent.Tools.Register(subagentTool)
|
||||
} else {
|
||||
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
|
||||
}
|
||||
if spawnStatusEnabled {
|
||||
agent.Tools.Register(tools.NewSpawnStatusTool(subagentManager))
|
||||
}
|
||||
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
|
||||
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,7 +618,7 @@ func (c *FeishuChannel) downloadResource(
|
||||
}
|
||||
|
||||
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
mediaDir := media.TempDir()
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
|
||||
"error": mkdirErr.Error(),
|
||||
|
||||
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
if len(m.channels) == 0 {
|
||||
logger.WarnC("channels", "No channels enabled")
|
||||
return errors.New("no channels enabled")
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "Starting all channels")
|
||||
@@ -397,7 +396,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -35,8 +35,6 @@ const (
|
||||
roomKindCacheTTL = 5 * time.Minute
|
||||
roomKindCacheCleanupPeriod = 1 * time.Minute
|
||||
roomKindCacheMaxEntries = 2048
|
||||
|
||||
matrixMediaTempDirName = "picoclaw_media"
|
||||
)
|
||||
|
||||
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
|
||||
@@ -1105,7 +1103,7 @@ func (c *MatrixChannel) stripSelfMention(text string) string {
|
||||
}
|
||||
|
||||
func matrixMediaTempDir() (string, error) {
|
||||
mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName)
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
|
||||
@@ -165,7 +166,7 @@ func TestMatrixMediaTempDir(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("matrixMediaTempDir failed: %v", err)
|
||||
}
|
||||
if filepath.Base(dir) != matrixMediaTempDirName {
|
||||
if filepath.Base(dir) != media.TempDirName {
|
||||
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
|
||||
}
|
||||
|
||||
|
||||
@@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, nil)
|
||||
// Echo the matched subprotocol back so the browser accepts the upgrade.
|
||||
var responseHeader http.Header
|
||||
if proto := c.matchedSubprotocol(r); proto != "" {
|
||||
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
|
||||
if err != nil {
|
||||
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
go c.readLoop(pc)
|
||||
}
|
||||
|
||||
// authenticate checks the Bearer token from the Authorization header.
|
||||
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
|
||||
// authenticate checks the request for a valid token:
|
||||
// 1. Authorization: Bearer <token> header
|
||||
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
|
||||
// 3. Query parameter "token" (only when AllowTokenQuery is on)
|
||||
func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
token := c.config.Token
|
||||
if token == "" {
|
||||
@@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
|
||||
if c.matchedSubprotocol(r) != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check query parameter only when explicitly allowed
|
||||
if c.config.AllowTokenQuery {
|
||||
if r.URL.Query().Get("token") == token {
|
||||
@@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
|
||||
// the configured token, or "" if none do.
|
||||
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
|
||||
token := c.config.Token
|
||||
for _, proto := range websocket.Subprotocols(r) {
|
||||
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
|
||||
return proto
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readLoop reads messages from a WebSocket connection.
|
||||
func (c *PicoChannel) readLoop(pc *picoConn) {
|
||||
defer func() {
|
||||
|
||||
+79
-6
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
@@ -624,8 +626,9 @@ func (c *ModelConfig) Validate() error {
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
|
||||
}
|
||||
|
||||
type ToolDiscoveryConfig struct {
|
||||
@@ -698,8 +701,9 @@ type WebToolsConfig struct {
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
AllowCommand bool ` env:"PICOCLAW_TOOLS_CRON_ALLOW_COMMAND" json:"allow_command"`
|
||||
}
|
||||
|
||||
type ExecConfig struct {
|
||||
@@ -749,6 +753,7 @@ type ToolsConfig struct {
|
||||
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SpawnStatus ToolConfig `json:"spawn_status" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
|
||||
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
|
||||
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
|
||||
@@ -838,10 +843,24 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
|
||||
m.ModelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := resolveAPIKeys(cfg.ModelList, filepath.Dir(path)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
@@ -858,6 +877,48 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextAPIKeys returns a copy of models with plaintext api_key values
|
||||
// encrypted. Returns (nil, nil) when nothing changed (all keys already sealed or
|
||||
// empty). Returns (nil, error) if any key fails to encrypt — callers must treat
|
||||
// this as a hard failure to prevent a mixed plaintext/ciphertext state on disk.
|
||||
// Symmetric counterpart of resolveAPIKeys: both operate purely on []ModelConfig
|
||||
// and leave JSON marshaling to the caller.
|
||||
func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelConfig, error) {
|
||||
sealed := make([]ModelConfig, len(models))
|
||||
copy(sealed, models)
|
||||
changed := false
|
||||
for i := range sealed {
|
||||
m := &sealed[i]
|
||||
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
|
||||
continue
|
||||
}
|
||||
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot seal api_key for model %q: %w", m.ModelName, err)
|
||||
}
|
||||
m.APIKey = encrypted
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
return nil, nil
|
||||
}
|
||||
return sealed, nil
|
||||
}
|
||||
|
||||
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
|
||||
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
|
||||
func resolveAPIKeys(models []ModelConfig, configDir string) error {
|
||||
cr := credential.NewResolver(configDir)
|
||||
for i := range models {
|
||||
resolved, err := cr.Resolve(models[i].APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
|
||||
}
|
||||
models[i].APIKey = resolved
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) migrateChannelConfigs() {
|
||||
// Discord: mention_only -> group_trigger.mention_only
|
||||
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
@@ -872,12 +933,22 @@ func (c *Config) migrateChannelConfigs() {
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
sealed, err := encryptPlaintextAPIKeys(cfg.ModelList, passphrase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sealed != nil {
|
||||
tmp := *cfg
|
||||
tmp.ModelList = sealed
|
||||
cfg = &tmp
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use unified atomic write utility with explicit sync for flash storage reliability.
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
@@ -1044,6 +1115,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.ReadFile.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spawn_status":
|
||||
return t.SpawnStatus.Enabled
|
||||
case "spi":
|
||||
return t.SPI.Enabled
|
||||
case "subagent":
|
||||
|
||||
+386
-5
@@ -7,8 +7,22 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
// mustSetupSSHKey generates a temporary Ed25519 SSH key in t.TempDir() and sets
|
||||
// PICOCLAW_SSH_KEY_PATH to its path for the duration of the test. This is required
|
||||
// whenever a test exercises encryption/decryption via credential.Encrypt or SaveConfig.
|
||||
func mustSetupSSHKey(t *testing.T) {
|
||||
t.Helper()
|
||||
keyPath := filepath.Join(t.TempDir(), "picoclaw_ed25519.key")
|
||||
if err := credential.GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("mustSetupSSHKey: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", keyPath)
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
|
||||
@@ -253,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
if cfg.Gateway.HotReload {
|
||||
t.Error("Gateway hot reload should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Providers verifies provider structure
|
||||
@@ -391,6 +408,13 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Tools.Cron.AllowCommand {
|
||||
t.Fatal("DefaultConfig().Tools.Cron.AllowCommand should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -423,6 +447,22 @@ func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_CronAllowCommandDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"tools":{"cron":{"exec_timeout_minutes":5}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if !cfg.Tools.Cron.AllowCommand {
|
||||
t.Fatal("tools.cron.allow_command should remain true when unset in config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -482,13 +522,19 @@ func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
|
||||
// Unset to ensure we test the default
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
// Set a known home for consistent test results
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
var fakeHome string
|
||||
if runtime.GOOS == "windows" {
|
||||
fakeHome = `C:\tmp\home`
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
} else {
|
||||
fakeHome = "/tmp/home"
|
||||
t.Setenv("HOME", fakeHome)
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
|
||||
want := filepath.Join(fakeHome, ".picoclaw", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -499,7 +545,7 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := "/custom/picoclaw/home/workspace"
|
||||
want := filepath.Join("/custom/picoclaw/home", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -621,3 +667,338 @@ func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoadConfig_WarnsForPlaintextAPIKey verifies that LoadConfig resolves a plaintext
|
||||
// api_key into memory but does NOT rewrite the config file. File writes are the sole
|
||||
// responsibility of SaveConfig.
|
||||
func TestLoadConfig_WarnsForPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
const original = `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(original), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
// In-memory value must be the resolved plaintext.
|
||||
if cfg.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("in-memory api_key = %q, want %q", cfg.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
// The file on disk must remain unchanged — LoadConfig must not write anything.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if string(raw) != original {
|
||||
t.Errorf("LoadConfig must not modify the config file; got:\n%s", string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_EncryptsPlaintextAPIKey verifies that SaveConfig writes enc:// ciphertext
|
||||
// to disk and that a subsequent LoadConfig decrypts it back to the original plaintext.
|
||||
func TestSaveConfig_EncryptsPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
// Disk must contain enc://, not the raw key.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("saved file should contain enc://, got:\n%s", string(raw))
|
||||
}
|
||||
if strings.Contains(string(raw), "sk-plaintext") {
|
||||
t.Errorf("saved file must not contain the plaintext key")
|
||||
}
|
||||
|
||||
// A fresh load must decrypt back to the original plaintext.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
if cfg2.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("loaded api_key = %q, want %q", cfg2.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_NoSealWithoutPassphrase verifies that api_key values are left
|
||||
// unchanged when PICOCLAW_KEY_PASSPHRASE is not set.
|
||||
func TestLoadConfig_NoSealWithoutPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("config file must not be modified when no passphrase is set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_FileRefNotSealed verifies that file:// api_key references are not
|
||||
// converted to enc:// values (they are resolved at runtime by the Resolver).
|
||||
func TestLoadConfig_FileRefNotSealed(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
keyFile := filepath.Join(dir, "openai.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"file://openai.key"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "file://openai.key") {
|
||||
t.Error("file:// reference should be preserved unchanged in the config file")
|
||||
}
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("file:// reference must not be converted to enc://")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_MixedKeys verifies that SaveConfig encrypts only plaintext api_keys
|
||||
// and leaves already-encrypted (enc://) and file:// entries unchanged.
|
||||
func TestSaveConfig_MixedKeys(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Pre-encrypt one key so we have a genuine enc:// value to put in the config.
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "pre", Model: "openai/gpt-4", APIKey: "sk-already-plain"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
// Extract the enc:// value from the saved file.
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil || len(tmp.ModelList) == 0 {
|
||||
t.Fatalf("setup: could not parse saved config: %v", err)
|
||||
}
|
||||
alreadyEncrypted := tmp.ModelList[0].APIKey
|
||||
if !strings.HasPrefix(alreadyEncrypted, "enc://") {
|
||||
t.Fatalf("setup: expected enc:// key, got %q", alreadyEncrypted)
|
||||
}
|
||||
|
||||
// Build a config with three models:
|
||||
// 1. plaintext → must be encrypted by SaveConfig
|
||||
// 2. enc:// → must be left unchanged (already encrypted)
|
||||
// 3. file:// → must be left unchanged (file reference)
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "plain", Model: "openai/gpt-4", APIKey: "sk-new-plaintext"},
|
||||
{ModelName: "enc", Model: "openai/gpt-4", APIKey: alreadyEncrypted},
|
||||
{ModelName: "file", Model: "openai/gpt-4", APIKey: "file://api.key"},
|
||||
},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ = os.ReadFile(cfgPath)
|
||||
s := string(raw)
|
||||
|
||||
// 1. Plaintext must be encrypted.
|
||||
if strings.Contains(s, "sk-new-plaintext") {
|
||||
t.Error("plaintext key must not appear in saved file")
|
||||
}
|
||||
// 2. The pre-existing enc:// value must still be present (byte-for-byte unchanged).
|
||||
if !strings.Contains(s, alreadyEncrypted) {
|
||||
t.Error("pre-existing enc:// entry must be preserved unchanged")
|
||||
}
|
||||
// 3. file:// must be preserved.
|
||||
if !strings.Contains(s, "file://api.key") {
|
||||
t.Error("file:// reference must be preserved unchanged")
|
||||
}
|
||||
|
||||
// Now load and verify all three decrypt/resolve correctly.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
byName := make(map[string]string)
|
||||
for _, m := range cfg2.ModelList {
|
||||
byName[m.ModelName] = m.APIKey
|
||||
}
|
||||
if byName["plain"] != "sk-new-plaintext" {
|
||||
t.Errorf("plain model api_key = %q, want %q", byName["plain"], "sk-new-plaintext")
|
||||
}
|
||||
if byName["enc"] != "sk-already-plain" {
|
||||
t.Errorf("enc model api_key = %q, want %q", byName["enc"], "sk-already-plain")
|
||||
}
|
||||
if byName["file"] != "sk-from-file" {
|
||||
t.Errorf("file model api_key = %q, want %q", byName["file"], "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_MixedKeys_NoPassphrase verifies that when PICOCLAW_KEY_PASSPHRASE
|
||||
// is not set, enc:// entries cause LoadConfig to return an error, while plaintext
|
||||
// and file:// entries in the same config are not affected.
|
||||
func TestLoadConfig_MixedKeys_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// First encrypt a key so we have a real enc:// value.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "m", Model: "openai/gpt-4", APIKey: "sk-secret"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil {
|
||||
t.Fatalf("setup parse: %v", err)
|
||||
}
|
||||
encValue := tmp.ModelList[0].APIKey
|
||||
|
||||
// Write a mixed config: enc:// + plaintext + file://
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
mixed, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "enc", "model": "openai/gpt-4", "api_key": encValue},
|
||||
{"model_name": "plain", "model": "openai/gpt-4", "api_key": "sk-plain"},
|
||||
{"model_name": "file", "model": "openai/gpt-4", "api_key": "file://api.key"},
|
||||
},
|
||||
})
|
||||
if err := os.WriteFile(cfgPath, mixed, 0o600); err != nil {
|
||||
t.Fatalf("setup write: %v", err)
|
||||
}
|
||||
|
||||
// Now clear the passphrase — LoadConfig must fail because enc:// cannot be decrypted.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
_, err := LoadConfig(cfgPath)
|
||||
if err == nil {
|
||||
t.Fatal("LoadConfig should fail when enc:// key is present and no passphrase is set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "passphrase required") {
|
||||
t.Errorf("error should mention passphrase required, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_UsesPassphraseProvider verifies that SaveConfig encrypts plaintext
|
||||
// api_keys using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
// This matters for the launcher, which clears the environment variable and redirects
|
||||
// PassphraseProvider to an in-memory SecureStore.
|
||||
func TestSaveConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty — passphrase must come from PassphraseProvider only.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Replace PassphraseProvider with an in-memory function (simulating SecureStore).
|
||||
const testPassphrase = "provider-passphrase"
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("SaveConfig should have encrypted plaintext key via PassphraseProvider; got:\n%s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_UsesPassphraseProvider verifies that LoadConfig decrypts enc:// keys
|
||||
// using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty throughout.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
const testPassphrase = "provider-passphrase"
|
||||
const plainKey = "sk-secret"
|
||||
|
||||
// First, encrypt the key using the same passphrase.
|
||||
encrypted, err := credential.Encrypt(testPassphrase, "", plainKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "test", "model": "openai/gpt-4", "api_key": encrypted},
|
||||
},
|
||||
})
|
||||
if err = os.WriteFile(cfgPath, raw, 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Redirect PassphraseProvider — env var is empty, so without this the load would fail.
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
if cfg.ModelList[0].APIKey != plainKey {
|
||||
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
|
||||
}
|
||||
}
|
||||
|
||||
+16
-2
@@ -385,10 +385,20 @@ func DefaultConfig() *Config {
|
||||
APIBase: "http://localhost:8000/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Azure OpenAI - https://portal.azure.com
|
||||
// model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name
|
||||
{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIBase: "https://your-resource.openai.azure.com",
|
||||
APIKey: "",
|
||||
},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
HotReload: false,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
MediaCleanup: MediaCleanupConfig{
|
||||
@@ -444,6 +454,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
},
|
||||
ExecTimeoutMinutes: 5,
|
||||
AllowCommand: true,
|
||||
},
|
||||
Exec: ExecConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
@@ -513,6 +524,9 @@ func DefaultConfig() *Config {
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
SpawnStatus: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
SPI: ToolConfig{
|
||||
Enabled: false, // Hardware tool - Linux only
|
||||
},
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
// Package credential resolves API credential values for model_list entries.
|
||||
//
|
||||
// An API key is a form of authorization credential. This package centralizes
|
||||
// how raw credential strings—plaintext or file references—are resolved into
|
||||
// their actual values, keeping that logic out of the config loader.
|
||||
//
|
||||
// Supported formats for the api_key field:
|
||||
//
|
||||
// - Plaintext: "sk-abc123" → returned as-is
|
||||
// - File ref: "file://filename.key" → content read from configDir/filename.key
|
||||
// - Encrypted: "enc://<base64>" → AES-256-GCM decrypt via PICOCLAW_KEY_PASSPHRASE
|
||||
// - Empty: "" → returned as-is (auth_method=oauth etc.)
|
||||
//
|
||||
// Encryption uses AES-256-GCM with HKDF-SHA256 key derivation (< 1ms, safe for embedded Linux).
|
||||
// An SSH private key is required for both encryption and decryption.
|
||||
// Key derivation:
|
||||
//
|
||||
// HKDF-SHA256(ikm=HMAC-SHA256(SHA256(sshKeyBytes), passphrase), salt, info)
|
||||
//
|
||||
// SSH key path resolution priority:
|
||||
//
|
||||
// 1. sshKeyPath argument to Encrypt (explicit)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform)
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hkdf"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PassphraseEnvVar is the environment variable that holds the encryption passphrase.
|
||||
// Other packages (e.g. config) reference this constant to avoid duplicating the string.
|
||||
const PassphraseEnvVar = "PICOCLAW_KEY_PASSPHRASE"
|
||||
|
||||
// PassphraseProvider is the function used to retrieve the passphrase for enc://
|
||||
// credential decryption. It defaults to reading PICOCLAW_KEY_PASSPHRASE from the
|
||||
// process environment. Replace it at startup to use a different source, such as
|
||||
// an in-memory SecureStore, so that all LoadConfig() calls everywhere share the
|
||||
// same passphrase source without needing os.Environ.
|
||||
//
|
||||
// Example (launcher main.go):
|
||||
//
|
||||
// credential.PassphraseProvider = apiHandler.passphraseStore.Get
|
||||
var PassphraseProvider func() string = func() string {
|
||||
return os.Getenv(PassphraseEnvVar)
|
||||
}
|
||||
|
||||
// ErrPassphraseRequired is returned when an enc:// credential is encountered but
|
||||
// no passphrase is available from PassphraseProvider. Callers can detect this
|
||||
// with errors.Is to distinguish a missing-passphrase condition from other errors.
|
||||
var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
|
||||
|
||||
// ErrDecryptionFailed is returned when an enc:// credential cannot be decrypted,
|
||||
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
|
||||
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
|
||||
|
||||
const (
|
||||
fileScheme = "file://"
|
||||
encScheme = "enc://"
|
||||
hkdfInfo = "picoclaw-credential-v1"
|
||||
saltLen = 16
|
||||
nonceLen = 12
|
||||
keyLen = 32
|
||||
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
|
||||
)
|
||||
|
||||
// Resolver resolves raw credential strings for model_list api_key fields.
|
||||
// File references are resolved relative to the directory of the config file.
|
||||
type Resolver struct {
|
||||
configDir string
|
||||
resolvedConfigDir string // symlink-resolved form of configDir
|
||||
}
|
||||
|
||||
// NewResolver returns a Resolver that resolves file:// references relative to
|
||||
// configDir (typically filepath.Dir of the config file path).
|
||||
func NewResolver(configDir string) *Resolver {
|
||||
resolved := configDir
|
||||
if configDir != "" {
|
||||
if linkedPath, err := filepath.EvalSymlinks(configDir); err == nil {
|
||||
resolved = linkedPath
|
||||
}
|
||||
}
|
||||
return &Resolver{configDir: configDir, resolvedConfigDir: resolved}
|
||||
}
|
||||
|
||||
// Resolve returns the actual credential value for raw:
|
||||
//
|
||||
// - "" → "" (no error; auth_method=oauth needs no key)
|
||||
// - "file://name.key" → trimmed content of configDir/name.key
|
||||
// - anything else → raw unchanged (plaintext credential)
|
||||
func (r *Resolver) Resolve(raw string) (string, error) {
|
||||
if raw == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, fileScheme) {
|
||||
fileName := strings.TrimSpace(strings.TrimPrefix(raw, fileScheme))
|
||||
if fileName == "" {
|
||||
return "", fmt.Errorf("credential: file:// reference has no filename")
|
||||
}
|
||||
|
||||
baseDir := r.resolvedConfigDir
|
||||
if baseDir == "" {
|
||||
baseDir = r.configDir
|
||||
}
|
||||
keyPath := filepath.Join(baseDir, fileName)
|
||||
// Resolve symlinks before enforcing containment to prevent escaping via symlinks.
|
||||
realKeyPath, err := filepath.EvalSymlinks(keyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to resolve credential file path %q: %w", keyPath, err)
|
||||
}
|
||||
if !isWithinDir(realKeyPath, baseDir) {
|
||||
return "", fmt.Errorf("credential: file:// path escapes config directory")
|
||||
}
|
||||
data, err := os.ReadFile(realKeyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to read credential file %q: %w", realKeyPath, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(data))
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("credential: credential file %q is empty", realKeyPath)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, encScheme) {
|
||||
return resolveEncrypted(raw)
|
||||
}
|
||||
|
||||
// Plaintext credential — return unchanged.
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// resolveEncrypted decrypts an enc:// credential using PassphraseProvider.
|
||||
func resolveEncrypted(raw string) (string, error) {
|
||||
passphrase := PassphraseProvider()
|
||||
if passphrase == "" {
|
||||
return "", ErrPassphraseRequired
|
||||
}
|
||||
|
||||
sshKeyPath := pickSSHKeyPath("") // override="": consult env then auto-detect
|
||||
|
||||
b64 := strings.TrimPrefix(raw, encScheme)
|
||||
blob, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// invalid base64: %w", err)
|
||||
}
|
||||
if len(blob) < saltLen+nonceLen+1 {
|
||||
return "", fmt.Errorf("credential: enc:// payload too short")
|
||||
}
|
||||
|
||||
salt := blob[:saltLen]
|
||||
nonce := blob[saltLen : saltLen+nonceLen]
|
||||
ciphertext := blob[saltLen+nonceLen:]
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// gcm init: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %w", ErrDecryptionFailed, err)
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext and returns an enc:// credential string.
|
||||
//
|
||||
// passphrase is required (PICOCLAW_KEY_PASSPHRASE value).
|
||||
// sshKeyPath is the SSH private key file to use; pass "" to auto-detect via
|
||||
// PICOCLAW_SSH_KEY_PATH env var or ~/.ssh/picoclaw_ed25519.key.
|
||||
// An SSH private key must be resolvable or Encrypt returns an error.
|
||||
func Encrypt(passphrase, sshKeyPath, plaintext string) (string, error) {
|
||||
if passphrase == "" {
|
||||
return "", fmt.Errorf("credential: passphrase must not be empty")
|
||||
}
|
||||
sshKeyPath = pickSSHKeyPath(sshKeyPath)
|
||||
|
||||
salt := make([]byte, saltLen)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: gcm init: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, nonceLen)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
|
||||
blob := make([]byte, 0, saltLen+nonceLen+len(ciphertext))
|
||||
blob = append(blob, salt...)
|
||||
blob = append(blob, nonce...)
|
||||
blob = append(blob, ciphertext...)
|
||||
return encScheme + base64.StdEncoding.EncodeToString(blob), nil
|
||||
}
|
||||
|
||||
// isWithinDir reports whether path is contained within (or equal to) dir.
|
||||
// Uses filepath.IsLocal on the relative path for robust cross-platform traversal detection.
|
||||
func isWithinDir(path, dir string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(dir), filepath.Clean(path))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
}
|
||||
|
||||
// allowedSSHKeyPath reports whether path is in a permitted location for SSH key files:
|
||||
// - exact match with PICOCLAW_SSH_KEY_PATH env var
|
||||
// - within the PICOCLAW_HOME env var directory
|
||||
// - within ~/.ssh/
|
||||
func allowedSSHKeyPath(path string) bool {
|
||||
if path == "" {
|
||||
return true // passphrase-only mode; no file will be read
|
||||
}
|
||||
clean := filepath.Clean(path)
|
||||
|
||||
// Exact match with PICOCLAW_SSH_KEY_PATH.
|
||||
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
|
||||
if clean == filepath.Clean(envPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within PICOCLAW_HOME.
|
||||
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
|
||||
if isWithinDir(clean, picoHome) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within ~/.ssh/.
|
||||
if userHome, err := os.UserHomeDir(); err == nil {
|
||||
if isWithinDir(clean, filepath.Join(userHome, ".ssh")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// deriveKey derives a 32-byte AES-256 key from passphrase and SSH private key.
|
||||
//
|
||||
// ikm = HMAC-SHA256(key=SHA256(sshKeyBytes), msg=passphrase)
|
||||
// Final key: HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
|
||||
// sshKeyPath must be non-empty; returns an error otherwise.
|
||||
func deriveKey(passphrase, sshKeyPath string, salt []byte) ([]byte, error) {
|
||||
if sshKeyPath == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH private key is required but not found" +
|
||||
" (set PICOCLAW_SSH_KEY_PATH or place key at ~/.ssh/picoclaw_ed25519.key)")
|
||||
}
|
||||
if !allowedSSHKeyPath(sshKeyPath) {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH key path %q is not in an allowed location (PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/)",
|
||||
sshKeyPath,
|
||||
)
|
||||
}
|
||||
sshBytes, err := os.ReadFile(sshKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: cannot read SSH key %q: %w", sshKeyPath, err)
|
||||
}
|
||||
sshHash := sha256.Sum256(sshBytes)
|
||||
mac := hmac.New(sha256.New, sshHash[:])
|
||||
mac.Write([]byte(passphrase))
|
||||
ikm := mac.Sum(nil)
|
||||
|
||||
key, err := hkdf.Key(sha256.New, ikm, salt, hkdfInfo, keyLen)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: HKDF expand failed: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// pickSSHKeyPath returns the SSH private key path to use for encryption/decryption.
|
||||
//
|
||||
// Priority:
|
||||
// 1. override (non-empty explicit argument)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (auto-detection)
|
||||
//
|
||||
// Returns "" when no key is found; deriveKey will return an error in that case.
|
||||
func pickSSHKeyPath(override string) string {
|
||||
if override != "" {
|
||||
return override
|
||||
}
|
||||
if p, ok := os.LookupEnv(sshKeyEnv); ok {
|
||||
return p // respect explicit setting, even if ""
|
||||
}
|
||||
return findDefaultSSHKey()
|
||||
}
|
||||
|
||||
// findDefaultSSHKey returns the picoclaw-specific SSH key path if it exists.
|
||||
func findDefaultSSHKey() string {
|
||||
p, err := DefaultSSHKeyPath()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package credential_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
func TestResolve_PlainKey(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve("sk-plaintext-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-plaintext-key" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-plaintext-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Success(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "openai_plain.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte("sk-from-file\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://" + keyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-from-file" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_NotFound(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("file://missing.key")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Empty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "empty.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte(" \n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
_, err := r.Resolve("file://" + keyFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty credential file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_RoundTrip tests basic encryption/decryption round-trip with an SSH key.
|
||||
func TestResolve_EncKey_RoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase-32bytes-long-ok!"
|
||||
const plaintext = "sk-encrypted-secret"
|
||||
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, "", plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_WithSSHKey tests that the SSH key file is incorporated into key derivation.
|
||||
func TestResolve_EncKey_WithSSHKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-private-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase"
|
||||
const plaintext = "sk-ssh-protected-secret"
|
||||
|
||||
// Set PICOCLAW_SSH_KEY_PATH before Encrypt so the path passes allowedSSHKeyPath validation.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, sshKeyPath, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("some-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when PICOCLAW_KEY_PASSPHRASE is unset, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_BadCiphertext(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://!!not-valid-base64!!")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_PayloadTooShort(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
// Valid base64 but fewer bytes than salt(16)+nonce(12)+1 minimum.
|
||||
import64 := "dG9vc2hvcnQ=" // "tooshort" = 8 bytes
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://" + import64)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for too-short enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_WrongPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("correct-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "wrong-passphrase")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected decryption error for wrong passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt_EmptyPassphrase(t *testing.T) {
|
||||
_, err := credential.Encrypt("", "", "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKey_SSHKeyNotFound(t *testing.T) {
|
||||
// Encrypt with a real SSH key path, then try to decrypt with a missing path.
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Register the real key path so allowedSSHKeyPath validation passes for Encrypt.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Point to a non-existent SSH key so deriveKey's ReadFile fails.
|
||||
// The path is still under the same dir, so allowedSSHKeyPath passes (exact env match).
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", filepath.Join(dir, "nonexistent_key"))
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when SSH key file is missing, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_PathTraversal verifies that file:// references cannot escape configDir
|
||||
// via relative traversal ("../../etc/passwd") or absolute paths ("/abs/path").
|
||||
func TestResolve_FileRef_PathTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
// Create a file outside configDir that the traversal would point to.
|
||||
outsideFile := filepath.Join(t.TempDir(), "secret.key")
|
||||
if err := os.WriteFile(outsideFile, []byte("stolen"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(filepath.Dir(cfgPath))
|
||||
|
||||
cases := []string{
|
||||
"file://../../secret.key",
|
||||
"file://../secret.key",
|
||||
"file://" + outsideFile, // absolute path
|
||||
}
|
||||
for _, raw := range cases {
|
||||
_, err := r.Resolve(raw)
|
||||
if err == nil {
|
||||
t.Errorf("Resolve(%q): expected path traversal error, got nil", raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_withinConfigDir verifies that a legitimate relative file:// ref works.
|
||||
func TestResolve_FileRef_withinConfigDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "my.key"), []byte("sk-valid\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://my.key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-valid" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-valid")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncrypt_SSHKeyOutsideAllowedDirs verifies that Encrypt rejects SSH key paths
|
||||
// that are not under PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/.
|
||||
func TestEncrypt_SSHKeyOutsideAllowedDirs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Make sure none of the allowed env vars point here.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
|
||||
_, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for SSH key outside allowed directories, got nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// DefaultSSHKeyPath returns the canonical path for the picoclaw-specific SSH key.
|
||||
// The path is always ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform).
|
||||
func DefaultSSHKeyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cannot determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".ssh", "picoclaw_ed25519.key"), nil
|
||||
}
|
||||
|
||||
// GenerateSSHKey generates an Ed25519 SSH key pair and writes the private key
|
||||
// to path (permissions 0600) and the public key to path+".pub" (permissions 0644).
|
||||
// The ~/.ssh/ directory is created with 0700 if it does not exist.
|
||||
// If the files already exist they are overwritten.
|
||||
func GenerateSSHKey(path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("credential: keygen: cannot create directory %q: %w", filepath.Dir(path), err)
|
||||
}
|
||||
|
||||
pubRaw, privRaw, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: ed25519 key generation failed: %w", err)
|
||||
}
|
||||
|
||||
// Marshal private key as OpenSSH PEM.
|
||||
block, err := ssh.MarshalPrivateKey(privRaw, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal private key: %w", err)
|
||||
}
|
||||
privPEM := pem.EncodeToMemory(block)
|
||||
|
||||
if err = os.WriteFile(path, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write private key %q: %w", path, err)
|
||||
}
|
||||
|
||||
// Marshal public key as authorized_keys line.
|
||||
sshPub, err := ssh.NewPublicKey(pubRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal public key: %w", err)
|
||||
}
|
||||
pubLine := ssh.MarshalAuthorizedKey(sshPub)
|
||||
|
||||
pubPath := path + ".pub"
|
||||
if err := os.WriteFile(pubPath, pubLine, 0o644); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write public key %q: %w", pubPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestGenerateSSHKey_CreatesFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
// Private key must exist.
|
||||
privInfo, err := os.Stat(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("private key file missing: %v", err)
|
||||
}
|
||||
|
||||
// Check permissions on non-Windows (Windows does not support Unix permission bits).
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := privInfo.Mode().Perm(); got != 0o600 {
|
||||
t.Errorf("private key permissions = %04o, want 0600", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Public key must exist.
|
||||
pubPath := keyPath + ".pub"
|
||||
pubInfo, err := os.Stat(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("public key file missing: %v", err)
|
||||
}
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := pubInfo.Mode().Perm(); got != 0o644 {
|
||||
t.Errorf("public key permissions = %04o, want 0644", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Private key must be parseable as an OpenSSH ed25519 key.
|
||||
privPEM, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read private key: %v", err)
|
||||
}
|
||||
privKey, err := ssh.ParseRawPrivateKey(privPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("parse private key: %v", err)
|
||||
}
|
||||
if _, ok := privKey.(*ed25519.PrivateKey); !ok {
|
||||
t.Errorf("private key type = %T, want *ed25519.PrivateKey", privKey)
|
||||
}
|
||||
|
||||
// Public key must be parseable as authorized_keys line.
|
||||
pubBytes, err := os.ReadFile(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read public key: %v", err)
|
||||
}
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(pubBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("parse public key: %v", err)
|
||||
}
|
||||
if pubKey == nil {
|
||||
t.Fatal("expected non-nil public key")
|
||||
}
|
||||
if len(rest) > 0 {
|
||||
t.Errorf("unexpected trailing bytes after public key: %d bytes", len(rest))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_OverwritesExisting(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
// Generate twice; second call must not error and must produce a different key.
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("first GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
first, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first key: %v", err)
|
||||
}
|
||||
|
||||
if err = GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("second GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
second, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second key: %v", err)
|
||||
}
|
||||
|
||||
// Two independently generated Ed25519 keys must differ.
|
||||
if string(first) == string(second) {
|
||||
t.Error("expected overwritten key to differ from original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_CreatesDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Nested directory that does not yet exist.
|
||||
keyPath := filepath.Join(dir, "subdir", ".ssh", "picoclaw_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(keyPath); err != nil {
|
||||
t.Fatalf("private key not created: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package credential
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// SecureStore holds a passphrase in memory.
|
||||
//
|
||||
// Uses atomic.Pointer so reads and writes are lock-free.
|
||||
// The passphrase is never written to disk; callers decide how to
|
||||
// transport it outside this store (e.g., via cmd.Env or os.Environ).
|
||||
type SecureStore struct {
|
||||
val atomic.Pointer[string]
|
||||
}
|
||||
|
||||
// NewSecureStore creates an empty SecureStore.
|
||||
func NewSecureStore() *SecureStore {
|
||||
return &SecureStore{}
|
||||
}
|
||||
|
||||
// SetString stores the passphrase. An empty string clears the store.
|
||||
func (s *SecureStore) SetString(passphrase string) {
|
||||
if passphrase == "" {
|
||||
s.val.Store(nil)
|
||||
return
|
||||
}
|
||||
s.val.Store(&passphrase)
|
||||
}
|
||||
|
||||
// Get returns the stored passphrase, or "" if not set.
|
||||
func (s *SecureStore) Get() string {
|
||||
if p := s.val.Load(); p != nil {
|
||||
return *p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSet reports whether a passphrase is currently stored.
|
||||
func (s *SecureStore) IsSet() bool {
|
||||
return s.val.Load() != nil
|
||||
}
|
||||
|
||||
// Clear removes the stored passphrase.
|
||||
func (s *SecureStore) Clear() {
|
||||
s.val.Store(nil)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSecureStore_SetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
if s.IsSet() {
|
||||
t.Error("expected empty store")
|
||||
}
|
||||
|
||||
s.SetString("hunter2")
|
||||
if !s.IsSet() {
|
||||
t.Error("expected store to be set")
|
||||
}
|
||||
if got := s.Get(); got != "hunter2" {
|
||||
t.Errorf("Get() = %q, want %q", got, "hunter2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_Clear(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("secret")
|
||||
s.Clear()
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("expected store to be empty after Clear()")
|
||||
}
|
||||
if got := s.Get(); got != "" {
|
||||
t.Errorf("Get() after Clear() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_SetOverwrites(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("first")
|
||||
s.SetString("second")
|
||||
|
||||
if got := s.Get(); got != "second" {
|
||||
t.Errorf("Get() = %q, want %q", got, "second")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_EmptyPassphrase(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("") // empty → should not mark as set
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("empty passphrase should not mark store as set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_ConcurrentSetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
const goroutines = 10
|
||||
const iterations = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
if id%2 == 0 {
|
||||
s.SetString("even")
|
||||
} else {
|
||||
s.SetString("odd")
|
||||
}
|
||||
_ = s.Get()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
final := s.Get()
|
||||
if final != "" && final != "even" && final != "odd" {
|
||||
t.Errorf("Get() returned unexpected value %q after concurrent Set/Get", final)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,594 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/discord"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/feishu"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/irc"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/line"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/matrix"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp_native"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/devices"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceShutdownTimeout = 30 * time.Second
|
||||
providerReloadTimeout = 30 * time.Second
|
||||
gracefulShutdownTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
type services struct {
|
||||
CronService *cron.CronService
|
||||
HeartbeatService *heartbeat.HeartbeatService
|
||||
MediaStore media.MediaStore
|
||||
ChannelManager *channels.Manager
|
||||
DeviceService *devices.Service
|
||||
HealthServer *health.Server
|
||||
}
|
||||
|
||||
type startupBlockedProvider struct {
|
||||
reason string
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) Chat(
|
||||
_ context.Context,
|
||||
_ []providers.Message,
|
||||
_ []providers.ToolDefinition,
|
||||
_ string,
|
||||
_ map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return nil, fmt.Errorf("%s", p.reason)
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Run starts the gateway runtime using the configuration loaded from configPath.
|
||||
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %w", err)
|
||||
}
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
}
|
||||
|
||||
if modelID != "" {
|
||||
cfg.Agents.Defaults.ModelName = modelID
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
fmt.Println("\n📦 Agent Status:")
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
toolsInfo := startupInfo["tools"].(map[string]any)
|
||||
skillsInfo := startupInfo["skills"].(map[string]any)
|
||||
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
|
||||
fmt.Printf(" • Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
|
||||
|
||||
logger.InfoCF("agent", "Agent initialized",
|
||||
map[string]any{
|
||||
"tools_count": toolsInfo["count"],
|
||||
"skills_total": skillsInfo["total"],
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
var configReloadChan <-chan *config.Config
|
||||
stopWatch := func() {}
|
||||
if cfg.Gateway.HotReload {
|
||||
configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
|
||||
logger.Info("Config hot reload enabled")
|
||||
}
|
||||
defer stopWatch()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sigChan:
|
||||
logger.Info("Shutting down...")
|
||||
shutdownGateway(runningServices, agentLoop, provider, true)
|
||||
return nil
|
||||
case newCfg := <-configReloadChan:
|
||||
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf("Config reload failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createStartupProvider(
|
||||
cfg *config.Config,
|
||||
allowEmptyStartup bool,
|
||||
) (providers.LLMProvider, string, error) {
|
||||
modelName := cfg.Agents.Defaults.GetModelName()
|
||||
if modelName == "" && allowEmptyStartup {
|
||||
reason := "no default model configured; gateway started in limited mode"
|
||||
fmt.Printf("⚠ Warning: %s\n", reason)
|
||||
logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
|
||||
"limited_mode": true,
|
||||
})
|
||||
return &startupBlockedProvider{reason: reason}, "", nil
|
||||
}
|
||||
|
||||
return providers.CreateProvider(cfg)
|
||||
}
|
||||
|
||||
func setupAndStartServices(
|
||||
cfg *config.Config,
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
var err error
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
agentLoop,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error setting up cron service: %w", err)
|
||||
}
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting cron service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Cron service started")
|
||||
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
return nil, fmt.Errorf("error creating channel manager: %w", err)
|
||||
}
|
||||
|
||||
agentLoop.SetChannelManager(runningServices.ChannelManager)
|
||||
agentLoop.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
|
||||
agentLoop.SetTranscriber(transcriber)
|
||||
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
}
|
||||
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err = runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
|
||||
return runningServices, nil
|
||||
}
|
||||
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
if runningServices.ChannelManager != nil {
|
||||
runningServices.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if runningServices.DeviceService != nil {
|
||||
runningServices.DeviceService.Stop()
|
||||
}
|
||||
if runningServices.HeartbeatService != nil {
|
||||
runningServices.HeartbeatService.Stop()
|
||||
}
|
||||
if runningServices.CronService != nil {
|
||||
runningServices.CronService.Stop()
|
||||
}
|
||||
if runningServices.MediaStore != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shutdownGateway(
|
||||
runningServices *services,
|
||||
agentLoop *agent.AgentLoop,
|
||||
provider providers.LLMProvider,
|
||||
fullShutdown bool,
|
||||
) {
|
||||
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
|
||||
logger.Info("✓ Gateway stopped")
|
||||
}
|
||||
|
||||
func handleConfigReload(
|
||||
ctx context.Context,
|
||||
al *agent.AgentLoop,
|
||||
newCfg *config.Config,
|
||||
providerRef *providers.LLMProvider,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
allowEmptyStartup bool,
|
||||
) error {
|
||||
logger.Info("🔄 Config file changed, reloading...")
|
||||
|
||||
newModel := newCfg.Agents.Defaults.ModelName
|
||||
if newModel == "" {
|
||||
newModel = newCfg.Agents.Defaults.Model
|
||||
}
|
||||
|
||||
logger.Infof(" New model is '%s', recreating provider...", newModel)
|
||||
|
||||
logger.Info(" Stopping all services...")
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
|
||||
|
||||
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf(" ⚠ Error creating new provider: %v", err)
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error creating new provider: %w", err)
|
||||
}
|
||||
|
||||
if newModelID != "" {
|
||||
newCfg.Agents.Defaults.ModelName = newModelID
|
||||
}
|
||||
|
||||
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
|
||||
defer reloadCancel()
|
||||
|
||||
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
|
||||
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
|
||||
if cp, ok := newProvider.(providers.StatefulProvider); ok {
|
||||
cp.Close()
|
||||
}
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error reloading agent loop: %w", err)
|
||||
}
|
||||
|
||||
*providerRef = newProvider
|
||||
|
||||
logger.Info(" Restarting all services with new configuration...")
|
||||
if err := restartServices(al, runningServices, msgBus); err != nil {
|
||||
logger.Errorf(" ⚠ Error restarting services: %v", err)
|
||||
return fmt.Errorf("error restarting services: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(" ✓ Provider, configuration, and services reloaded successfully (thread-safe)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restartServices(
|
||||
al *agent.AgentLoop,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
) error {
|
||||
cfg := al.GetConfig()
|
||||
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
var err error
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
al,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Cron service restarted")
|
||||
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Heartbeat service restarted")
|
||||
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
al.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error recreating channel manager: %w", err)
|
||||
}
|
||||
al.SetChannelManager(runningServices.ChannelManager)
|
||||
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println(" ⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return fmt.Errorf("error restarting channels: %w", err)
|
||||
}
|
||||
fmt.Printf(
|
||||
" ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n",
|
||||
cfg.Gateway.Host,
|
||||
cfg.Gateway.Port,
|
||||
)
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err := runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println(" ✓ Device event service restarted")
|
||||
}
|
||||
|
||||
transcriber := voice.DetectTranscriber(cfg)
|
||||
al.SetTranscriber(transcriber)
|
||||
if transcriber != nil {
|
||||
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
} else {
|
||||
logger.InfoCF("voice", "Transcription disabled", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
|
||||
configChan := make(chan *config.Config, 1)
|
||||
stop := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
lastModTime := getFileModTime(configPath)
|
||||
lastSize := getFileSize(configPath)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
currentModTime := getFileModTime(configPath)
|
||||
currentSize := getFileSize(configPath)
|
||||
|
||||
if currentModTime.After(lastModTime) || currentSize != lastSize {
|
||||
if debug {
|
||||
logger.Debugf("🔍 Config file change detected")
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
lastModTime = currentModTime
|
||||
lastSize = currentSize
|
||||
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
logger.Errorf("⚠ Error loading new config: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := newCfg.ValidateModelList(); err != nil {
|
||||
logger.Errorf(" ⚠ New config validation failed: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("✓ Config file validated and loaded")
|
||||
|
||||
select {
|
||||
case configChan <- newCfg:
|
||||
default:
|
||||
logger.Warn("⚠ Previous config reload still in progress, skipping")
|
||||
}
|
||||
}
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
stopFunc := func() {
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
return configChan, stopFunc
|
||||
}
|
||||
|
||||
func getFileModTime(path string) time.Time {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return info.ModTime()
|
||||
}
|
||||
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
|
||||
func setupCronTool(
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
workspace string,
|
||||
restrict bool,
|
||||
execTimeout time.Duration,
|
||||
cfg *config.Config,
|
||||
) (*cron.CronService, error) {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
var cronTool *tools.CronTool
|
||||
if cfg.Tools.IsToolEnabled("cron") {
|
||||
var err error
|
||||
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("critical error during CronTool initialization: %w", err)
|
||||
}
|
||||
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
}
|
||||
|
||||
if cronTool != nil {
|
||||
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||
result := cronTool.ExecuteJob(context.Background(), job)
|
||||
return result, nil
|
||||
})
|
||||
}
|
||||
|
||||
return cronService, nil
|
||||
}
|
||||
|
||||
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
|
||||
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
return tools.SilentResult(response)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -29,6 +30,7 @@ type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
Pid int `json:"pid"`
|
||||
}
|
||||
|
||||
func NewServer(host string, port int) *Server {
|
||||
@@ -112,6 +114,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
Pid: os.Getpid(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
+58
-7
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -45,6 +46,9 @@ func init() {
|
||||
consoleWriter := zerolog.ConsoleWriter{
|
||||
Out: os.Stdout,
|
||||
TimeFormat: "15:04:05", // TODO: make it configurable???
|
||||
|
||||
// Custom formatter to handle multiline strings and JSON objects
|
||||
FormatFieldValue: formatFieldValue,
|
||||
}
|
||||
|
||||
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
|
||||
@@ -52,6 +56,37 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func formatFieldValue(i any) string {
|
||||
var s string
|
||||
|
||||
switch val := i.(type) {
|
||||
case string:
|
||||
s = val
|
||||
case []byte:
|
||||
s = string(val)
|
||||
default:
|
||||
return fmt.Sprintf("%v", i)
|
||||
}
|
||||
|
||||
if unquoted, err := strconv.Unquote(s); err == nil {
|
||||
s = unquoted
|
||||
}
|
||||
|
||||
if strings.Contains(s, "\n") {
|
||||
return fmt.Sprintf("\n%s", s)
|
||||
}
|
||||
|
||||
if strings.Contains(s, " ") {
|
||||
if (strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}")) ||
|
||||
(strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]")) {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%q", s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func SetLevel(level LogLevel) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
@@ -163,10 +198,7 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
event.Str("caller", fmt.Sprintf("<none> %s:%d (%s)", callerFile, callerLine, callerFunc))
|
||||
}
|
||||
|
||||
for k, v := range fields {
|
||||
event.Interface(k, v)
|
||||
}
|
||||
|
||||
appendFields(event, fields)
|
||||
event.Msg(message)
|
||||
|
||||
// Also log to file if enabled
|
||||
@@ -176,9 +208,8 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
if component != "" {
|
||||
fileEvent.Str("component", component)
|
||||
}
|
||||
for k, v := range fields {
|
||||
fileEvent.Interface(k, v)
|
||||
}
|
||||
|
||||
appendFields(event, fields)
|
||||
fileEvent.Msg(message)
|
||||
}
|
||||
|
||||
@@ -187,6 +218,26 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
}
|
||||
}
|
||||
|
||||
func appendFields(event *zerolog.Event, fields map[string]any) {
|
||||
for k, v := range fields {
|
||||
// Type switch to avoid double JSON serialization of strings
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
event.Str(k, val)
|
||||
case int:
|
||||
event.Int(k, val)
|
||||
case int64:
|
||||
event.Int64(k, val)
|
||||
case float64:
|
||||
event.Float64(k, val)
|
||||
case bool:
|
||||
event.Bool(k, val)
|
||||
default:
|
||||
event.Interface(k, v) // Fallback for struct, slice and maps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Debug(message string) {
|
||||
logMessage(DEBUG, "", message, nil)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,20 @@
|
||||
|
||||
package logger
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// botTokenRe matches the bot ID prefix and the secret part of a Telegram bot token.
|
||||
// Groups: 1 = "bot<id>:", 2 = first 4 chars of secret, 3 = middle, 4 = last 4 chars.
|
||||
var botTokenRe = regexp.MustCompile(`(bot\d+:)([A-Za-z0-9_-]{4})[A-Za-z0-9_-]{12,}([A-Za-z0-9_-]{4})`)
|
||||
|
||||
// maskSecrets replaces any embedded bot tokens in s with a redacted placeholder
|
||||
// that keeps the first and last 4 characters of the secret for identification.
|
||||
func maskSecrets(s string) string {
|
||||
return botTokenRe.ReplaceAllString(s, "${1}${2}****${3}")
|
||||
}
|
||||
|
||||
// Logger implements common Logger interface
|
||||
type Logger struct {
|
||||
@@ -12,52 +25,52 @@ type Logger struct {
|
||||
|
||||
// Debug logs debug messages
|
||||
func (b *Logger) Debug(v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Info logs info messages
|
||||
func (b *Logger) Info(v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(INFO, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Warn logs warning messages
|
||||
func (b *Logger) Warn(v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Error logs error messages
|
||||
func (b *Logger) Error(v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(ERROR, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Debugf logs formatted debug messages
|
||||
func (b *Logger) Debugf(format string, v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Infof logs formatted info messages
|
||||
func (b *Logger) Infof(format string, v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(INFO, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Warnf logs formatted warning messages
|
||||
func (b *Logger) Warnf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Warningf logs formatted warning messages
|
||||
func (b *Logger) Warningf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Errorf logs formatted error messages
|
||||
func (b *Logger) Errorf(format string, v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(ERROR, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Fatalf logs formatted fatal messages and exits
|
||||
func (b *Logger) Fatalf(format string, v ...any) {
|
||||
logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(FATAL, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Log logs a message at a given level with caller information
|
||||
@@ -75,7 +88,7 @@ func (b *Logger) Log(msgL, caller int, format string, a ...any) {
|
||||
level = lvl
|
||||
}
|
||||
}
|
||||
logMessage(level, b.component, fmt.Sprintf(format, a...), nil)
|
||||
logMessage(level, b.component, maskSecrets(fmt.Sprintf(format, a...)), nil)
|
||||
}
|
||||
|
||||
// Sync flushes log buffer (no-op for this implementation)
|
||||
|
||||
@@ -141,3 +141,114 @@ func TestLoggerHelperFunctions(t *testing.T) {
|
||||
Debugf("test from %v", "Debugf")
|
||||
WarnF("Warning with fields", map[string]any{"key": "value"})
|
||||
}
|
||||
|
||||
func TestFormatFieldValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input any
|
||||
expected string
|
||||
}{
|
||||
// Basic types test (default case of the switch)
|
||||
{
|
||||
name: "Integer Type",
|
||||
input: 42,
|
||||
expected: "42",
|
||||
},
|
||||
{
|
||||
name: "Boolean Type",
|
||||
input: true,
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "Unsupported Struct Type",
|
||||
input: struct{ A int }{A: 1},
|
||||
expected: "{1}",
|
||||
},
|
||||
|
||||
// Simple strings and byte slices test
|
||||
{
|
||||
name: "Simple string without spaces",
|
||||
input: "simple_value",
|
||||
expected: "simple_value",
|
||||
},
|
||||
{
|
||||
name: "Simple byte slice",
|
||||
input: []byte("byte_value"),
|
||||
expected: "byte_value",
|
||||
},
|
||||
|
||||
// Unquoting test (strconv.Unquote)
|
||||
{
|
||||
name: "Quoted string",
|
||||
input: `"quoted_value"`,
|
||||
expected: "quoted_value",
|
||||
},
|
||||
|
||||
// Strings with newline (\n) test
|
||||
{
|
||||
name: "String with newline",
|
||||
input: "line1\nline2",
|
||||
expected: "\nline1\nline2",
|
||||
},
|
||||
{
|
||||
name: "Quoted string with newline (Unquote -> newline)",
|
||||
input: `"line1\nline2"`, // Escaped \n that Unquote will resolve
|
||||
expected: "\nline1\nline2",
|
||||
},
|
||||
|
||||
// Strings with spaces test (which should be quoted)
|
||||
{
|
||||
name: "String with spaces",
|
||||
input: "hello world",
|
||||
expected: `"hello world"`,
|
||||
},
|
||||
{
|
||||
name: "Quoted string with spaces (Unquote -> has spaces -> Re-quote)",
|
||||
input: `"hello world"`,
|
||||
expected: `"hello world"`,
|
||||
},
|
||||
|
||||
// JSON formats test (strings with spaces that start/end with brackets)
|
||||
{
|
||||
name: "Valid JSON object",
|
||||
input: `{"key": "value"}`,
|
||||
expected: `{"key": "value"}`,
|
||||
},
|
||||
{
|
||||
name: "Valid JSON array",
|
||||
input: `[1, 2, "three"]`,
|
||||
expected: `[1, 2, "three"]`,
|
||||
},
|
||||
{
|
||||
name: "Fake JSON (starts with { but doesn't end with })",
|
||||
input: `{"key": "value"`, // Missing closing bracket, has spaces
|
||||
expected: `"{\"key\": \"value\""`,
|
||||
},
|
||||
{
|
||||
name: "Empty JSON (object)",
|
||||
input: `{ }`,
|
||||
expected: `{ }`,
|
||||
},
|
||||
|
||||
// 7. Edge Cases
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Whitespace only string",
|
||||
input: " ",
|
||||
expected: `" "`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := formatFieldValue(tt.input)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("formatFieldValue() = %q, expected %q", actual, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package media
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const TempDirName = "picoclaw_media"
|
||||
|
||||
// TempDir returns the shared temporary directory used for downloaded media.
|
||||
func TempDir() string {
|
||||
return filepath.Join(os.TempDir(), TempDirName)
|
||||
}
|
||||
@@ -221,11 +221,17 @@ func buildRequestBody(
|
||||
|
||||
// Add tool_use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Handle nil Arguments (GLM-4 may return null input)
|
||||
input := tc.Arguments
|
||||
if input == nil {
|
||||
input = map[string]any{}
|
||||
}
|
||||
|
||||
toolUse := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": tc.ID,
|
||||
"name": tc.Name,
|
||||
"input": tc.Arguments,
|
||||
"input": input,
|
||||
}
|
||||
content = append(content, toolUse)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
)
|
||||
|
||||
const (
|
||||
// azureAPIVersion is the Azure OpenAI API version used for all requests.
|
||||
azureAPIVersion = "2024-10-21"
|
||||
defaultRequestTimeout = common.DefaultRequestTimeout
|
||||
)
|
||||
|
||||
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
|
||||
// It handles Azure-specific authentication (api-key header), URL construction
|
||||
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Option configures the Azure Provider.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithRequestTimeout sets the HTTP request timeout.
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(p *Provider) {
|
||||
if timeout > 0 {
|
||||
p.httpClient.Timeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider creates a new Azure OpenAI provider.
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(p)
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
|
||||
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
|
||||
return NewProvider(
|
||||
apiKey, apiBase, proxy,
|
||||
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
)
|
||||
}
|
||||
|
||||
// Chat sends a chat completion request to the Azure OpenAI endpoint.
|
||||
// The model parameter is used as the Azure deployment name in the URL.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("Azure API base not configured")
|
||||
}
|
||||
|
||||
// model is the deployment name for Azure OpenAI
|
||||
deployment := model
|
||||
|
||||
// Build Azure-specific URL safely using url.JoinPath and query encoding
|
||||
// to prevent path traversal or query injection via deployment names.
|
||||
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
|
||||
}
|
||||
requestURL := base + "?api-version=" + azureAPIVersion
|
||||
|
||||
// Build request body — no "model" field (Azure infers from deployment URL)
|
||||
requestBody := map[string]any{
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
// Azure OpenAI always uses max_completion_tokens
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Azure uses api-key header instead of Authorization: Bearer
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Api-Key", p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
|
||||
func writeValidResponse(w http.ResponseWriter) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureURLConstruction(t *testing.T) {
|
||||
var capturedPath string
|
||||
var capturedAPIVersion string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
capturedAPIVersion = r.URL.Query().Get("api-version")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
|
||||
if capturedPath != wantPath {
|
||||
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
|
||||
}
|
||||
if capturedAPIVersion != azureAPIVersion {
|
||||
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureAuthHeader(t *testing.T) {
|
||||
var capturedAPIKey string
|
||||
var capturedAuth string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAPIKey = r.Header.Get("Api-Key")
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-azure-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if capturedAPIKey != "test-azure-key" {
|
||||
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
|
||||
}
|
||||
if capturedAuth != "" {
|
||||
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["model"]; exists {
|
||||
t.Error("request body should not contain 'model' field for Azure OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"deployment",
|
||||
map[string]any{"max_tokens": 2048},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["max_completion_tokens"]; !exists {
|
||||
t.Error("request body should contain 'max_completion_tokens'")
|
||||
}
|
||||
if _, exists := requestBody["max_tokens"]; exists {
|
||||
t.Error("request body should not contain 'max_tokens'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureHTTPError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("bad-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "",
|
||||
"tool_calls": []map[string]any{
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": `{"city":"Seattle"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
|
||||
p := NewProvider("test-key", "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty API base")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "")
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
|
||||
if p.httpClient.Timeout != 300*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
|
||||
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
|
||||
if p.httpClient.Timeout != 180*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
|
||||
var capturedPath string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
|
||||
if capturedPath == "" {
|
||||
capturedPath = r.URL.Path
|
||||
}
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
|
||||
// Deployment name with characters that could cause path injection
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// The slash and special chars in the deployment name must be escaped, not treated as path separators
|
||||
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
|
||||
t.Fatal("deployment name was interpolated without escaping — path injection possible")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package common provides shared utilities used by multiple LLM provider
|
||||
// implementations (openai_compat, azure, etc.).
|
||||
package common
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// Re-export protocol types used across providers.
|
||||
type (
|
||||
ToolCall = protocoltypes.ToolCall
|
||||
FunctionCall = protocoltypes.FunctionCall
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
UsageInfo = protocoltypes.UsageInfo
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
ExtraContent = protocoltypes.ExtraContent
|
||||
GoogleExtra = protocoltypes.GoogleExtra
|
||||
ReasoningDetail = protocoltypes.ReasoningDetail
|
||||
)
|
||||
|
||||
const DefaultRequestTimeout = 120 * time.Second
|
||||
|
||||
// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
|
||||
func NewHTTPClient(proxy string) *http.Client {
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRequestTimeout,
|
||||
}
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
// Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
|
||||
if base, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
tr := base.Clone()
|
||||
tr.Proxy = http.ProxyURL(parsed)
|
||||
client.Transport = tr
|
||||
} else {
|
||||
// Fallback: minimal transport if DefaultTransport is not *http.Transport.
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("common: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// --- Message serialization ---
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// SerializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func SerializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
|
||||
func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]any)
|
||||
name := ""
|
||||
|
||||
// Extract thought_signature from Gemini/Google-specific extra content
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
toolCall := ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
|
||||
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// --- HTTP response helpers ---
|
||||
|
||||
// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
|
||||
func HandleErrorResponse(resp *http.Response, apiBase string) error {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if LooksLikeHTML(body, contentType) {
|
||||
return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
ResponsePreview(body, 128),
|
||||
)
|
||||
}
|
||||
|
||||
// ReadAndParseResponse peeks at the response body to detect HTML errors,
|
||||
// then parses the JSON response into an LLMResponse.
|
||||
func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256)
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if LooksLikeHTML(prefix, contentType) {
|
||||
return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
|
||||
}
|
||||
out, err := ParseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LooksLikeHTML checks if the response body appears to be HTML.
|
||||
func LooksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
// WrapHTMLResponseError creates a descriptive error for HTML responses.
|
||||
func WrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := ResponsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
// ResponsePreview returns a truncated preview of response body for error messages.
|
||||
func ResponsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Numeric helpers ---
|
||||
|
||||
// AsInt converts various numeric types to int.
|
||||
func AsInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// AsFloat converts various numeric types to float64.
|
||||
func AsFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,558 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// --- NewHTTPClient tests ---
|
||||
|
||||
func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Timeout != DefaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_WithProxy(t *testing.T) {
|
||||
client := NewHTTPClient("http://127.0.0.1:8080")
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
|
||||
}
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
|
||||
gotProxy, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy function error: %v", err)
|
||||
}
|
||||
if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
|
||||
t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_NoProxy(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Transport != nil {
|
||||
t.Errorf("expected nil transport without proxy, got %T", client.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
|
||||
// Should not panic, just log and return client without proxy
|
||||
client := NewHTTPClient("://bad-url")
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil client even with invalid proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SerializeMessages tests ---
|
||||
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["content"] != "hello" {
|
||||
t.Errorf("expected plain string content, got %v", msgs[0]["content"])
|
||||
}
|
||||
if msgs[1]["reasoning_content"] != "thinking..." {
|
||||
t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["tool_call_id"] != "call_1" {
|
||||
t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "you are helpful",
|
||||
SystemParts: []protocoltypes.ContentBlock{
|
||||
{Type: "text", Text: "you are helpful"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
if strings.Contains(string(data), "system_parts") {
|
||||
t.Error("system_parts should not appear in serialized output")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse tests ---
|
||||
|
||||
func TestParseResponse_BasicContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "hello world" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "hello world")
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_EmptyChoices(t *testing.T) {
|
||||
body := `{"choices":[]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", out.Content)
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithToolCalls(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["city"] != "SF" {
|
||||
t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithUsage(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Usage == nil {
|
||||
t.Fatal("Usage is nil")
|
||||
}
|
||||
if out.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithReasoningContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.ReasoningContent != "Let me think... 1+1=2" {
|
||||
t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := ParseResponse(strings.NewReader("not json"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- DecodeToolCallArguments tests ---
|
||||
|
||||
func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "Seattle" {
|
||||
t.Errorf("city = %v, want Seattle", args["city"])
|
||||
}
|
||||
if args["units"] != "metric" {
|
||||
t.Errorf("units = %v, want metric", args["units"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "SF" {
|
||||
t.Errorf("city = %v, want SF", args["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(nil, "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_NullInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
|
||||
if _, ok := args["raw"]; !ok {
|
||||
t.Error("expected 'raw' fallback key for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map for whitespace string, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse tests ---
|
||||
|
||||
func TestHandleErrorResponse_JSONError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"bad request"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "400") {
|
||||
t.Errorf("error should contain status code, got %v", err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "HTML") {
|
||||
t.Errorf("should not mention HTML for JSON error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleErrorResponse_HTMLError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>bad gateway</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error message, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse tests ---
|
||||
|
||||
func TestReadAndParseResponse_ValidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, err := ReadAndParseResponse(resp, server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAndParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "ok" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>login page</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTML response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- LooksLikeHTML tests ---
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
|
||||
t.Error("expected true for text/html content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "application/xhtml+xml") {
|
||||
t.Error("expected true for xhtml content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"doctype", "<!DOCTYPE html><html>"},
|
||||
{"html tag", "<html><body>"},
|
||||
{"head tag", "<head><title>"},
|
||||
{"body tag", "<body>content"},
|
||||
{"whitespace before", " \n\t<!DOCTYPE html>"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !LooksLikeHTML([]byte(tt.body), "application/json") {
|
||||
t.Errorf("expected true for body %q", tt.body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_NotHTML(t *testing.T) {
|
||||
if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
|
||||
t.Error("expected false for JSON body")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResponsePreview tests ---
|
||||
|
||||
func TestResponsePreview_Short(t *testing.T) {
|
||||
got := ResponsePreview([]byte("hello"), 128)
|
||||
if got != "hello" {
|
||||
t.Errorf("got %q, want %q", got, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Truncated(t *testing.T) {
|
||||
body := strings.Repeat("a", 200)
|
||||
got := ResponsePreview([]byte(body), 128)
|
||||
if len(got) != 131 { // 128 + "..."
|
||||
t.Errorf("len = %d, want 131", len(got))
|
||||
}
|
||||
if !strings.HasSuffix(got, "...") {
|
||||
t.Error("expected ... suffix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Empty(t *testing.T) {
|
||||
got := ResponsePreview([]byte(""), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Whitespace(t *testing.T) {
|
||||
got := ResponsePreview([]byte(" \n\t "), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsInt tests ---
|
||||
|
||||
func TestAsInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want int
|
||||
ok bool
|
||||
}{
|
||||
{"int", 42, 42, true},
|
||||
{"int64", int64(99), 99, true},
|
||||
{"float64", float64(512), 512, true},
|
||||
{"float32", float32(256), 256, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsInt(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsFloat tests ---
|
||||
|
||||
func TestAsFloat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want float64
|
||||
ok bool
|
||||
}{
|
||||
{"float64", float64(0.7), 0.7, true},
|
||||
{"float32", float32(0.5), float64(float32(0.5)), true},
|
||||
{"int", 1, 1.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsFloat(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- WrapHTMLResponseError tests ---
|
||||
|
||||
func TestWrapHTMLResponseError(t *testing.T) {
|
||||
err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "502") {
|
||||
t.Errorf("expected status code in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "https://api.example.com") {
|
||||
t.Errorf("expected api base in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML mention in error, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse with read failure ---
|
||||
|
||||
func TestHandleErrorResponse_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
// empty body
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected status code, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse with invalid JSON ---
|
||||
|
||||
func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("not valid json"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse with thought_signature (Google/Gemini) ---
|
||||
|
||||
func TestParseResponse_WithThoughtSignature(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].ThoughtSignature != "sig123" {
|
||||
t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
|
||||
t.Fatal("ExtraContent.Google is nil")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
|
||||
t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
|
||||
out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
||||
@@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "azure", "azure-openai":
|
||||
// Azure OpenAI uses deployment-based URLs, api-key header auth,
|
||||
// and always sends max_completion_tokens.
|
||||
if cfg.APIKey == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for azure protocol")
|
||||
}
|
||||
if cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf(
|
||||
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
|
||||
)
|
||||
}
|
||||
return azure.NewProviderWithTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
|
||||
@@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) {
|
||||
wantProtocol: "nvidia",
|
||||
wantModelID: "meta/llama-3.1-8b",
|
||||
},
|
||||
{
|
||||
name: "azure with prefix",
|
||||
model: "azure/my-gpt5-deployment",
|
||||
wantProtocol: "azure",
|
||||
wantModelID: "my-gpt5-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
|
||||
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Azure(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-gpt5-deployment" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt4",
|
||||
Model: "azure-openai/my-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-deployment" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -38,7 +36,7 @@ type Provider struct {
|
||||
|
||||
type Option func(*Provider)
|
||||
|
||||
const defaultRequestTimeout = 120 * time.Second
|
||||
const defaultRequestTimeout = common.DefaultRequestTimeout
|
||||
|
||||
func WithMaxTokensField(maxTokensField string) Option {
|
||||
return func(p *Provider) {
|
||||
@@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option {
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: defaultRequestTimeout,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
} else {
|
||||
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -117,7 +100,7 @@ func (p *Provider) Chat(
|
||||
|
||||
requestBody := map[string]any{
|
||||
"model": model,
|
||||
"messages": serializeMessages(messages),
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
@@ -125,7 +108,7 @@ func (p *Provider) Chat(
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := asInt(options["max_tokens"]); ok {
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
|
||||
fieldName := p.maxTokensField
|
||||
if fieldName == "" {
|
||||
@@ -141,7 +124,7 @@ func (p *Provider) Chat(
|
||||
requestBody[fieldName] = maxTokens
|
||||
}
|
||||
|
||||
if temperature, ok := asFloat(options["temperature"]); ok {
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1.
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
@@ -185,275 +168,11 @@ func (p *Provider) Chat(
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if looksLikeHTML(body, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
responsePreview(body, 128),
|
||||
)
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
// Peek without consuming so the full stream reaches the JSON decoder.
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if looksLikeHTML(prefix, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
|
||||
}
|
||||
|
||||
out, err := parseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := responsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func parseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]any)
|
||||
name := ""
|
||||
|
||||
// Extract thought_signature from Gemini/Google-specific extra content
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
|
||||
toolCall := ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// serializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func serializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
@@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func asInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
@@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
@@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
@@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
raw := string(data)
|
||||
|
||||
+53
-20
@@ -20,10 +20,12 @@ type JobExecutor interface {
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
allowCommand bool
|
||||
execEnabled bool
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
@@ -32,17 +34,32 @@ func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) (*CronTool, error) {
|
||||
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
allowCommand := true
|
||||
execEnabled := true
|
||||
if config != nil {
|
||||
allowCommand = config.Tools.Cron.AllowCommand
|
||||
execEnabled = config.Tools.Exec.Enabled
|
||||
}
|
||||
|
||||
execTool.SetTimeout(execTimeout)
|
||||
var execTool *ExecTool
|
||||
if execEnabled {
|
||||
var err error
|
||||
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if execTool != nil {
|
||||
execTool.SetTimeout(execTimeout)
|
||||
}
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
allowCommand: allowCommand,
|
||||
execEnabled: execEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"command_confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
|
||||
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
|
||||
},
|
||||
"at_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"deliver": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
// Read deliver parameter, default to false so scheduled tasks execute through the agent
|
||||
deliver := false
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
|
||||
// Non-command reminders (plain messages) remain open to all channels.
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
|
||||
// allow_command is disabled, explicit confirmation is required as an override.
|
||||
// Non-command reminders remain open to all channels.
|
||||
command, _ := args["command"].(string)
|
||||
commandConfirm, _ := args["command_confirm"].(bool)
|
||||
if command != "" {
|
||||
if !t.execEnabled {
|
||||
return ErrorResult("command execution is disabled")
|
||||
}
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
return ErrorResult("scheduling command execution is restricted to internal channels")
|
||||
}
|
||||
if !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required to schedule command execution")
|
||||
if !t.allowCommand && !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required when allow_command is disabled")
|
||||
}
|
||||
deliver = false
|
||||
}
|
||||
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
|
||||
// Execute command if present
|
||||
if job.Payload.Command != "" {
|
||||
if !t.execEnabled || t.execTool == nil {
|
||||
output := "Error executing scheduled command: command execution is disabled"
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
args := map[string]any{
|
||||
"command": job.Payload.Command,
|
||||
"__channel": channel,
|
||||
|
||||
+126
-6
@@ -5,18 +5,18 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
)
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.DefaultConfig()
|
||||
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCronTool() error: %v", err)
|
||||
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
|
||||
return tool
|
||||
}
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
t.Helper()
|
||||
return newTestCronToolWithConfig(t, config.DefaultConfig())
|
||||
}
|
||||
|
||||
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
|
||||
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
|
||||
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when command_confirm is missing")
|
||||
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command_confirm=true") {
|
||||
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
|
||||
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected command scheduling to be blocked when exec is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command execution is disabled") {
|
||||
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "send me a poem",
|
||||
"at_seconds": float64(600),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
jobs := tool.cronService.ListJobs(false)
|
||||
if len(jobs) != 1 {
|
||||
t.Fatalf("expected 1 job, got %d", len(jobs))
|
||||
}
|
||||
if jobs[0].Payload.Deliver {
|
||||
t.Fatal("expected deliver=false by default for non-command jobs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
job := &cron.CronJob{}
|
||||
job.Payload.Channel = "cli"
|
||||
job.Payload.To = "direct"
|
||||
job.Payload.Command = "df -h"
|
||||
|
||||
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
|
||||
t.Fatalf("ExecuteJob() = %q, want ok", got)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
if !strings.Contains(msg.Content, "command execution is disabled") {
|
||||
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
+161
-9
@@ -20,8 +20,7 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
// validatePath ensures the given path is within the workspace if restrict is true.
|
||||
func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
|
||||
if restrict {
|
||||
if isAllowedPath(absPath, patterns) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
if len(patterns) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(path)
|
||||
if !filepath.IsAbs(cleaned) {
|
||||
return false
|
||||
}
|
||||
if !matchesAllowedPath(cleaned, patterns) {
|
||||
return false
|
||||
}
|
||||
|
||||
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchesAllowedPath(resolved, patterns)
|
||||
}
|
||||
|
||||
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
cleaned := filepath.Clean(path)
|
||||
for _, pattern := range patterns {
|
||||
if pattern.MatchString(cleaned) {
|
||||
return true
|
||||
}
|
||||
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
|
||||
raw := pattern.String()
|
||||
if !strings.HasPrefix(raw, "^") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
literal := strings.TrimPrefix(raw, "^")
|
||||
|
||||
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
|
||||
literal = strings.TrimSuffix(literal, "(?:/|$)")
|
||||
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
|
||||
|
||||
// Reject patterns that still contain regex operators after removing the
|
||||
// optional anchored-directory suffix. That keeps arbitrary regex behavior
|
||||
// unchanged and only enables normalized prefix matching for literal paths.
|
||||
if containsUnescapedRegexMeta(literal) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
unescaped, ok := unescapeRegexLiteral(literal)
|
||||
if !ok || unescaped == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
|
||||
}
|
||||
|
||||
func appendUniquePath(paths []string, path string) []string {
|
||||
for _, existing := range paths {
|
||||
if existing == path {
|
||||
return paths
|
||||
}
|
||||
}
|
||||
return append(paths, path)
|
||||
}
|
||||
|
||||
func containsUnescapedRegexMeta(s string) bool {
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
|
||||
return true
|
||||
}
|
||||
}
|
||||
return escaped
|
||||
}
|
||||
|
||||
func unescapeRegexLiteral(s string) (string, bool) {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
b.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
if escaped {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func isWithinAllowedRoot(path, root string) bool {
|
||||
candidate := filepath.Clean(path)
|
||||
allowedVariants := []string{filepath.Clean(root)}
|
||||
|
||||
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
|
||||
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
|
||||
}
|
||||
|
||||
for _, allowedRoot := range allowedVariants {
|
||||
if isWithinWorkspace(candidate, allowedRoot) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolvePathAgainstExistingAncestor(path string) (string, error) {
|
||||
cleaned := filepath.Clean(path)
|
||||
for current := cleaned; ; current = filepath.Dir(current) {
|
||||
resolved, err := filepath.EvalSymlinks(current)
|
||||
if err == nil {
|
||||
suffix, relErr := filepath.Rel(current, cleaned)
|
||||
if relErr != nil {
|
||||
return "", relErr
|
||||
}
|
||||
if suffix == "." {
|
||||
return filepath.Clean(resolved), nil
|
||||
}
|
||||
return filepath.Clean(filepath.Join(resolved, suffix)), nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
return err == nil && (rel == "." || filepath.IsLocal(rel))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
@@ -625,12 +782,7 @@ type whitelistFs struct {
|
||||
}
|
||||
|
||||
func (w *whitelistFs) matches(path string) bool {
|
||||
for _, p := range w.patterns {
|
||||
if p.MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return isAllowedPath(path, w.patterns)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
|
||||
|
||||
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
allowedDir := t.TempDir()
|
||||
secretDir := t.TempDir()
|
||||
secretFile := filepath.Join(secretDir, "secret.txt")
|
||||
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(secretFile) error = %v", err)
|
||||
}
|
||||
|
||||
linkPath := filepath.Join(allowedDir, "link_out")
|
||||
if err := os.Symlink(secretDir, linkPath); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
rootDir := t.TempDir()
|
||||
allowedDir := filepath.Join(rootDir, "allowed")
|
||||
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewWriteFileTool(workspace, true, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": targetFile,
|
||||
"content": "outside write",
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(targetFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile(targetFile) error = %v", err)
|
||||
}
|
||||
if string(data) != "outside write" {
|
||||
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
realDir := t.TempDir()
|
||||
linkParent := t.TempDir()
|
||||
allowedAlias := filepath.Join(linkParent, "allowed-link")
|
||||
|
||||
if err := os.Symlink(realDir, allowedAlias); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
|
||||
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(targetFile) error = %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
|
||||
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
),
|
||||
}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "through alias") {
|
||||
t.Fatalf("expected file content, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
|
||||
// by reading a file in multiple chunks using 'offset' and 'length'.
|
||||
func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
|
||||
+15
-2
@@ -6,6 +6,7 @@ import (
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
@@ -21,20 +22,32 @@ type SendFileTool struct {
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
}
|
||||
|
||||
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
|
||||
func NewSendFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
store media.MediaStore,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *SendFileTool {
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &SendFileTool{
|
||||
workspace: workspace,
|
||||
restrict: restrict,
|
||||
maxFileSize: maxFileSize,
|
||||
mediaStore: store,
|
||||
allowPaths: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("media store not configured")
|
||||
}
|
||||
|
||||
resolved, err := validatePath(path, t.workspace, t.restrict)
|
||||
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
testPath := testFile.Name()
|
||||
if _, err := testFile.WriteString("forward me"); err != nil {
|
||||
testFile.Close()
|
||||
t.Fatalf("WriteString(testFile) error = %v", err)
|
||||
}
|
||||
if err := testFile.Close(); err != nil {
|
||||
t.Fatalf("Close(testFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(testPath) })
|
||||
|
||||
pattern := regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_MagicBytes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
+31
-13
@@ -23,6 +23,7 @@ type ExecTool struct {
|
||||
denyPatterns []*regexp.Regexp
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
allowedPathPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
allowRemote bool
|
||||
}
|
||||
@@ -95,14 +96,23 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
func NewExecToolWithConfig(
|
||||
workingDir string,
|
||||
restrict bool,
|
||||
config *config.Config,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
var allowedPathPatterns []*regexp.Regexp
|
||||
allowRemote := true
|
||||
if len(allowPaths) > 0 {
|
||||
allowedPathPatterns = allowPaths[0]
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
allowedPathPatterns: allowedPathPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
allowRemote: allowRemote,
|
||||
}, nil
|
||||
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePath(wd, t.workingDir, true)
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
if isAllowedPath(resolved, t.allowedPathPatterns) {
|
||||
cwd = resolved
|
||||
} else {
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
|
||||
// timeout == 0 means no timeout
|
||||
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
if safePaths[p] {
|
||||
continue
|
||||
}
|
||||
if isAllowedPath(p, t.allowedPathPatterns) {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(cwdPath, p)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SpawnStatusTool reports the status of subagents that were spawned via the
|
||||
// spawn tool. It can query a specific task by ID, or list every known task with
|
||||
// a summary count broken-down by status.
|
||||
type SpawnStatusTool struct {
|
||||
manager *SubagentManager
|
||||
}
|
||||
|
||||
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
|
||||
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
|
||||
return &SpawnStatusTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Name() string {
|
||||
return "spawn_status"
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Description() string {
|
||||
return "Get the status of spawned subagents. " +
|
||||
"Returns a list of all subagents and their current state " +
|
||||
"(running, completed, failed, or canceled), or retrieves details " +
|
||||
"for a specific subagent task when task_id is provided. " +
|
||||
"Results are scoped to the current conversation's channel and chat ID; " +
|
||||
"all tasks are listed only when no channel/chat context is injected " +
|
||||
"(e.g. direct programmatic calls via Execute)."
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"task_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
|
||||
"subagent. When omitted, all visible subagents are listed.",
|
||||
},
|
||||
},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Derive the calling conversation's identity so we can scope results to the
|
||||
// current chat only — preventing cross-conversation task leakage in
|
||||
// multi-user deployments.
|
||||
callerChannel := ToolChannel(ctx)
|
||||
callerChatID := ToolChatID(ctx)
|
||||
|
||||
var taskID string
|
||||
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
|
||||
taskIDStr, ok := rawTaskID.(string)
|
||||
if !ok {
|
||||
return ErrorResult("task_id must be a string")
|
||||
}
|
||||
taskID = strings.TrimSpace(taskIDStr)
|
||||
}
|
||||
|
||||
if taskID != "" {
|
||||
// GetTaskCopy returns a consistent snapshot under the manager lock,
|
||||
// eliminating any data race with the concurrent subagent goroutine.
|
||||
taskCopy, ok := t.manager.GetTaskCopy(taskID)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
// Restrict lookup to tasks that belong to this conversation.
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
return NewToolResult(spawnStatusFormatTask(&taskCopy))
|
||||
}
|
||||
|
||||
// ListTaskCopies returns consistent snapshots under the manager lock.
|
||||
origTasks := t.manager.ListTaskCopies()
|
||||
if len(origTasks) == 0 {
|
||||
return NewToolResult("No subagents have been spawned yet.")
|
||||
}
|
||||
|
||||
tasks := make([]*SubagentTask, 0, len(origTasks))
|
||||
for i := range origTasks {
|
||||
cpy := &origTasks[i]
|
||||
|
||||
// Filter to tasks that originate from the current conversation only.
|
||||
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
|
||||
continue
|
||||
}
|
||||
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
|
||||
continue
|
||||
}
|
||||
|
||||
tasks = append(tasks, cpy)
|
||||
}
|
||||
|
||||
if len(tasks) == 0 {
|
||||
return NewToolResult("No subagents found for this conversation.")
|
||||
}
|
||||
|
||||
// Order by creation time (ascending) so spawning order is preserved.
|
||||
// Fall back to ID string for tasks created in the same millisecond.
|
||||
sort.Slice(tasks, func(i, j int) bool {
|
||||
if tasks[i].Created != tasks[j].Created {
|
||||
return tasks[i].Created < tasks[j].Created
|
||||
}
|
||||
return tasks[i].ID < tasks[j].ID
|
||||
})
|
||||
|
||||
counts := map[string]int{}
|
||||
for _, task := range tasks {
|
||||
counts[task.Status]++
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
|
||||
for _, status := range []string{"running", "completed", "failed", "canceled"} {
|
||||
if n := counts[status]; n > 0 {
|
||||
label := strings.ToUpper(status[:1]) + status[1:] + ":"
|
||||
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, task := range tasks {
|
||||
sb.WriteString(spawnStatusFormatTask(task))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
|
||||
}
|
||||
|
||||
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
|
||||
func spawnStatusFormatTask(task *SubagentTask) string {
|
||||
var sb strings.Builder
|
||||
|
||||
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
|
||||
if task.Label != "" {
|
||||
header += fmt.Sprintf(" label=%q", task.Label)
|
||||
}
|
||||
if task.AgentID != "" {
|
||||
header += fmt.Sprintf(" agent=%s", task.AgentID)
|
||||
}
|
||||
if task.Created > 0 {
|
||||
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
|
||||
header += fmt.Sprintf(" created=%s", created)
|
||||
}
|
||||
sb.WriteString(header)
|
||||
|
||||
if task.Task != "" {
|
||||
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
|
||||
}
|
||||
if task.Result != "" {
|
||||
result := task.Result
|
||||
const maxResultLen = 300
|
||||
runes := []rune(result)
|
||||
if len(runes) > maxResultLen {
|
||||
result = string(runes[:maxResultLen]) + "…"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("\n result: %s", result))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSpawnStatusTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
if tool.Name() != "spawn_status" {
|
||||
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(desc), "subagent") {
|
||||
t.Errorf("Description should mention 'subagent', got: %s", desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got: %v", params["type"])
|
||||
}
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected 'properties' to be a map")
|
||||
}
|
||||
if _, hasTaskID := props["task_id"]; !hasTaskID {
|
||||
t.Error("Expected 'task_id' parameter in properties")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_NilManager(t *testing.T) {
|
||||
tool := &SpawnStatusTool{manager: nil}
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Error("Expected error result when manager is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Empty(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "No subagents") {
|
||||
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Do task A",
|
||||
Label: "task-a",
|
||||
Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2",
|
||||
Task: "Do task B",
|
||||
Label: "task-b",
|
||||
Status: "completed",
|
||||
Result: "Done successfully",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-3"] = &SubagentTask{
|
||||
ID: "subagent-3",
|
||||
Task: "Do task C",
|
||||
Status: "failed",
|
||||
Result: "Error: something went wrong",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Summary header
|
||||
if !strings.Contains(result.ForLLM, "3 total") {
|
||||
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Individual task IDs
|
||||
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
|
||||
if !strings.Contains(result.ForLLM, id) {
|
||||
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Status values
|
||||
for _, status := range []string{"running", "completed", "failed"} {
|
||||
if !strings.Contains(result.ForLLM, status) {
|
||||
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Result content
|
||||
if !strings.Contains(result.ForLLM, "Done successfully") {
|
||||
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-42"] = &SubagentTask{
|
||||
ID: "subagent-42",
|
||||
Task: "Specific task",
|
||||
Label: "my-task",
|
||||
Status: "failed",
|
||||
Result: "Something went wrong",
|
||||
Created: time.Now().UnixMilli(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-42") {
|
||||
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "failed") {
|
||||
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Something went wrong") {
|
||||
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "my-task") {
|
||||
t.Errorf("Expected label in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "nonexistent-999") {
|
||||
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task_id must be a string") {
|
||||
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
longResult := strings.Repeat("X", 500)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Long task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// Output should be shorter than the raw result due to truncation
|
||||
if len(result.ForLLM) >= len(longResult) {
|
||||
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
|
||||
cjkChar := string(rune(0x5b57))
|
||||
longResult := strings.Repeat(cjkChar, 400)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Unicode task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator in output")
|
||||
}
|
||||
// The truncated result must be valid UTF-8 (no split rune boundaries).
|
||||
if !strings.Contains(result.ForLLM, cjkChar) {
|
||||
t.Errorf("Expected CJK runes to appear intact in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
|
||||
id := fmt.Sprintf("subagent-%d", i+1)
|
||||
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// The summary line should mention all statuses that have counts
|
||||
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
|
||||
if !strings.Contains(result.ForLLM, want) {
|
||||
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
// Intentionally insert with out-of-order IDs and timestamps that reflect
|
||||
// true spawn order: subagent-2 was spawned first, subagent-10 second.
|
||||
manager.tasks["subagent-10"] = &SubagentTask{
|
||||
ID: "subagent-10", Task: "second", Status: "running",
|
||||
Created: now + 1,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "first", Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
pos2 := strings.Index(result.ForLLM, "subagent-2")
|
||||
pos10 := strings.Index(result.ForLLM, "subagent-10")
|
||||
if pos2 < 0 || pos10 < 0 {
|
||||
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
|
||||
}
|
||||
if pos2 > pos10 {
|
||||
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "mine", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "other user", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-B",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Caller is chat-A — should only see subagent-1.
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "subagent-2") {
|
||||
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-99"] = &SubagentTask{
|
||||
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
|
||||
OriginChannel: "slack", OriginChatID: "room-Z",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Different chat trying to look up subagent-99 by ID.
|
||||
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
|
||||
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "t", Status: "completed",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// No ToolContext injected (e.g. a direct programmatic call that bypasses
|
||||
// WithToolContext entirely) — callerChannel and callerChatID are both "".
|
||||
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
|
||||
// which *does* inject a non-empty context; this test covers the case where
|
||||
// no context injection happens at all.
|
||||
// The filter conditions require a non-empty caller value, so all tasks pass through.
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -255,6 +255,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
|
||||
return task, ok
|
||||
}
|
||||
|
||||
// GetTaskCopy returns a copy of the task with the given ID, taken under the
|
||||
// read lock, so the caller receives a consistent snapshot with no data race.
|
||||
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if !ok {
|
||||
return SubagentTask{}, false
|
||||
}
|
||||
return *task, true
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
@@ -266,6 +278,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
return tasks
|
||||
}
|
||||
|
||||
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
|
||||
// so callers receive consistent snapshots with no data race.
|
||||
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
copies := make([]SubagentTask, 0, len(sm.tasks))
|
||||
for _, task := range sm.tasks {
|
||||
copies = append(copies, *task)
|
||||
}
|
||||
return copies
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// It directly calls SubTurnSpawner with Async=false for synchronous execution.
|
||||
type SubagentTool struct {
|
||||
|
||||
+2
-1
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
|
||||
@@ -67,7 +68,7 @@ func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
|
||||
opts.LoggerPrefix = "utils"
|
||||
}
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{
|
||||
"error": err.Error(),
|
||||
|
||||
Reference in New Issue
Block a user