mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
chore: revert unrelated golines formatting
This commit is contained in:
+5
-17
@@ -15,10 +15,7 @@ import (
|
||||
|
||||
// JobExecutor is the interface for executing cron jobs through the agent
|
||||
type JobExecutor interface {
|
||||
ProcessDirectWithChannel(
|
||||
ctx context.Context,
|
||||
content, sessionKey, channel, chatID string,
|
||||
) (string, error)
|
||||
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
|
||||
// PublishResponseIfNeeded sends response to the outbound bus only when the
|
||||
// agent did not already deliver content through the message tool in this round.
|
||||
PublishResponseIfNeeded(ctx context.Context, channel, chatID, response string)
|
||||
@@ -37,13 +34,8 @@ type CronTool struct {
|
||||
// NewCronTool creates a new CronTool
|
||||
// execTimeout: 0 means no timeout, >0 sets the timeout duration
|
||||
func NewCronTool(
|
||||
cronService *cron.CronService,
|
||||
executor JobExecutor,
|
||||
msgBus *bus.MessageBus,
|
||||
workspace string,
|
||||
restrict bool,
|
||||
execTimeout time.Duration,
|
||||
config *config.Config,
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) (*CronTool, error) {
|
||||
allowCommand := true
|
||||
execEnabled := true
|
||||
@@ -164,9 +156,7 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
chatID := ToolChatID(ctx)
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return ErrorResult(
|
||||
"no session context (channel/chat_id not set). Use this tool in an active conversation.",
|
||||
)
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
@@ -218,9 +208,7 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
// Validate type parameter (server-side whitelist, not just LLM schema hint)
|
||||
msgType, _ := args["type"].(string)
|
||||
if msgType != "" && msgType != "message" && msgType != "directive" {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("invalid type %q, must be 'message' or 'directive'", msgType),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("invalid type %q, must be 'message' or 'directive'", msgType))
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
|
||||
|
||||
+8
-34
@@ -49,11 +49,7 @@ func (s *stubJobExecutor) PublishResponseIfNeeded(
|
||||
s.publishedChatID = chatID
|
||||
}
|
||||
|
||||
func newTestCronToolWithExecutorAndConfig(
|
||||
t *testing.T,
|
||||
executor JobExecutor,
|
||||
cfg *config.Config,
|
||||
) *CronTool {
|
||||
func newTestCronToolWithExecutorAndConfig(t *testing.T, executor JobExecutor, cfg *config.Config) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
@@ -106,10 +102,7 @@ func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling without confirm to succeed by default, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
@@ -197,10 +190,7 @@ func TestCronTool_CommandAllowedFromInternalChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling to succeed from internal channel, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected command scheduling to succeed from internal channel, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
@@ -235,10 +225,7 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected non-command reminder to succeed from remote channel, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,11 +297,7 @@ func TestCronTool_ExecuteJobPublishesAgentResponse(t *testing.T) {
|
||||
t.Fatalf("sessionKey = %q, want cron-job-1", executor.lastKey)
|
||||
}
|
||||
if executor.lastChan != "telegram" || executor.lastChatID != "chat-1" {
|
||||
t.Fatalf(
|
||||
"executor target = %s/%s, want telegram/chat-1",
|
||||
executor.lastChan,
|
||||
executor.lastChatID,
|
||||
)
|
||||
t.Fatalf("executor target = %s/%s, want telegram/chat-1", executor.lastChan, executor.lastChatID)
|
||||
}
|
||||
if executor.lastPrompt != "send me a poem" {
|
||||
t.Fatalf("prompt = %q, want original message", executor.lastPrompt)
|
||||
@@ -323,11 +306,7 @@ func TestCronTool_ExecuteJobPublishesAgentResponse(t *testing.T) {
|
||||
t.Fatalf("published response = %q, want generated reply", executor.publishedResp)
|
||||
}
|
||||
if executor.publishedChan != "telegram" || executor.publishedChatID != "chat-1" {
|
||||
t.Fatalf(
|
||||
"published target = %s/%s, want telegram/chat-1",
|
||||
executor.publishedChan,
|
||||
executor.publishedChatID,
|
||||
)
|
||||
t.Fatalf("published target = %s/%s, want telegram/chat-1", executor.publishedChan, executor.publishedChatID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,10 +342,7 @@ func TestCronTool_ExecuteJobSkipsWhenMessageToolAlreadySent(t *testing.T) {
|
||||
}
|
||||
|
||||
if executor.publishedResp != "" {
|
||||
t.Fatalf(
|
||||
"expected no published response when message tool already sent, got: %q",
|
||||
executor.publishedResp,
|
||||
)
|
||||
t.Fatalf("expected no published response when message tool already sent, got: %q", executor.publishedResp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,9 +386,7 @@ func TestCronTool_ExecuteJobDirectiveWithDeliverRoutesToAgent(t *testing.T) {
|
||||
}
|
||||
|
||||
if executor.lastPrompt == "" {
|
||||
t.Fatal(
|
||||
"expected agent to be called for directive+deliver, but ProcessDirectWithChannel was not invoked",
|
||||
)
|
||||
t.Fatal("expected agent to be called for directive+deliver, but ProcessDirectWithChannel was not invoked")
|
||||
}
|
||||
if executor.publishedResp != "agent processed" {
|
||||
t.Fatalf("published response = %q, want %q", executor.publishedResp, "agent processed")
|
||||
|
||||
+3
-14
@@ -16,11 +16,7 @@ type EditFileTool struct {
|
||||
}
|
||||
|
||||
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
|
||||
func NewEditFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *EditFileTool {
|
||||
func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
@@ -83,11 +79,7 @@ type AppendFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewAppendFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *AppendFileTool {
|
||||
func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
@@ -174,10 +166,7 @@ func replaceEditContent(content []byte, oldText, newText string) ([]byte, error)
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf(
|
||||
"old_text appears %d times. Please provide more context to make it unique",
|
||||
count,
|
||||
)
|
||||
return nil, fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
@@ -76,8 +76,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention file not found
|
||||
if !strings.Contains(result.ForLLM, "not found") &&
|
||||
!strings.Contains(result.ForUser, "not found") {
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -104,8 +103,7 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention old_text not found
|
||||
if !strings.Contains(result.ForLLM, "not found") &&
|
||||
!strings.Contains(result.ForUser, "not found") {
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+3
-13
@@ -20,11 +20,7 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
func validatePathWithAllowPaths(
|
||||
path, workspace string,
|
||||
restrict bool,
|
||||
patterns []*regexp.Regexp,
|
||||
) (string, error) {
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -487,11 +483,7 @@ type WriteFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewWriteFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *WriteFileTool {
|
||||
func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
@@ -544,9 +536,7 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
||||
|
||||
if !overwrite {
|
||||
if _, err := t.fs.Open(path); err == nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,13 +59,8 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") &&
|
||||
!strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf(
|
||||
"Expected error message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,8 +78,7 @@ func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "path is required") &&
|
||||
!strings.Contains(result.ForUser, "path is required") {
|
||||
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
|
||||
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -303,12 +297,7 @@ func TestFilesystemTool_WriteFile_OverwriteSandboxed(t *testing.T) {
|
||||
"content": "replaced in sandbox",
|
||||
"overwrite": true,
|
||||
})
|
||||
assert.False(
|
||||
t,
|
||||
result.IsError,
|
||||
"expected success in sandbox mode with overwrite=true, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
assert.False(t, result.IsError, "expected success in sandbox mode with overwrite=true, got: %s", result.ForLLM)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, testFile))
|
||||
assert.NoError(t, err)
|
||||
@@ -336,8 +325,7 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should list files and directories
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") ||
|
||||
!strings.Contains(result.ForLLM, "file2.txt") {
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
|
||||
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subdir") {
|
||||
@@ -361,13 +349,8 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") &&
|
||||
!strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf(
|
||||
"Expected error message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,8 +397,7 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
// os.Root might return different errors depending on platform/implementation
|
||||
// but it definitely should error.
|
||||
// Our wrapper returns "access denied or file not found"
|
||||
if !strings.Contains(result.ForLLM, "access denied") &&
|
||||
!strings.Contains(result.ForLLM, "file not found") &&
|
||||
if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
|
||||
!strings.Contains(result.ForLLM, "no such file") {
|
||||
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
|
||||
}
|
||||
@@ -434,20 +416,10 @@ func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
|
||||
})
|
||||
|
||||
// We EXPECT IsError=true (access blocked due to empty workspace)
|
||||
assert.True(
|
||||
t,
|
||||
result.IsError,
|
||||
"Security Regression: Empty workspace allowed access! content: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
|
||||
|
||||
// Verify it failed for the right reason
|
||||
assert.Contains(
|
||||
t,
|
||||
result.ForLLM,
|
||||
"workspace is not defined",
|
||||
"Expected 'workspace is not defined' error",
|
||||
)
|
||||
assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
|
||||
}
|
||||
|
||||
// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
|
||||
@@ -681,10 +653,7 @@ func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"path": filepath.Join(linkPath, "secret.txt")},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
+2
-6
@@ -65,9 +65,7 @@ func (t *I2CTool) Parameters() map[string]any {
|
||||
|
||||
func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult(
|
||||
"I2C is only supported on Linux. This tool requires /dev/i2c-* device files.",
|
||||
)
|
||||
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
|
||||
}
|
||||
|
||||
action, ok := args["action"].(string)
|
||||
@@ -85,9 +83,7 @@ func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
case "write":
|
||||
return t.writeDevice(args)
|
||||
default:
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+7
-35
@@ -55,12 +55,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
|
||||
size: i2cSmbusQuick,
|
||||
data: nil,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
i2cSmbus,
|
||||
uintptr(unsafe.Pointer(&args)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
@@ -72,12 +67,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
|
||||
size: i2cSmbusByte,
|
||||
data: &data,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
i2cSmbus,
|
||||
uintptr(unsafe.Pointer(&args)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
@@ -93,29 +83,16 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"failed to open %s: %v (check permissions and i2c-dev module)",
|
||||
devPath,
|
||||
err,
|
||||
),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Query adapter capabilities to determine available probe methods.
|
||||
// I2C_FUNCS writes an unsigned long, which is word-sized on Linux.
|
||||
var funcs uintptr
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
i2cFuncs,
|
||||
uintptr(unsafe.Pointer(&funcs)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cFuncs, uintptr(unsafe.Pointer(&funcs)))
|
||||
if errno != 0 {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno))
|
||||
}
|
||||
|
||||
hasQuick := funcs&i2cFuncSmbusQuick != 0
|
||||
@@ -123,10 +100,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
|
||||
if !hasQuick && !hasReadByte {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely",
|
||||
devPath,
|
||||
),
|
||||
fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -158,9 +132,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if len(found) == 0 {
|
||||
return SilentResult(
|
||||
fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath),
|
||||
)
|
||||
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
|
||||
+6
-28
@@ -314,10 +314,7 @@ func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Cont
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeEmbeddedResource(
|
||||
ctx context.Context,
|
||||
content *mcp.EmbeddedResource,
|
||||
) (string, string) {
|
||||
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) {
|
||||
if content == nil || content.Resource == nil {
|
||||
return "", "[MCP returned an embedded resource without data.]"
|
||||
}
|
||||
@@ -377,39 +374,23 @@ func (t *MCPTool) storeBinaryContent(
|
||||
|
||||
dir := media.TempDir()
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
ext := extensionForMIMEType(mimeType)
|
||||
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, err = tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be stored.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
scope := fmt.Sprintf(
|
||||
@@ -489,10 +470,7 @@ func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"[MCP returned embedded resource (%s).]",
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
)
|
||||
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
|
||||
}
|
||||
|
||||
func annotationsAllowUser(annotations *mcp.Annotations) bool {
|
||||
|
||||
@@ -571,10 +571,7 @@ func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) {
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf(
|
||||
"expected embedded resource blob to be stored as media, got %d refs",
|
||||
len(result.Media),
|
||||
)
|
||||
t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media))
|
||||
}
|
||||
path, _, err := store.ResolveWithMeta(result.Media[0])
|
||||
if err != nil {
|
||||
|
||||
@@ -43,10 +43,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
|
||||
// - ForLLM contains send status description
|
||||
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
|
||||
t.Errorf(
|
||||
"Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
|
||||
// - ForUser is empty (user already received message directly)
|
||||
@@ -91,10 +88,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
t.Error("Expected Silent=true")
|
||||
}
|
||||
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
|
||||
t.Errorf(
|
||||
"Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -215,43 +215,28 @@ func storeInlineDataURL(
|
||||
payload = strings.NewReplacer("\n", "", "\r", "", "\t", "", " ", "").Replace(payload)
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) that could not be decoded.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) that could not be decoded.]", mimeType)
|
||||
}
|
||||
|
||||
dir := media.TempDir()
|
||||
if err = os.MkdirAll(dir, 0o700); err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
|
||||
ext := extensionForMIMEType(mimeType)
|
||||
tmpFile, err := os.CreateTemp(dir, "tool-inline-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, err = tmpFile.Write(decoded); err != nil {
|
||||
tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be stored.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be stored.]", mimeType)
|
||||
}
|
||||
|
||||
filename := sanitizeIdentifierComponent(toolName) + ext
|
||||
@@ -270,10 +255,7 @@ func storeInlineDataURL(
|
||||
}, scope)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[Tool returned inline media content (%s) but it could not be registered.]",
|
||||
mimeType,
|
||||
)
|
||||
return "", fmt.Sprintf("[Tool returned inline media content (%s) but it could not be registered.]", mimeType)
|
||||
}
|
||||
|
||||
return ref, fmt.Sprintf(inlineMediaStoredMessage, mimeType)
|
||||
|
||||
+1
-4
@@ -80,10 +80,7 @@ func (tr *ToolResult) ContentForLLM() string {
|
||||
}
|
||||
}
|
||||
if len(tr.ArtifactTags) > 0 {
|
||||
artifactNote := "Local artifact paths: " + strings.Join(
|
||||
tr.ArtifactTags,
|
||||
" ",
|
||||
) + "\n" + artifactPathsLLMNote
|
||||
artifactNote := "Local artifact paths: " + strings.Join(tr.ArtifactTags, " ") + "\n" + artifactPathsLLMNote
|
||||
if content == "" {
|
||||
content = artifactNote
|
||||
} else if !strings.Contains(content, artifactNote) {
|
||||
|
||||
@@ -142,11 +142,7 @@ func TestToolResultJSONSerialization(t *testing.T) {
|
||||
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
|
||||
}
|
||||
if decoded.ForUser != tt.result.ForUser {
|
||||
t.Errorf(
|
||||
"ForUser mismatch: got '%s', want '%s'",
|
||||
decoded.ForUser,
|
||||
tt.result.ForUser,
|
||||
)
|
||||
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
|
||||
}
|
||||
if decoded.Silent != tt.result.Silent {
|
||||
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
|
||||
|
||||
@@ -56,38 +56,19 @@ func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *Too
|
||||
}
|
||||
|
||||
if len(pattern) > MaxRegexPatternLength {
|
||||
logger.WarnCF(
|
||||
"discovery",
|
||||
"Regex pattern rejected (too long)",
|
||||
map[string]any{"len": len(pattern)},
|
||||
)
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength),
|
||||
)
|
||||
logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)})
|
||||
return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength))
|
||||
}
|
||||
|
||||
logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern})
|
||||
|
||||
res, err := t.registry.SearchRegex(pattern, t.maxSearchResults)
|
||||
if err != nil {
|
||||
logger.WarnCF(
|
||||
"discovery",
|
||||
"Invalid regex pattern",
|
||||
map[string]any{"pattern": pattern, "error": err.Error()},
|
||||
)
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"Invalid regex pattern syntax: %v. Please fix your regex and try again.",
|
||||
err,
|
||||
),
|
||||
)
|
||||
logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()})
|
||||
return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err))
|
||||
}
|
||||
|
||||
logger.InfoCF(
|
||||
"discovery",
|
||||
"Regex search completed",
|
||||
map[string]any{"pattern": pattern, "results": len(res)},
|
||||
)
|
||||
logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)})
|
||||
return formatDiscoveryResponse(t.registry, res, t.ttl)
|
||||
}
|
||||
|
||||
@@ -157,11 +138,7 @@ func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *Tool
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF(
|
||||
"discovery",
|
||||
"BM25 search completed",
|
||||
map[string]any{"query": query, "results": len(results)},
|
||||
)
|
||||
logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)})
|
||||
return formatDiscoveryResponse(t.registry, results, t.ttl)
|
||||
}
|
||||
|
||||
@@ -173,10 +150,7 @@ type ToolSearchResult struct {
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) SearchRegex(
|
||||
pattern string,
|
||||
maxSearchResults int,
|
||||
) ([]ToolSearchResult, error) {
|
||||
func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) {
|
||||
if maxSearchResults <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -214,11 +188,7 @@ func (r *ToolRegistry) SearchRegex(
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func formatDiscoveryResponse(
|
||||
registry *ToolRegistry,
|
||||
results []ToolSearchResult,
|
||||
ttl int,
|
||||
) *ToolResult {
|
||||
func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult {
|
||||
if len(results) == 0 {
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
@@ -304,11 +274,7 @@ func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine {
|
||||
cached := &bm25CachedEngine{engine: buildBM25Engine(docs)}
|
||||
t.cachedEngine = cached
|
||||
t.cacheVersion = snap.Version
|
||||
logger.DebugCF(
|
||||
"discovery",
|
||||
"BM25 engine rebuilt",
|
||||
map[string]any{"docs": len(docs), "version": snap.Version},
|
||||
)
|
||||
logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version})
|
||||
return cached
|
||||
}
|
||||
|
||||
|
||||
@@ -93,10 +93,7 @@ func TestRegexSearchTool_Execute(t *testing.T) {
|
||||
reg.mu.RLock()
|
||||
defer reg.mu.RUnlock()
|
||||
if reg.tools["mcp_read_file"].TTL != 5 {
|
||||
t.Errorf(
|
||||
"Expected TTL of 'mcp_read_file' to be promoted to 5, got %d",
|
||||
reg.tools["mcp_read_file"].TTL,
|
||||
)
|
||||
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL)
|
||||
}
|
||||
if reg.tools["mcp_fetch_net"].TTL != 0 {
|
||||
t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)")
|
||||
|
||||
@@ -142,10 +142,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
|
||||
}
|
||||
|
||||
return MediaResult(
|
||||
fmt.Sprintf("File %q sent to user", filename),
|
||||
[]string{ref},
|
||||
).WithResponseHandled()
|
||||
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}).WithResponseHandled()
|
||||
}
|
||||
|
||||
// detectMediaType determines the MIME type of a file.
|
||||
|
||||
@@ -79,11 +79,7 @@ func TestSendFileTool_FileTooLarge(t *testing.T) {
|
||||
func TestSendFileTool_DefaultMaxSize(t *testing.T) {
|
||||
tool := NewSendFileTool("/tmp", false, 0, nil)
|
||||
if tool.maxFileSize != config.DefaultMaxMediaSize {
|
||||
t.Errorf(
|
||||
"expected default max size %d, got %d",
|
||||
config.DefaultMaxMediaSize,
|
||||
tool.maxFileSize,
|
||||
)
|
||||
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,11 +162,7 @@ func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
|
||||
t.Cleanup(func() { _ = os.Remove(testPath) })
|
||||
|
||||
pattern := regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(
|
||||
filepath.Clean(mediaDir),
|
||||
) + "(?:" + regexp.QuoteMeta(
|
||||
string(os.PathSeparator),
|
||||
) + "|$)",
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
|
||||
+14
-58
@@ -113,11 +113,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(
|
||||
workingDir string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) (*ExecTool, error) {
|
||||
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
|
||||
}
|
||||
|
||||
@@ -197,16 +193,8 @@ func (t *ExecTool) Parameters() map[string]any {
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{
|
||||
"run",
|
||||
"list",
|
||||
"poll",
|
||||
"read",
|
||||
"write",
|
||||
"kill",
|
||||
"send-keys",
|
||||
},
|
||||
"type": "string",
|
||||
"enum": []string{"run", "list", "poll", "read", "write", "kill", "send-keys"},
|
||||
"description": "Action: run (execute command), list (show sessions), poll (check status), read (get output), write (send input), kill (terminate), send-keys (send keys to PTY)",
|
||||
},
|
||||
"command": map[string]any{
|
||||
@@ -312,12 +300,7 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["cwd"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePathWithAllowPaths(
|
||||
wd,
|
||||
t.workingDir,
|
||||
true,
|
||||
t.allowedPathPatterns,
|
||||
)
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
@@ -343,9 +326,7 @@ func (t *ExecTool) executeRun(ctx context.Context, args map[string]any) *ToolRes
|
||||
if t.restrictToWorkspace && t.workingDir != "" && cwd != t.workingDir {
|
||||
resolved, err := filepath.EvalSymlinks(cwd)
|
||||
if err != nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
if isAllowedPath(resolved, t.allowedPathPatterns) {
|
||||
cwd = resolved
|
||||
@@ -383,14 +364,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.CommandContext(
|
||||
cmdCtx,
|
||||
"powershell",
|
||||
"-NoProfile",
|
||||
"-NonInteractive",
|
||||
"-Command",
|
||||
command,
|
||||
)
|
||||
cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command)
|
||||
} else {
|
||||
cmd = exec.CommandContext(cmdCtx, "sh", "-c", command)
|
||||
}
|
||||
@@ -468,10 +442,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
|
||||
|
||||
maxLen := 10000
|
||||
if len(output) > maxLen {
|
||||
output = output[:maxLen] + fmt.Sprintf(
|
||||
"\n... (truncated, %d more chars)",
|
||||
len(output)-maxLen,
|
||||
)
|
||||
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -489,11 +460,7 @@ func (t *ExecTool) runSync(ctx context.Context, command, cwd string) *ToolResult
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) runBackground(
|
||||
ctx context.Context,
|
||||
command, cwd string,
|
||||
ptyEnabled bool,
|
||||
) *ToolResult {
|
||||
func (t *ExecTool) runBackground(ctx context.Context, command, cwd string, ptyEnabled bool) *ToolResult {
|
||||
sessionID := generateSessionID()
|
||||
session := &ProcessSession{
|
||||
ID: sessionID,
|
||||
@@ -586,8 +553,7 @@ func (t *ExecTool) runBackground(
|
||||
n, err := session.ptyMaster.Read(buf)
|
||||
if n > 0 {
|
||||
raw := string(buf[:n])
|
||||
if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound &&
|
||||
mode != session.GetPtyKeyMode() {
|
||||
if mode := detectPtyKeyMode(raw); mode != PtyKeyModeNotFound && mode != session.GetPtyKeyMode() {
|
||||
session.SetPtyKeyMode(mode)
|
||||
}
|
||||
|
||||
@@ -768,16 +734,12 @@ func (t *ExecTool) executeWrite(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Write(data); err != nil {
|
||||
if errors.Is(err, ErrSessionDone) {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to write to session: %v", err))
|
||||
}
|
||||
@@ -808,9 +770,7 @@ func (t *ExecTool) executeKill(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Kill(); err != nil {
|
||||
@@ -1032,16 +992,12 @@ func (t *ExecTool) executeSendKeys(args map[string]any) *ToolResult {
|
||||
}
|
||||
|
||||
if session.IsDone() {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
|
||||
if err := session.Write(data); err != nil {
|
||||
if errors.Is(err, ErrSessionDone) {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("process already exited with code %d", session.GetExitCode()),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("process already exited with code %d", session.GetExitCode()))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to send keys: %v", err))
|
||||
}
|
||||
|
||||
+18
-77
@@ -100,13 +100,8 @@ func TestShellTool_Timeout(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention timeout
|
||||
if !strings.Contains(result.ForLLM, "timed out") &&
|
||||
!strings.Contains(result.ForUser, "timed out") {
|
||||
t.Errorf(
|
||||
"Expected timeout message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
|
||||
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,11 +156,7 @@ func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf(
|
||||
"Expected 'blocked' message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,11 +177,7 @@ func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) {
|
||||
t.Errorf("Expected kill command to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf(
|
||||
"Expected blocked message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
t.Errorf("Expected blocked message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,10 +269,7 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf(
|
||||
"expected working_dir outside workspace to be blocked, got output: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected working_dir outside workspace to be blocked, got output: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
|
||||
@@ -460,10 +444,7 @@ func TestShellTool_DevNullAllowed(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
@@ -492,10 +473,7 @@ func TestShellTool_BlockDevices(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range blocked {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if !result.IsError {
|
||||
t.Errorf("expected block device write to be blocked: %s", cmd)
|
||||
}
|
||||
@@ -519,16 +497,9 @@ func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range commands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf(
|
||||
"safe path should not be blocked by workspace check: %s\n error: %s",
|
||||
cmd,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -620,10 +591,7 @@ func TestShellTool_CustomAllowPatterns(t *testing.T) {
|
||||
"command": "git push origin main",
|
||||
})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "blocked") {
|
||||
t.Errorf(
|
||||
"custom allow pattern should exempt 'git push origin main', got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// "git push upstream main" should still be blocked (does not match allow pattern).
|
||||
@@ -661,11 +629,7 @@ func TestShellTool_URLsNotBlocked(t *testing.T) {
|
||||
result := tool.Execute(ctx, map[string]any{"action": "run", "command": cmd})
|
||||
cancel()
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf(
|
||||
"command with URL should not be blocked by workspace check: %s\n error: %s",
|
||||
cmd,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("command with URL should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -688,10 +652,7 @@ func TestShellTool_FileURISandboxing(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range blockedCommands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf("file:// URI outside workspace should be blocked: %s", cmd)
|
||||
}
|
||||
@@ -709,16 +670,9 @@ func TestShellTool_FileURISandboxing(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range allowedCommands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf(
|
||||
"file:// URI inside workspace should be allowed: %s\n error: %s",
|
||||
cmd,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("file:// URI inside workspace should be allowed: %s\n error: %s", cmd, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -742,10 +696,7 @@ func TestShellTool_URLBypassPrevented(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, cmd := range blockedCommands {
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"action": "run", "command": cmd},
|
||||
)
|
||||
result := tool.Execute(context.Background(), map[string]any{"action": "run", "command": cmd})
|
||||
if !result.IsError || !strings.Contains(result.ForLLM, "path outside working dir") {
|
||||
t.Errorf("bypass attempt should be blocked: %q\n got: %s", cmd, result.ForLLM)
|
||||
}
|
||||
@@ -1270,9 +1221,7 @@ func TestShellTool_PTY_ProcessGroupKill(t *testing.T) {
|
||||
// The binary is created in /tmp/test_pgroup.c and compiled as part of test setup.
|
||||
testBinary := "/tmp/test_pgroup"
|
||||
if _, err := os.Stat(testBinary); os.IsNotExist(err) {
|
||||
t.Skip(
|
||||
"Test binary /tmp/test_pgroup not found - run: gcc -o /tmp/test_pgroup /tmp/test_pgroup.c",
|
||||
)
|
||||
t.Skip("Test binary /tmp/test_pgroup not found - run: gcc -o /tmp/test_pgroup /tmp/test_pgroup.c")
|
||||
}
|
||||
|
||||
tool, err := NewExecTool("", false)
|
||||
@@ -1606,16 +1555,8 @@ func TestDetectPtyKeyMode(t *testing.T) {
|
||||
{"rmkx only", "\x1b[?1l\x1b>", PtyKeyModeCSI},
|
||||
{"both smkx first", "\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI},
|
||||
{"both rmkx first", "\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3},
|
||||
{
|
||||
"multiple toggles smkx last",
|
||||
"\x1b[?1h\x1b=...\x1b[?1l\x1b>...\x1b[?1h\x1b=",
|
||||
PtyKeyModeSS3,
|
||||
},
|
||||
{
|
||||
"multiple toggles rmkx last",
|
||||
"\x1b[?1l\x1b>...\x1b[?1h\x1b=...\x1b[?1l\x1b>",
|
||||
PtyKeyModeCSI,
|
||||
},
|
||||
{"multiple toggles smkx last", "\x1b[?1h\x1b=...\x1b[?1l\x1b>...\x1b[?1h\x1b=", PtyKeyModeSS3},
|
||||
{"multiple toggles rmkx last", "\x1b[?1l\x1b>...\x1b[?1h\x1b=...\x1b[?1l\x1b>", PtyKeyModeCSI},
|
||||
{"partial smkx", "\x1b[?1h", PtyKeyModeSS3},
|
||||
{"partial rmkx", "\x1b[?1l", PtyKeyModeCSI},
|
||||
}
|
||||
|
||||
@@ -96,11 +96,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
|
||||
if !force {
|
||||
if _, err := os.Stat(targetDir); err == nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"skill %q already installed at %s. Use force=true to reinstall.",
|
||||
slug,
|
||||
targetDir,
|
||||
),
|
||||
fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
@@ -146,9 +142,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug),
|
||||
)
|
||||
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
|
||||
}
|
||||
|
||||
// Write origin metadata.
|
||||
@@ -168,10 +162,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To
|
||||
// Build result with moderation warning if suspicious.
|
||||
var output string
|
||||
if result.IsSuspicious {
|
||||
output = fmt.Sprintf(
|
||||
"⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n",
|
||||
slug,
|
||||
)
|
||||
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
|
||||
}
|
||||
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
|
||||
slug, result.Version, registry.Name(), targetDir)
|
||||
|
||||
@@ -17,10 +17,7 @@ type FindSkillsTool struct {
|
||||
// NewFindSkillsTool creates a new FindSkillsTool.
|
||||
// registryMgr is the shared registry manager (built from config in createToolRegistry).
|
||||
// cache is the search cache for deduplicating similar queries.
|
||||
func NewFindSkillsTool(
|
||||
registryMgr *skills.RegistryManager,
|
||||
cache *skills.SearchCache,
|
||||
) *FindSkillsTool {
|
||||
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
|
||||
return &FindSkillsTool{
|
||||
registryMgr: registryMgr,
|
||||
cache: cache,
|
||||
|
||||
@@ -77,12 +77,10 @@ func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *Too
|
||||
}
|
||||
|
||||
// Restrict lookup to tasks that belong to this conversation.
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" &&
|
||||
taskCopy.OriginChannel != callerChannel {
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" &&
|
||||
taskCopy.OriginChatID != callerChatID {
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
|
||||
@@ -195,12 +195,7 @@ func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
|
||||
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
|
||||
if !result.IsError {
|
||||
t.Errorf(
|
||||
"Expected error for task_id=%T(%v), got success: %s",
|
||||
badVal,
|
||||
badVal,
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task_id must be a string") {
|
||||
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
|
||||
@@ -324,10 +319,7 @@ func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
|
||||
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
|
||||
}
|
||||
if pos2 > pos10 {
|
||||
t.Errorf(
|
||||
"Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+2
-6
@@ -69,9 +69,7 @@ func (t *SPITool) Parameters() map[string]any {
|
||||
|
||||
func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult(
|
||||
"SPI is only supported on Linux. This tool requires /dev/spidev* device files.",
|
||||
)
|
||||
return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.")
|
||||
}
|
||||
|
||||
action, ok := args["action"].(string)
|
||||
@@ -126,9 +124,7 @@ func (t *SPITool) list() *ToolResult {
|
||||
// parseSPIArgs extracts and validates common SPI parameters
|
||||
//
|
||||
//nolint:unused // Used by spi_linux.go
|
||||
func parseSPIArgs(
|
||||
args map[string]any,
|
||||
) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
|
||||
func parseSPIArgs(args map[string]any) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
|
||||
dev, ok := args["device"].(string)
|
||||
if !ok || dev == "" {
|
||||
return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)"
|
||||
|
||||
+6
-37
@@ -38,46 +38,25 @@ type spiTransfer struct {
|
||||
func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *ToolResult) {
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return -1, ErrorResult(
|
||||
fmt.Sprintf(
|
||||
"failed to open %s: %v (check permissions and spidev module)",
|
||||
devPath,
|
||||
err,
|
||||
),
|
||||
)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err))
|
||||
}
|
||||
|
||||
// Set SPI mode
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocWrMode,
|
||||
uintptr(unsafe.Pointer(&mode)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMode, uintptr(unsafe.Pointer(&mode)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set SPI mode %d: %v", mode, errno))
|
||||
}
|
||||
|
||||
// Set bits per word
|
||||
_, _, errno = syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocWrBitsPerWord,
|
||||
uintptr(unsafe.Pointer(&bits)),
|
||||
)
|
||||
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrBitsPerWord, uintptr(unsafe.Pointer(&bits)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set bits per word %d: %v", bits, errno))
|
||||
}
|
||||
|
||||
// Set max speed
|
||||
_, _, errno = syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocWrMaxSpeedHz,
|
||||
uintptr(unsafe.Pointer(&speed)),
|
||||
)
|
||||
_, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMaxSpeedHz, uintptr(unsafe.Pointer(&speed)))
|
||||
if errno != 0 {
|
||||
syscall.Close(fd)
|
||||
return -1, ErrorResult(fmt.Sprintf("failed to set SPI speed %d Hz: %v", speed, errno))
|
||||
@@ -138,12 +117,7 @@ func (t *SPITool) transfer(args map[string]any) *ToolResult {
|
||||
bitsPerWord: bits,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocMessage1,
|
||||
uintptr(unsafe.Pointer(&xfer)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
|
||||
runtime.KeepAlive(txBuf)
|
||||
runtime.KeepAlive(rxBuf)
|
||||
if errno != 0 {
|
||||
@@ -200,12 +174,7 @@ func (t *SPITool) readDevice(args map[string]any) *ToolResult {
|
||||
bitsPerWord: bits,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
spiIocMessage1,
|
||||
uintptr(unsafe.Pointer(&xfer)),
|
||||
)
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer)))
|
||||
runtime.KeepAlive(txBuf)
|
||||
runtime.KeepAlive(rxBuf)
|
||||
if errno != 0 {
|
||||
|
||||
@@ -316,11 +316,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
// ForUser should be truncated to 500 chars + "..."
|
||||
maxUserLen := 500
|
||||
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
|
||||
t.Errorf(
|
||||
"ForUser should be truncated to ~%d chars, got: %d",
|
||||
maxUserLen,
|
||||
len(result.ForUser),
|
||||
)
|
||||
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
|
||||
}
|
||||
|
||||
// ForLLM should have full content
|
||||
|
||||
+2
-15
@@ -64,13 +64,7 @@ func RunToolLoop(
|
||||
llmOpts = map[string]any{}
|
||||
}
|
||||
// 3. Call LLM
|
||||
response, err := config.Provider.Chat(
|
||||
ctx,
|
||||
messages,
|
||||
providerToolDefs,
|
||||
config.Model,
|
||||
llmOpts,
|
||||
)
|
||||
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
|
||||
if err != nil {
|
||||
logger.ErrorCF("toolloop", "LLM call failed",
|
||||
map[string]any{
|
||||
@@ -154,14 +148,7 @@ func RunToolLoop(
|
||||
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
channel,
|
||||
chatID,
|
||||
nil,
|
||||
)
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
}
|
||||
|
||||
@@ -151,10 +151,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []any{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "red"},
|
||||
@@ -164,10 +161,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []any{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "yellow"},
|
||||
@@ -178,10 +172,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "green"},
|
||||
@@ -191,10 +182,7 @@ func TestValidateToolArgs(t *testing.T) {
|
||||
schema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"color": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"red", "green", "blue"},
|
||||
},
|
||||
"color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}},
|
||||
},
|
||||
},
|
||||
args: map[string]any{"color": "yellow"},
|
||||
@@ -354,11 +342,7 @@ func TestValidateToolArgs_RegistryIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Extra property — should fail with validation error
|
||||
result = r.Execute(
|
||||
context.Background(),
|
||||
"read_file",
|
||||
map[string]any{"path": "/x", "__inject": true},
|
||||
)
|
||||
result = r.Execute(context.Background(), "read_file", map[string]any{"path": "/x", "__inject": true})
|
||||
if !result.IsError {
|
||||
t.Error("expected validation error for extra property")
|
||||
}
|
||||
|
||||
+32
-130
@@ -54,8 +54,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
// ForUser should contain summary
|
||||
if !strings.Contains(result.ForUser, "bytes") &&
|
||||
!strings.Contains(result.ForUser, "extractor") {
|
||||
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
||||
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
@@ -76,11 +75,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -105,11 +100,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -134,11 +125,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -154,8 +141,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention only http/https allowed
|
||||
if !strings.Contains(result.ForLLM, "http/https") &&
|
||||
!strings.Contains(result.ForUser, "http/https") {
|
||||
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
|
||||
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -164,11 +150,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -182,8 +164,7 @@ func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention URL is required
|
||||
if !strings.Contains(result.ForLLM, "url is required") &&
|
||||
!strings.Contains(result.ForUser, "url is required") {
|
||||
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
|
||||
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -203,11 +184,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
|
||||
tool, err := NewWebFetchTool(1000, format, testFetchLimit) // Limit to 1000 chars
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -239,10 +216,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
// Text should end with the truncation notice
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if !strings.HasSuffix(text, "[Content truncated due to size limit]") {
|
||||
t.Errorf(
|
||||
"Expected text to end with truncation notice, got: %q",
|
||||
text[max(0, len(text)-60):],
|
||||
)
|
||||
t.Errorf("Expected text to end with truncation notice, got: %q", text[max(0, len(text)-60):])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -289,13 +263,11 @@ func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", tt.contentType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(tt.body))
|
||||
}),
|
||||
)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", tt.contentType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(tt.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(maxChars, tt.format, testFetchLimit)
|
||||
@@ -319,11 +291,7 @@ func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(text, truncationNotice) {
|
||||
t.Errorf(
|
||||
"expected text to end with %q, got suffix: %q",
|
||||
truncationNotice,
|
||||
text[max(0, len(text)-60):],
|
||||
)
|
||||
t.Errorf("expected text to end with %q, got suffix: %q", truncationNotice, text[max(0, len(text)-60):])
|
||||
}
|
||||
|
||||
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
||||
@@ -392,11 +360,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
// Initialize the tool
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
// Prepare the arguments pointing to the URL of our local mock server
|
||||
@@ -416,8 +380,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
// Search for the exact error string we set earlier in the Execute method
|
||||
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
|
||||
|
||||
if !strings.Contains(result.ForLLM, expectedErrorMsg) &&
|
||||
!strings.Contains(result.ForUser, expectedErrorMsg) {
|
||||
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
|
||||
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
|
||||
}
|
||||
}
|
||||
@@ -570,11 +533,7 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -759,13 +718,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(
|
||||
50000,
|
||||
"",
|
||||
format,
|
||||
testFetchLimit,
|
||||
[]string{singleHostCIDR(t, host)},
|
||||
)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -800,10 +753,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf(
|
||||
"expected success when private host access is allowed in tests, got %q",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1023,11 +973,7 @@ func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -1049,19 +995,9 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
tool, err := NewWebFetchToolWithProxy(
|
||||
1024,
|
||||
"http://127.0.0.1:7890",
|
||||
format,
|
||||
testFetchLimit,
|
||||
nil,
|
||||
)
|
||||
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else if tool.maxChars != 1024 {
|
||||
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
|
||||
}
|
||||
@@ -1072,11 +1008,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
|
||||
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(
|
||||
"agent",
|
||||
"Failed to create web fetch tool",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
if tool.maxChars != 50000 {
|
||||
@@ -1085,13 +1017,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
||||
_, err := NewWebFetchToolWithConfig(
|
||||
1024,
|
||||
"",
|
||||
format,
|
||||
testFetchLimit,
|
||||
[]string{"not-an-ip-or-cidr"},
|
||||
)
|
||||
_, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid whitelist entry to fail")
|
||||
}
|
||||
@@ -1247,11 +1173,7 @@ func TestWebTool_TavilySearch_RangeMapping(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"results": []map[string]any{
|
||||
{
|
||||
"title": "Recent result",
|
||||
"url": "https://example.com/recent",
|
||||
"content": "snippet",
|
||||
},
|
||||
{"title": "Recent result", "url": "https://example.com/recent", "content": "snippet"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
@@ -1381,10 +1303,7 @@ func TestWebFetchTool_CloudflareChallenge_RetryFailsToo(t *testing.T) {
|
||||
|
||||
// Should not be an error — the retry response is used as-is (403 is a valid HTTP response)
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected non-error result even when retry is also blocked, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
t.Fatalf("expected non-error result even when retry is also blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
// Status in the JSON result should reflect the 403
|
||||
if !strings.Contains(result.ForLLM, "403") {
|
||||
@@ -1549,10 +1468,7 @@ func TestWebTool_GLMSearch_Success(t *testing.T) {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
|
||||
t.Errorf(
|
||||
"Expected Authorization Bearer test-glm-key, got %s",
|
||||
r.Header.Get("Authorization"),
|
||||
)
|
||||
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
@@ -1618,21 +1534,14 @@ func TestWebTool_GLMSearch_RangeMapping(t *testing.T) {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
}
|
||||
if payload["search_recency_filter"] != "oneMonth" {
|
||||
t.Fatalf(
|
||||
"expected search_recency_filter=oneMonth, got %v",
|
||||
payload["search_recency_filter"],
|
||||
)
|
||||
t.Fatalf("expected search_recency_filter=oneMonth, got %v", payload["search_recency_filter"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"search_result": []map[string]any{
|
||||
{
|
||||
"title": "Recent GLM Result",
|
||||
"content": "snippet",
|
||||
"link": "https://example.com/glm-range",
|
||||
},
|
||||
{"title": "Recent GLM Result", "content": "snippet", "link": "https://example.com/glm-range"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
@@ -1664,21 +1573,14 @@ func TestWebTool_BaiduSearch_RangeMapping(t *testing.T) {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
}
|
||||
if payload["search_recency_filter"] != "week" {
|
||||
t.Fatalf(
|
||||
"expected search_recency_filter=week for day fallback, got %v",
|
||||
payload["search_recency_filter"],
|
||||
)
|
||||
t.Fatalf("expected search_recency_filter=week for day fallback, got %v", payload["search_recency_filter"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"references": []map[string]any{
|
||||
{
|
||||
"title": "Recent Baidu Result",
|
||||
"url": "https://example.com/baidu",
|
||||
"content": "snippet",
|
||||
},
|
||||
{"title": "Recent Baidu Result", "url": "https://example.com/baidu", "content": "snippet"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
|
||||
Reference in New Issue
Block a user