mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into refactor-inbound-context-routing-session
# Conflicts: # pkg/agent/eventbus_test.go # pkg/agent/loop.go # pkg/bus/bus.go # pkg/bus/types.go # pkg/channels/pico/pico.go # pkg/channels/telegram/telegram.go # pkg/config/config.go # web/backend/api/session.go # web/backend/api/session_test.go
This commit is contained in:
+5
-5
@@ -29,7 +29,7 @@ func (t *EditFileTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Description() string {
|
||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n."
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Parameters() map[string]any {
|
||||
@@ -42,11 +42,11 @@ func (t *EditFileTool) Parameters() map[string]any {
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The exact text to find and replace",
|
||||
"description": "The exact text to find and replace. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace with",
|
||||
"description": "The text to replace with. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path", "old_text", "new_text"},
|
||||
@@ -92,7 +92,7 @@ func (t *AppendFileTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Description() string {
|
||||
return "Append content to the end of a file"
|
||||
return "Append content to the end of a file. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n."
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Parameters() map[string]any {
|
||||
@@ -105,7 +105,7 @@ func (t *AppendFileTool) Parameters() map[string]any {
|
||||
},
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The content to append",
|
||||
"description": "The content to append. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path", "content"},
|
||||
|
||||
+381
-5
@@ -1,18 +1,22 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -20,7 +24,11 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
func validatePathWithAllowPaths(
|
||||
path, workspace string,
|
||||
restrict bool,
|
||||
patterns []*regexp.Regexp,
|
||||
) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -253,6 +261,11 @@ type ReadFileTool struct {
|
||||
maxSize int64
|
||||
}
|
||||
|
||||
type ReadFileLinesTool struct {
|
||||
fs fileSystem
|
||||
maxSize int64
|
||||
}
|
||||
|
||||
func NewReadFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
@@ -275,14 +288,53 @@ func NewReadFileTool(
|
||||
}
|
||||
}
|
||||
|
||||
func NewReadFileBytesTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxReadFileSize int,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *ReadFileTool {
|
||||
return NewReadFileTool(workspace, restrict, maxReadFileSize, allowPaths...)
|
||||
}
|
||||
|
||||
func NewReadFileLinesTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxReadFileSize int,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *ReadFileLinesTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
|
||||
maxSize := int64(maxReadFileSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = MaxReadFileSize
|
||||
}
|
||||
|
||||
return &ReadFileLinesTool{
|
||||
fs: buildFs(workspace, restrict, patterns),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Name() string {
|
||||
return "read_file"
|
||||
}
|
||||
|
||||
func (t *ReadFileLinesTool) Name() string {
|
||||
return "read_file"
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Description() string {
|
||||
return "Read the contents of a file. Supports pagination via `offset` and `length`."
|
||||
}
|
||||
|
||||
func (t *ReadFileLinesTool) Description() string {
|
||||
return "Read a UTF-8 text file from the filesystem. Output always includes line numbers in the format `LINE_NUMBER|LINE_CONTENT` (1-indexed). Supports partial reads via `start_line` and `max_lines` for large text files."
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
@@ -306,6 +358,28 @@ func (t *ReadFileTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileLinesTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to read.",
|
||||
},
|
||||
"start_line": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, inclusive).",
|
||||
"default": 1,
|
||||
},
|
||||
"max_lines": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
@@ -447,6 +521,302 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return NewToolResult(header + "\n\n" + string(data))
|
||||
}
|
||||
|
||||
func (t *ReadFileLinesTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
startLine, err := getInt64Arg(args, "start_line", 1)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
if startLine < 1 {
|
||||
return ErrorResult("start_line must be >= 1")
|
||||
}
|
||||
if _, exists := args["offset"]; exists {
|
||||
return ErrorResult("offset is not supported in line mode; use start_line")
|
||||
}
|
||||
if _, exists := args["length"]; exists {
|
||||
return ErrorResult("length is not supported in line mode; use max_lines")
|
||||
}
|
||||
if _, exists := args["limit"]; exists {
|
||||
return ErrorResult("limit is not supported in line mode; use max_lines")
|
||||
}
|
||||
|
||||
limit := int64(-1)
|
||||
if raw, exists := args["max_lines"]; exists && raw != nil {
|
||||
limit, err = getInt64Arg(args, "max_lines", -1)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
if limit <= 0 {
|
||||
return ErrorResult("max_lines, if provided, must be > 0")
|
||||
}
|
||||
}
|
||||
|
||||
file, err := t.fs.Open(path)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if info, statErr := file.Stat(); statErr == nil && info.IsDir() {
|
||||
return ErrorResult(fmt.Sprintf("failed to open file: path is a directory: %s", path))
|
||||
}
|
||||
|
||||
sample := make([]byte, 512)
|
||||
sampleN, readErr := file.Read(sample)
|
||||
if readErr != nil && readErr != io.EOF {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", readErr))
|
||||
}
|
||||
sample = sample[:sampleN]
|
||||
if isBinaryReadFileData(sample) {
|
||||
return ErrorResult("file appears to be binary; switch read_file mode to 'bytes' for byte-based inspection")
|
||||
}
|
||||
|
||||
reader := bufio.NewReaderSize(io.MultiReader(bytes.NewReader(sample), file), 32*1024)
|
||||
|
||||
var content strings.Builder
|
||||
lineIndex := int64(1)
|
||||
var linesRead int64
|
||||
var fileBytesRead int64
|
||||
var outputBytesRead int64
|
||||
var reachedEOF bool
|
||||
var byteBudgetTruncated bool
|
||||
var lineTruncated bool
|
||||
|
||||
for lineIndex < startLine {
|
||||
hasLine, consumeErr := consumeNextLine(reader)
|
||||
if consumeErr != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file content: %v", consumeErr))
|
||||
}
|
||||
if !hasLine {
|
||||
reachedEOF = true
|
||||
break
|
||||
}
|
||||
lineIndex++
|
||||
}
|
||||
|
||||
for !reachedEOF && (limit < 0 || linesRead < limit) {
|
||||
prefix := formatReadFileLinePrefix(lineIndex)
|
||||
remaining := t.maxSize - outputBytesRead - int64(len(prefix))
|
||||
if remaining <= 0 {
|
||||
byteBudgetTruncated = true
|
||||
break
|
||||
}
|
||||
|
||||
line, complete, hasLine, readLineErr := readNextLinePrefix(reader, remaining)
|
||||
if readLineErr != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read file content: %v", readLineErr))
|
||||
}
|
||||
if !hasLine {
|
||||
reachedEOF = true
|
||||
break
|
||||
}
|
||||
|
||||
content.WriteString(prefix)
|
||||
content.Write(line)
|
||||
fileBytesRead += int64(len(line))
|
||||
outputBytesRead += int64(len(prefix) + len(line))
|
||||
linesRead++
|
||||
lineIndex++
|
||||
|
||||
if !complete {
|
||||
byteBudgetTruncated = true
|
||||
lineTruncated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !reachedEOF && !lineTruncated {
|
||||
hasMoreContent, peekErr := readerHasMoreContent(reader)
|
||||
if peekErr != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to inspect remaining file content: %v", peekErr))
|
||||
}
|
||||
if !hasMoreContent {
|
||||
reachedEOF = true
|
||||
byteBudgetTruncated = false
|
||||
}
|
||||
}
|
||||
|
||||
if linesRead == 0 && content.Len() == 0 {
|
||||
return NewToolResult(fmt.Sprintf("[END OF FILE - no content at or after start_line=%d]", startLine))
|
||||
}
|
||||
|
||||
start := startLine
|
||||
endLine := startLine + linesRead - 1
|
||||
displayPath := filepath.Base(path)
|
||||
header := fmt.Sprintf(
|
||||
"[file: %s | read: lines %d-%d (1-indexed) | file_bytes: %d | output_bytes: %d]",
|
||||
displayPath, start, endLine, fileBytesRead, outputBytesRead,
|
||||
)
|
||||
|
||||
switch {
|
||||
case lineTruncated:
|
||||
header += fmt.Sprintf(
|
||||
"\n[TRUNCATED - line %d exceeded the %d byte read budget and was cut mid-line.]",
|
||||
endLine,
|
||||
t.maxSize,
|
||||
)
|
||||
case byteBudgetTruncated:
|
||||
if limit > 0 {
|
||||
header += fmt.Sprintf(
|
||||
"\n[TRUNCATED - byte budget reached. Call read_file again with start_line=%d and max_lines=%d to continue at the next line.]",
|
||||
startLine+linesRead,
|
||||
limit,
|
||||
)
|
||||
} else {
|
||||
header += fmt.Sprintf(
|
||||
"\n[TRUNCATED - byte budget reached. Call read_file again with start_line=%d to continue at the next line.]",
|
||||
startLine+linesRead,
|
||||
)
|
||||
}
|
||||
case !reachedEOF && limit > 0 && linesRead >= limit:
|
||||
header += fmt.Sprintf(
|
||||
"\n[PARTIAL - more content remains. Call read_file again with start_line=%d and max_lines=%d to continue.]",
|
||||
startLine+linesRead,
|
||||
limit,
|
||||
)
|
||||
default:
|
||||
header += "\n[END OF FILE - no further content.]"
|
||||
}
|
||||
|
||||
logger.DebugCF("tool", "ReadFileTool execution completed successfully",
|
||||
map[string]any{
|
||||
"path": path,
|
||||
"lines_read": linesRead,
|
||||
"file_bytes_read": fileBytesRead,
|
||||
"output_bytes_read": outputBytesRead,
|
||||
"truncated": byteBudgetTruncated,
|
||||
"tool": t.Name(),
|
||||
})
|
||||
|
||||
return NewToolResult(header + "\n\n" + content.String())
|
||||
}
|
||||
|
||||
func formatReadFileLinePrefix(lineNumber int64) string {
|
||||
return strconv.FormatInt(lineNumber, 10) + "|"
|
||||
}
|
||||
|
||||
func isBinaryReadFileData(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
sample := data
|
||||
if len(sample) > 512 {
|
||||
sample = sample[:512]
|
||||
}
|
||||
|
||||
if bytes.IndexByte(sample, 0) >= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(sample)
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
return false
|
||||
}
|
||||
if strings.HasSuffix(contentType, "/json") ||
|
||||
strings.HasSuffix(contentType, "+json") ||
|
||||
strings.HasSuffix(contentType, "/xml") ||
|
||||
strings.HasSuffix(contentType, "+xml") ||
|
||||
strings.Contains(contentType, "javascript") {
|
||||
return false
|
||||
}
|
||||
|
||||
if !utf8.Valid(sample) {
|
||||
return true
|
||||
}
|
||||
|
||||
controlChars := 0
|
||||
for _, b := range sample {
|
||||
if b < 0x20 && b != '\n' && b != '\r' && b != '\t' && b != '\f' && b != '\b' {
|
||||
controlChars++
|
||||
}
|
||||
}
|
||||
|
||||
return float64(controlChars)/float64(len(sample)) > 0.1
|
||||
}
|
||||
|
||||
func consumeNextLine(reader *bufio.Reader) (bool, error) {
|
||||
sawData := false
|
||||
|
||||
for {
|
||||
fragment, err := reader.ReadSlice('\n')
|
||||
if len(fragment) > 0 {
|
||||
sawData = true
|
||||
}
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
return true, nil
|
||||
case errors.Is(err, bufio.ErrBufferFull):
|
||||
continue
|
||||
case errors.Is(err, io.EOF):
|
||||
return sawData, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readNextLinePrefix(reader *bufio.Reader, maxBytes int64) ([]byte, bool, bool, error) {
|
||||
if maxBytes <= 0 {
|
||||
return nil, false, false, nil
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
sawData := false
|
||||
complete := true
|
||||
|
||||
for {
|
||||
fragment, err := reader.ReadSlice('\n')
|
||||
if len(fragment) > 0 {
|
||||
sawData = true
|
||||
if remaining := maxBytes - int64(out.Len()); remaining > 0 {
|
||||
take := len(fragment)
|
||||
if int64(take) > remaining {
|
||||
take = int(remaining)
|
||||
complete = false
|
||||
}
|
||||
out.Write(fragment[:take])
|
||||
} else {
|
||||
complete = false
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
return out.Bytes(), complete, sawData, nil
|
||||
case errors.Is(err, bufio.ErrBufferFull):
|
||||
if !complete {
|
||||
return out.Bytes(), false, true, nil
|
||||
}
|
||||
continue
|
||||
case errors.Is(err, io.EOF):
|
||||
if !sawData {
|
||||
return nil, true, false, nil
|
||||
}
|
||||
return out.Bytes(), complete, true, nil
|
||||
default:
|
||||
return nil, false, false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readerHasMoreContent(reader *bufio.Reader) (bool, error) {
|
||||
_, err := reader.Peek(1)
|
||||
switch {
|
||||
case err == nil:
|
||||
return true, nil
|
||||
case errors.Is(err, io.EOF):
|
||||
return false, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// getInt64Arg extracts an integer argument from the args map, returning the
|
||||
// provided default if the key is absent.
|
||||
func getInt64Arg(args map[string]any, key string, defaultVal int64) (int64, error) {
|
||||
@@ -483,7 +853,11 @@ type WriteFileTool struct {
|
||||
fs fileSystem
|
||||
}
|
||||
|
||||
func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
|
||||
func NewWriteFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *WriteFileTool {
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
@@ -496,7 +870,7 @@ func (t *WriteFileTool) Name() string {
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Description() string {
|
||||
return "Write content to a file. If the file already exists, you must set overwrite=true to replace it."
|
||||
return "Write content to a file. In `function.arguments`, use \\n for a newline and \\\\n for a literal backslash-n sequence. Content is written byte-for-byte after argument decoding. If the file already exists, you must set overwrite=true to replace it."
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Parameters() map[string]any {
|
||||
@@ -509,7 +883,7 @@ func (t *WriteFileTool) Parameters() map[string]any {
|
||||
},
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Content to write to the file",
|
||||
"description": "Content to write to the file. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
|
||||
},
|
||||
"overwrite": map[string]any{
|
||||
"type": "boolean",
|
||||
@@ -536,7 +910,9 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolR
|
||||
|
||||
if !overwrite {
|
||||
if _, err := t.fs.Open(path); err == nil {
|
||||
return ErrorResult(fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path))
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+460
-28
@@ -18,7 +18,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0o644)
|
||||
|
||||
tool := NewReadFileTool("", false, MaxReadFileSize)
|
||||
tool := NewReadFileBytesTool("", false, MaxReadFileSize)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
@@ -45,7 +45,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
|
||||
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
|
||||
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
tool := NewReadFileTool("", false, MaxReadFileSize)
|
||||
tool := NewReadFileBytesTool("", false, MaxReadFileSize)
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"path": "/nonexistent_file_12345.txt",
|
||||
@@ -59,8 +59,13 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
if !strings.Contains(result.ForLLM, "failed to open file") &&
|
||||
!strings.Contains(result.ForUser, "failed to open") {
|
||||
t.Errorf(
|
||||
"Expected error message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +83,8 @@ func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
|
||||
if !strings.Contains(result.ForLLM, "path is required") &&
|
||||
!strings.Contains(result.ForUser, "path is required") {
|
||||
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -122,6 +128,45 @@ func TestFilesystemTool_WriteFile_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_LiteralBackslashN verifies write_file keeps
|
||||
// literal backslash sequences unchanged when they are passed as plain text.
|
||||
func TestFilesystemTool_WriteFile_LiteralBackslashN(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "literal.txt")
|
||||
|
||||
tool := NewWriteFileTool("", false)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"content": `aaa\naaa`,
|
||||
})
|
||||
|
||||
assert.False(t, result.IsError, "expected success, got: %s", result.ForLLM)
|
||||
|
||||
data, err := os.ReadFile(testFile)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `aaa\naaa`, string(data))
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_PreservesCRLF verifies write_file does not
|
||||
// normalize line endings and writes CRLF bytes as provided.
|
||||
func TestFilesystemTool_WriteFile_PreservesCRLF(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "crlf.txt")
|
||||
content := "line1\r\nline2\r\n"
|
||||
|
||||
tool := NewWriteFileTool("", false)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
assert.False(t, result.IsError, "expected success, got: %s", result.ForLLM)
|
||||
|
||||
data, err := os.ReadFile(testFile)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte(content), data)
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
|
||||
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
@@ -297,7 +342,12 @@ func TestFilesystemTool_WriteFile_OverwriteSandboxed(t *testing.T) {
|
||||
"content": "replaced in sandbox",
|
||||
"overwrite": true,
|
||||
})
|
||||
assert.False(t, result.IsError, "expected success in sandbox mode with overwrite=true, got: %s", result.ForLLM)
|
||||
assert.False(
|
||||
t,
|
||||
result.IsError,
|
||||
"expected success in sandbox mode with overwrite=true, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, testFile))
|
||||
assert.NoError(t, err)
|
||||
@@ -325,7 +375,8 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should list files and directories
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") ||
|
||||
!strings.Contains(result.ForLLM, "file2.txt") {
|
||||
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subdir") {
|
||||
@@ -349,8 +400,13 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
if !strings.Contains(result.ForLLM, "failed to read") &&
|
||||
!strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf(
|
||||
"Expected error message, got ForLLM: %s, ForUser: %s",
|
||||
result.ForLLM,
|
||||
result.ForUser,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,7 +453,8 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
// os.Root might return different errors depending on platform/implementation
|
||||
// but it definitely should error.
|
||||
// Our wrapper returns "access denied or file not found"
|
||||
if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
|
||||
if !strings.Contains(result.ForLLM, "access denied") &&
|
||||
!strings.Contains(result.ForLLM, "file not found") &&
|
||||
!strings.Contains(result.ForLLM, "no such file") {
|
||||
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
|
||||
}
|
||||
@@ -416,10 +473,20 @@ func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
|
||||
})
|
||||
|
||||
// We EXPECT IsError=true (access blocked due to empty workspace)
|
||||
assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
|
||||
assert.True(
|
||||
t,
|
||||
result.IsError,
|
||||
"Security Regression: Empty workspace allowed access! content: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
|
||||
// Verify it failed for the right reason
|
||||
assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
|
||||
assert.Contains(
|
||||
t,
|
||||
result.ForLLM,
|
||||
"workspace is not defined",
|
||||
"Expected 'workspace is not defined' error",
|
||||
)
|
||||
}
|
||||
|
||||
// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
|
||||
@@ -653,7 +720,10 @@ func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
result := tool.Execute(
|
||||
context.Background(),
|
||||
map[string]any{"path": filepath.Join(linkPath, "secret.txt")},
|
||||
)
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
@@ -726,7 +796,6 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "pagination_test.txt")
|
||||
|
||||
// Create a test file with exactly 26 bytes of content
|
||||
fullContent := "abcdefghijklmnopqrstuvwxyz"
|
||||
err := os.WriteFile(testFile, []byte(fullContent), 0o644)
|
||||
if err != nil {
|
||||
@@ -748,15 +817,12 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
t.Fatalf("Chunk 1 failed: %s", result1.ForLLM)
|
||||
}
|
||||
|
||||
// Expect the first 10 characters
|
||||
if !strings.Contains(result1.ForLLM, "abcdefghij") {
|
||||
t.Errorf("Chunk 1 should contain 'abcdefghij', got: %s", result1.ForLLM)
|
||||
}
|
||||
// Expect the header to indicate the file is truncated
|
||||
if !strings.Contains(result1.ForLLM, "[TRUNCATED") {
|
||||
t.Errorf("Chunk 1 header should indicate truncation, got: %s", result1.ForLLM)
|
||||
}
|
||||
// Expect the header to suggest the next offset (10)
|
||||
if !strings.Contains(result1.ForLLM, "offset=10") {
|
||||
t.Errorf("Chunk 1 header should suggest next offset=10, got: %s", result1.ForLLM)
|
||||
}
|
||||
@@ -773,17 +839,14 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
t.Fatalf("Chunk 2 failed: %s", result2.ForLLM)
|
||||
}
|
||||
|
||||
// Expect the next 10 characters
|
||||
if !strings.Contains(result2.ForLLM, "klmnopqrst") {
|
||||
t.Errorf("Chunk 2 should contain 'klmnopqrst', got: %s", result2.ForLLM)
|
||||
}
|
||||
// Expect the header to suggest the next offset (20)
|
||||
if !strings.Contains(result2.ForLLM, "offset=20") {
|
||||
t.Errorf("Chunk 2 header should suggest next offset=20, got: %s", result2.ForLLM)
|
||||
}
|
||||
|
||||
// Step 3: Read the final chunk (remaining 6 bytes) ---
|
||||
// We ask for 10 bytes, but only 6 are left in the file
|
||||
args3 := map[string]any{
|
||||
"path": testFile,
|
||||
"offset": 20,
|
||||
@@ -795,16 +858,12 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
t.Fatalf("Chunk 3 failed: %s", result3.ForLLM)
|
||||
}
|
||||
|
||||
// Expect the last 6 characters
|
||||
if !strings.Contains(result3.ForLLM, "uvwxyz") {
|
||||
t.Errorf("Chunk 3 should contain 'uvwxyz', got: %s", result3.ForLLM)
|
||||
}
|
||||
// Expect the header to indicate the end of the file
|
||||
if !strings.Contains(result3.ForLLM, "[END OF FILE") {
|
||||
t.Errorf("Chunk 3 header should indicate end of file, got: %s", result3.ForLLM)
|
||||
}
|
||||
|
||||
// Ensure no TRUNCATED message is present in the final chunk
|
||||
if strings.Contains(result3.ForLLM, "[TRUNCATED") {
|
||||
t.Errorf("Chunk 3 header should NOT indicate truncation, got: %s", result3.ForLLM)
|
||||
}
|
||||
@@ -816,7 +875,6 @@ func TestReadFileTool_OffsetBeyondEOF(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "short.txt")
|
||||
|
||||
// create a file of only 5 bytes
|
||||
err := os.WriteFile(testFile, []byte("12345"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
@@ -827,19 +885,393 @@ func TestReadFileTool_OffsetBeyondEOF(t *testing.T) {
|
||||
|
||||
args := map[string]any{
|
||||
"path": testFile,
|
||||
"offset": int64(100), // Offset beyond the end of the file
|
||||
"offset": int64(100),
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// It should not be classified as a tool execution error
|
||||
if result.IsError {
|
||||
t.Errorf("A mistake was not expected, obtained IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Must return EXACTLY the string provided in the code
|
||||
expectedMsg := "[END OF FILE - no content at this offset]"
|
||||
if result.ForLLM != expectedMsg {
|
||||
t.Errorf("The message %q was expected, obtained: %q", expectedMsg, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_ChunkedReading(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "pagination_lines.txt")
|
||||
|
||||
fullContent := strings.Join([]string{
|
||||
"line 1",
|
||||
"line 2",
|
||||
"line 3",
|
||||
"line 4",
|
||||
"line 5",
|
||||
"line 6",
|
||||
}, "\n") + "\n"
|
||||
err := os.WriteFile(testFile, []byte(fullContent), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
|
||||
result1 := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
"max_lines": 2,
|
||||
})
|
||||
if result1.IsError {
|
||||
t.Fatalf("Chunk 1 failed: %s", result1.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result1.ForLLM, "1|line 1\n2|line 2\n") {
|
||||
t.Fatalf("expected first two lines, got: %s", result1.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result1.ForLLM, "lines 1-2") {
|
||||
t.Fatalf("expected line range 1-2, got: %s", result1.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result1.ForLLM, "start_line=3") {
|
||||
t.Fatalf("expected continuation start_line=3, got: %s", result1.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result1.ForLLM, "max_lines=2") {
|
||||
t.Fatalf("expected continuation max_lines=2, got: %s", result1.ForLLM)
|
||||
}
|
||||
|
||||
result2 := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 3,
|
||||
"max_lines": 2,
|
||||
})
|
||||
if result2.IsError {
|
||||
t.Fatalf("Chunk 2 failed: %s", result2.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result2.ForLLM, "3|line 3\n4|line 4\n") {
|
||||
t.Fatalf("expected middle chunk, got: %s", result2.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result2.ForLLM, "start_line=5") {
|
||||
t.Fatalf("expected continuation start_line=5, got: %s", result2.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result2.ForLLM, "max_lines=2") {
|
||||
t.Fatalf("expected continuation max_lines=2, got: %s", result2.ForLLM)
|
||||
}
|
||||
|
||||
result3 := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 5,
|
||||
"max_lines": 2,
|
||||
})
|
||||
if result3.IsError {
|
||||
t.Fatalf("Chunk 3 failed: %s", result3.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result3.ForLLM, "5|line 5\n6|line 6\n") {
|
||||
t.Fatalf("expected final chunk, got: %s", result3.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result3.ForLLM, "[END OF FILE") {
|
||||
t.Fatalf("expected EOF marker, got: %s", result3.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_DefaultOffsetAndRemainingLines(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "default_lines.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("line 1\nline 2\nline 3\n"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("Execute() error = %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "1|line 1\n2|line 2\n3|line 3\n") {
|
||||
t.Fatalf("expected remaining lines by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "lines 1-3") {
|
||||
t.Fatalf("expected line range 1-3, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileTool_LegacyLengthUsesByteModeForText(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "legacy_bytes.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("abcdefghijklmnopqrstuvwxyz"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileBytesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"offset": 10,
|
||||
"length": 5,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("Execute() error = %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "read: bytes 10-14") {
|
||||
t.Fatalf("expected byte-based header, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "klmno") {
|
||||
t.Fatalf("expected byte chunk content, got: %s", result.ForLLM)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "lines ") {
|
||||
t.Fatalf("expected legacy byte mode, got line-based header: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_OffsetBeyondEOF(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "short_lines.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": int64(100),
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "[END OF FILE - no content at or after start_line=100]" {
|
||||
t.Fatalf("unexpected EOF message: %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_RegistryValidationSupportsMaxLinesAndRejectsLimit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "registry_lines.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("line 1\nline 2\nline 3\n"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
reg := NewToolRegistry()
|
||||
reg.Register(NewReadFileLinesTool(tmpDir, false, MaxReadFileSize))
|
||||
|
||||
result := reg.Execute(context.Background(), "read_file", map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
"max_lines": 1,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected max_lines to pass registry validation, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "1|line 1\n") {
|
||||
t.Fatalf("expected first line via max_lines, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
result = reg.Execute(context.Background(), "read_file", map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 2,
|
||||
"limit": 1,
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected limit to be rejected, got success: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "unexpected property \"limit\"") {
|
||||
t.Fatalf("expected registry validation error for limit, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_RejectsOffset(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "legacy_offset.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
"offset": 1,
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected offset to be rejected, got success: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "offset is not supported in line mode; use start_line") {
|
||||
t.Fatalf("unexpected error for offset in line mode: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_RejectsLength(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "legacy_length.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
"length": 1,
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected length to be rejected, got success: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "length is not supported in line mode; use max_lines") {
|
||||
t.Fatalf("unexpected error for length in line mode: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_RejectsLimit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "legacy_limit.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
"limit": 1,
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected limit to be rejected, got success: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "limit is not supported in line mode; use max_lines") {
|
||||
t.Fatalf("unexpected error for limit in line mode: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_BinaryFileRejected(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "binary.dat")
|
||||
|
||||
data := []byte{0x00, 0x01, 'A', 'B', 'C', 'D', 'E', 'F'}
|
||||
err := os.WriteFile(testFile, data, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected binary file rejection in line mode, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "switch read_file mode to 'bytes'") {
|
||||
t.Fatalf("expected binary file rejection message, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "mode to 'bytes'") {
|
||||
t.Fatalf("expected suggestion to switch read_file mode, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_TruncatesSingleLongLineAtByteBudget(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "long_line.txt")
|
||||
|
||||
content := "first line\n" + strings.Repeat("x", 70*1024) + "\n"
|
||||
err := os.WriteFile(testFile, []byte(content), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("Execute() error = %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "was cut mid-line") {
|
||||
t.Fatalf("expected explicit mid-line truncation warning, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "1|first line\n") {
|
||||
t.Fatalf("expected the first line with line prefix, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "2|") {
|
||||
t.Fatalf("expected line prefix for the truncated line, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_NoTrailingNewline(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "no_trailing_newline.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("line 1\nline 2"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("Execute() error = %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "1|line 1\n2|line 2") {
|
||||
t.Fatalf(
|
||||
"expected final line without trailing newline to be preserved, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "[END OF FILE - no further content.]") {
|
||||
t.Fatalf("expected EOF marker, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLinesTool_ExactByteBudgetBoundaryIncludesPrefix(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "exact_boundary.txt")
|
||||
|
||||
err := os.WriteFile(testFile, []byte("1234567\nsecond line\n"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileLinesTool(tmpDir, false, 10)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"start_line": 1,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("Execute() error = %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "1|1234567\n") {
|
||||
t.Fatalf(
|
||||
"expected first line to fit exactly in the byte budget with its prefix, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "2|") {
|
||||
t.Fatalf(
|
||||
"expected second line to be excluded once the exact output byte budget was reached, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "file_bytes: 8 | output_bytes: 10") {
|
||||
t.Fatalf("expected separate file/output byte counters, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "start_line=2") {
|
||||
t.Fatalf("expected continuation at line 2, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// LoadImageTool loads a local image file into the MediaStore and returns a
|
||||
// media:// reference. The agent loop's resolveMediaRefs will then base64-encode
|
||||
// it and attach it as an image_url part in the next LLM request, enabling
|
||||
// vision on local files — the same pipeline used when a user sends an image
|
||||
// through a chat channel.
|
||||
//
|
||||
// This is intentionally different from SendFileTool:
|
||||
// - SendFileTool → MediaResult + WithResponseHandled() → sends file to user, ends turn
|
||||
// - LoadImageTool → plain ToolResult with media:// in ForLLM → LLM sees the image next turn
|
||||
type LoadImageTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
}
|
||||
|
||||
func NewLoadImageTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
store media.MediaStore,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *LoadImageTool {
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &LoadImageTool{
|
||||
workspace: workspace,
|
||||
restrict: restrict,
|
||||
maxFileSize: maxFileSize,
|
||||
mediaStore: store,
|
||||
allowPaths: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *LoadImageTool) Name() string { return "load_image" }
|
||||
|
||||
func (t *LoadImageTool) Description() string {
|
||||
return "Load a local image file so you can analyze its contents with vision. " +
|
||||
"Supported formats: JPEG, PNG, GIF, WebP, BMP. " +
|
||||
"After calling this tool, describe or analyze the image in your next response."
|
||||
}
|
||||
|
||||
func (t *LoadImageTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the local image file. Relative paths are resolved from workspace.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *LoadImageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
}
|
||||
|
||||
func (t *LoadImageTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, _ := args["path"].(string)
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
// Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = t.defaultChannel
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = t.defaultChatID
|
||||
}
|
||||
if channel == "" || chatID == "" {
|
||||
return ErrorResult("no target channel/chat available")
|
||||
}
|
||||
|
||||
if t.mediaStore == nil {
|
||||
return ErrorResult("media store not configured")
|
||||
}
|
||||
|
||||
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
|
||||
}
|
||||
|
||||
info, err := os.Stat(resolved)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("file not found: %v", err))
|
||||
}
|
||||
if info.IsDir() {
|
||||
return ErrorResult("path is a directory, expected an image file")
|
||||
}
|
||||
if info.Size() > int64(t.maxFileSize) {
|
||||
return ErrorResult(fmt.Sprintf(
|
||||
"file too large: %d bytes (max %d bytes)", info.Size(), t.maxFileSize,
|
||||
))
|
||||
}
|
||||
|
||||
// Detect MIME type — reuse the helper already in send_file.go
|
||||
mediaType := detectMediaType(resolved)
|
||||
if !strings.HasPrefix(mediaType, "image/") {
|
||||
return ErrorResult(fmt.Sprintf(
|
||||
"file does not appear to be an image (detected type: %s)", mediaType,
|
||||
))
|
||||
}
|
||||
|
||||
filename := filepath.Base(resolved)
|
||||
scope := fmt.Sprintf("tool:load_image:%s:%s", channel, chatID)
|
||||
|
||||
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: mediaType,
|
||||
Source: "tool:load_image",
|
||||
CleanupPolicy: media.CleanupPolicyForgetOnly,
|
||||
}, scope)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to register image in media store: %v", err))
|
||||
}
|
||||
|
||||
// Build the tool result text. The media:// ref will be picked up by
|
||||
// resolveMediaRefs in loop_media.go and converted to a base64 data URL
|
||||
// before the next LLM call, exactly like channel-received images.
|
||||
msg := fmt.Sprintf("Image loaded: %s\n[image: %s]", filename, ref)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
ForUser: fmt.Sprintf("Loaded image: %s", filename),
|
||||
// Media refs inside ForLLM are resolved by resolveMediaRefs in the
|
||||
// agent loop before the next LLM call. Do NOT use MediaResult here —
|
||||
// that would send the file to the user channel instead.
|
||||
Media: []string{ref},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestLoadImage_PathRequired(t *testing.T) {
|
||||
tool := NewLoadImageTool("/tmp", false, 0, nil)
|
||||
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error for missing path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadImage_NilMediaStore(t *testing.T) {
|
||||
tool := NewLoadImageTool("/tmp", false, 0, nil)
|
||||
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||
result := tool.Execute(ctx, map[string]any{"path": "test.png"})
|
||||
if !result.IsError || result.ForLLM != "media store not configured" {
|
||||
t.Fatalf("expected media store error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadImage_NoChannelContext(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewLoadImageTool("/tmp", false, 0, store)
|
||||
// No WithToolContext — should fail
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": "test.png"})
|
||||
if !result.IsError || result.ForLLM != "no target channel/chat available" {
|
||||
t.Fatalf("expected channel error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadImage_NonImageFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
txtFile := filepath.Join(dir, "readme.txt")
|
||||
os.WriteFile(txtFile, []byte("hello"), 0o644)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewLoadImageTool(dir, false, 0, store)
|
||||
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||
result := tool.Execute(ctx, map[string]any{"path": txtFile})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error for non-image file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadImage_DefaultMaxSize(t *testing.T) {
|
||||
tool := NewLoadImageTool("/tmp", false, 0, nil)
|
||||
if tool.maxFileSize != config.DefaultMaxMediaSize {
|
||||
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadImage_FileTooLarge(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
bigFile := filepath.Join(dir, "big.png")
|
||||
// Create a file with PNG header but exceeding max size
|
||||
data := make([]byte, 1024)
|
||||
copy(data, []byte{0x89, 0x50, 0x4E, 0x47}) // PNG magic bytes
|
||||
os.WriteFile(bigFile, data, 0o644)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewLoadImageTool(dir, false, 512, store) // maxSize = 512
|
||||
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||
result := tool.Execute(ctx, map[string]any{"path": bigFile})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error for oversized file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentManager_SetMediaResolver_StoresResolver(t *testing.T) {
|
||||
manager := NewSubagentManager(nil, "gpt-test", "/tmp")
|
||||
|
||||
called := false
|
||||
manager.SetMediaResolver(func(msgs []providers.Message) []providers.Message {
|
||||
called = true
|
||||
return msgs
|
||||
})
|
||||
|
||||
manager.mu.RLock()
|
||||
got := manager.mediaResolver
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if got == nil {
|
||||
t.Fatal("expected mediaResolver to be set")
|
||||
}
|
||||
|
||||
if called {
|
||||
t.Fatal("resolver should not be called during SetMediaResolver")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadImage_SuccessPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a minimal valid PNG file (8-byte signature + minimal IHDR + IEND).
|
||||
// The PNG spec requires the 8-byte magic header: 0x89 P N G \r \n 0x1a \n
|
||||
pngSignature := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
|
||||
// IHDR chunk: length(13) + "IHDR" + 1x1 px, 8-bit RGB, no interlace + CRC
|
||||
ihdr := []byte{
|
||||
0x00, 0x00, 0x00, 0x0D, // chunk length = 13
|
||||
0x49, 0x48, 0x44, 0x52, // "IHDR"
|
||||
0x00, 0x00, 0x00, 0x01, // width = 1
|
||||
0x00, 0x00, 0x00, 0x01, // height = 1
|
||||
0x08, // bit depth = 8
|
||||
0x02, // color type = RGB
|
||||
0x00, 0x00, 0x00, // compression, filter, interlace
|
||||
0x90, 0x77, 0x53, 0xDE, // CRC (valid for this IHDR)
|
||||
}
|
||||
// IEND chunk
|
||||
iend := []byte{
|
||||
0x00, 0x00, 0x00, 0x00, // chunk length = 0
|
||||
0x49, 0x45, 0x4E, 0x44, // "IEND"
|
||||
0xAE, 0x42, 0x60, 0x82, // CRC
|
||||
}
|
||||
|
||||
pngData := make([]byte, 0, len(pngSignature)+len(ihdr)+len(iend))
|
||||
pngData = append(pngData, pngSignature...)
|
||||
pngData = append(pngData, ihdr...)
|
||||
pngData = append(pngData, iend...)
|
||||
|
||||
imgPath := filepath.Join(dir, "test_image.png")
|
||||
if err := os.WriteFile(imgPath, pngData, 0o644); err != nil {
|
||||
t.Fatalf("failed to create test PNG: %v", err)
|
||||
}
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewLoadImageTool(dir, false, 0, store)
|
||||
ctx := WithToolContext(context.Background(), "test", "chat1")
|
||||
|
||||
result := tool.Execute(ctx, map[string]any{"path": imgPath})
|
||||
|
||||
// 1. Must not be an error
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// 2. Media must contain exactly one media:// ref
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
if !strings.HasPrefix(result.Media[0], "media://") {
|
||||
t.Errorf("expected media ref to start with 'media://', got: %s", result.Media[0])
|
||||
}
|
||||
|
||||
// 3. ForLLM must contain the [image: marker
|
||||
if !strings.Contains(result.ForLLM, "[image:") {
|
||||
t.Errorf("expected ForLLM to contain '[image:' marker, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// 4. ForLLM should also contain the media:// ref
|
||||
if !strings.Contains(result.ForLLM, result.Media[0]) {
|
||||
t.Errorf("expected ForLLM to contain media ref %q, got: %s", result.Media[0], result.ForLLM)
|
||||
}
|
||||
|
||||
// 5. Verify the ref is resolvable in the store
|
||||
resolved, err := store.Resolve(result.Media[0])
|
||||
if err != nil {
|
||||
t.Fatalf("media ref not resolvable: %v", err)
|
||||
}
|
||||
if resolved != imgPath {
|
||||
t.Errorf("expected resolved path %q, got %q", imgPath, resolved)
|
||||
}
|
||||
}
|
||||
+115
-18
@@ -6,11 +6,14 @@ import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
@@ -26,18 +29,21 @@ type MCPManager interface {
|
||||
|
||||
// MCPTool wraps an MCP tool to implement the Tool interface
|
||||
type MCPTool struct {
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
mediaStore media.MediaStore
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
mediaStore media.MediaStore
|
||||
workspace string
|
||||
maxInlineTextRunes int
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
maxInlineTextRunes: maxMCPInlineTextRunes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +51,18 @@ func (t *MCPTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
func (t *MCPTool) SetWorkspace(workspace string) {
|
||||
t.workspace = strings.TrimSpace(workspace)
|
||||
}
|
||||
|
||||
func (t *MCPTool) SetMaxInlineTextRunes(limit int) {
|
||||
if limit > 0 {
|
||||
t.maxInlineTextRunes = limit
|
||||
}
|
||||
}
|
||||
|
||||
const maxMCPInlineTextRunes = 16 * 1024
|
||||
|
||||
// sanitizeIdentifierComponent normalizes a string so it can be safely used
|
||||
// as part of a tool/function identifier for downstream providers.
|
||||
// It:
|
||||
@@ -255,14 +273,19 @@ func extractContentText(content []mcp.Content) string {
|
||||
|
||||
func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult {
|
||||
llmParts := make([]string, 0, len(content))
|
||||
rawTextParts := make([]string, 0, len(content))
|
||||
mediaRefs := make([]string, 0, len(content))
|
||||
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
text := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
|
||||
if text != "" {
|
||||
llmParts = append(llmParts, text)
|
||||
rawText := strings.TrimSpace(v.Text)
|
||||
if rawText != "" {
|
||||
rawTextParts = append(rawTextParts, rawText)
|
||||
}
|
||||
safeText := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
|
||||
if safeText != "" {
|
||||
llmParts = append(llmParts, safeText)
|
||||
}
|
||||
case *mcp.ImageContent:
|
||||
ref, note := t.storeBinaryContent(
|
||||
@@ -295,10 +318,13 @@ func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Cont
|
||||
case *mcp.ResourceLink:
|
||||
llmParts = append(llmParts, summarizeResourceLink(v))
|
||||
case *mcp.EmbeddedResource:
|
||||
ref, note := t.storeEmbeddedResource(ctx, v)
|
||||
ref, note, rawText := t.storeEmbeddedResource(ctx, v)
|
||||
if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
if rawText != "" {
|
||||
rawTextParts = append(rawTextParts, rawText)
|
||||
}
|
||||
if note != "" {
|
||||
llmParts = append(llmParts, note)
|
||||
}
|
||||
@@ -307,34 +333,105 @@ func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Cont
|
||||
}
|
||||
}
|
||||
|
||||
forLLM := strings.Join(compactStrings(llmParts), "\n")
|
||||
rawText := strings.Join(compactStrings(rawTextParts), "\n")
|
||||
if artifactResult := t.persistLargeTextArtifact(rawText); artifactResult != nil {
|
||||
artifactResult.Media = mediaRefs
|
||||
return artifactResult
|
||||
}
|
||||
|
||||
result := &ToolResult{
|
||||
ForLLM: strings.Join(compactStrings(llmParts), "\n"),
|
||||
ForLLM: forLLM,
|
||||
Media: mediaRefs,
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) {
|
||||
func (t *MCPTool) persistLargeTextArtifact(text string) *ToolResult {
|
||||
text = strings.TrimSpace(text)
|
||||
limit := t.maxInlineTextRunes
|
||||
if limit <= 0 {
|
||||
limit = maxMCPInlineTextRunes
|
||||
}
|
||||
size := utf8.RuneCountInString(text)
|
||||
if text == "" || size <= limit || t.workspace == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
dir := filepath.Join(t.workspace, ".artifacts", "mcp")
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
// TODO: Add lifecycle cleanup/retention for MCP artifact files.
|
||||
|
||||
pattern := fmt.Sprintf(
|
||||
"%s_%s_*.txt",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
sanitizeIdentifierComponent(t.tool.Name),
|
||||
)
|
||||
tmpFile, err := os.CreateTemp(dir, pattern)
|
||||
if err != nil {
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
path := tmpFile.Name()
|
||||
if _, err = tmpFile.WriteString(text); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(path)
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(path)
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
"[MCP returned a large text result (%d chars); omitted from model context and saved as a local artifact.]",
|
||||
size,
|
||||
),
|
||||
ArtifactTags: []string{"[file:" + path + "]"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) largeTextArtifactFallback(text string, err error) *ToolResult {
|
||||
size := utf8.RuneCountInString(text)
|
||||
logger.WarnCF("tool", "Failed to persist large MCP text artifact", map[string]any{
|
||||
"server": t.serverName,
|
||||
"tool": t.tool.Name,
|
||||
"chars": size,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
"[MCP returned a large text result (%d chars); omitted from model context because artifact persistence failed.]",
|
||||
size,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string, string) {
|
||||
if content == nil || content.Resource == nil {
|
||||
return "", "[MCP returned an embedded resource without data.]"
|
||||
return "", "[MCP returned an embedded resource without data.]", ""
|
||||
}
|
||||
|
||||
resource := content.Resource
|
||||
if len(resource.Blob) > 0 {
|
||||
return t.storeBinaryContent(
|
||||
ref, note := t.storeBinaryContent(
|
||||
ctx,
|
||||
"resource",
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
resource.Blob,
|
||||
content.Annotations,
|
||||
)
|
||||
return ref, note, ""
|
||||
}
|
||||
|
||||
if strings.TrimSpace(resource.Text) != "" {
|
||||
return "", sanitizeToolLLMContent(resource.Text)
|
||||
rawText := strings.TrimSpace(resource.Text)
|
||||
if rawText != "" {
|
||||
return "", sanitizeToolLLMContent(resource.Text), rawText
|
||||
}
|
||||
|
||||
return "", summarizeEmbeddedResource(content)
|
||||
return "", summarizeEmbeddedResource(content), ""
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeBinaryContent(
|
||||
|
||||
@@ -634,3 +634,177 @@ func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) {
|
||||
t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeBase64TextArtifactPreservesRawPayload(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
largeBase64 := strings.Repeat("QUJD", 400)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeBase64},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(32)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if !strings.Contains(result.ForLLM, "saved as a local artifact") {
|
||||
t.Fatalf("expected artifact note, got %q", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM == largeBase64OmittedMessage {
|
||||
t.Fatalf("expected artifact note instead of sanitized base64 placeholder")
|
||||
}
|
||||
if len(result.ArtifactTags) != 1 {
|
||||
t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
|
||||
}
|
||||
tag := result.ArtifactTags[0]
|
||||
const prefix = "[file:"
|
||||
if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
|
||||
t.Fatalf("expected file artifact tag, got %q", tag)
|
||||
}
|
||||
path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected artifact file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != largeBase64 {
|
||||
t.Fatalf("expected artifact file contents to preserve raw MCP payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeTextStoredAsArtifact(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspace)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
|
||||
t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "saved as a local artifact") {
|
||||
t.Fatalf("expected artifact note, got %q", result.ForLLM)
|
||||
}
|
||||
if len(result.ArtifactTags) != 1 {
|
||||
t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
|
||||
}
|
||||
tag := result.ArtifactTags[0]
|
||||
const prefix = "[file:"
|
||||
if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
|
||||
t.Fatalf("expected file artifact tag, got %q", tag)
|
||||
}
|
||||
path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
|
||||
if !strings.HasPrefix(path, workspace) {
|
||||
t.Fatalf("expected artifact inside workspace, got %q", path)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected artifact file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != strings.TrimSpace(largeText) {
|
||||
t.Fatalf("expected artifact file contents to match source text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_CustomInlineTextThreshold(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
text := strings.Repeat("small custom threshold text\n", 20)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: text},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(32)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if len(result.ArtifactTags) != 1 {
|
||||
t.Fatalf("expected custom threshold to persist artifact, got %+v", result)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "small custom threshold text") {
|
||||
t.Fatalf("expected text to be omitted from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeTextArtifactFailureStillOmitsContext(t *testing.T) {
|
||||
workspaceRoot := t.TempDir()
|
||||
workspaceFile := filepath.Join(workspaceRoot, "not-a-directory")
|
||||
if err := os.WriteFile(workspaceFile, []byte("x"), 0o600); err != nil {
|
||||
t.Fatalf("failed to create workspace file: %v", err)
|
||||
}
|
||||
|
||||
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspaceFile)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
|
||||
t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "artifact persistence failed") {
|
||||
t.Fatalf("expected persistence failure note, got %q", result.ForLLM)
|
||||
}
|
||||
if len(result.ArtifactTags) != 0 {
|
||||
t.Fatalf("expected no artifact tags on persistence failure, got %+v", result.ArtifactTags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_WhitespaceWorkspaceDisablesArtifactPersistence(t *testing.T) {
|
||||
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(" \n\t ")
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if len(result.ArtifactTags) != 0 {
|
||||
t.Fatalf("expected no artifact tags for whitespace workspace, got %+v", result.ArtifactTags)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "This is a large MCP text payload") {
|
||||
t.Fatalf("expected large text to remain inline when workspace is blank, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,6 +228,7 @@ func (r *ToolRegistry) ExecuteWithContext(
|
||||
func() {
|
||||
defer func() {
|
||||
if re := recover(); re != nil {
|
||||
logger.RecoverPanicNoExit(re)
|
||||
errMsg := fmt.Sprintf("Tool '%s' crashed with panic: %v", name, re)
|
||||
logger.ErrorCF("tool", "Tool execution panic recovered",
|
||||
map[string]any{
|
||||
|
||||
@@ -67,6 +67,12 @@ type SubagentManager struct {
|
||||
hasTemperature bool
|
||||
nextID int
|
||||
spawner SpawnSubTurnFunc
|
||||
|
||||
// mediaResolver resolves media:// refs in tool-loop messages before
|
||||
// each LLM call in the legacy RunToolLoop fallback path.
|
||||
// This lets subagents reuse the same media handling behavior as the
|
||||
// main agent loop without importing pkg/agent and creating a cycle.
|
||||
mediaResolver func([]providers.Message) []providers.Message
|
||||
}
|
||||
|
||||
func NewSubagentManager(
|
||||
@@ -90,6 +96,17 @@ func (sm *SubagentManager) SetSpawner(spawner SpawnSubTurnFunc) {
|
||||
sm.spawner = spawner
|
||||
}
|
||||
|
||||
// SetMediaResolver injects a message preprocessor that resolves media:// refs
|
||||
// into LLM-ready content before each tool-loop iteration.
|
||||
// This is only used by the legacy RunToolLoop fallback path.
|
||||
func (sm *SubagentManager) SetMediaResolver(
|
||||
resolver func([]providers.Message) []providers.Message,
|
||||
) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.mediaResolver = resolver
|
||||
}
|
||||
|
||||
// SetLLMOptions sets max tokens and temperature for subagent LLM calls.
|
||||
func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) {
|
||||
sm.mu.Lock()
|
||||
@@ -177,6 +194,7 @@ func (sm *SubagentManager) runTask(
|
||||
temperature := sm.temperature
|
||||
hasMaxTokens := sm.hasMaxTokens
|
||||
hasTemperature := sm.hasTemperature
|
||||
mediaResolver := sm.mediaResolver
|
||||
sm.mu.RUnlock()
|
||||
|
||||
var result *ToolResult
|
||||
@@ -223,6 +241,7 @@ After completing the task, provide a clear summary of what was done.`
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: llmOptions,
|
||||
MediaResolver: mediaResolver,
|
||||
}, messages, task.OriginChannel, task.OriginChatID)
|
||||
|
||||
if err == nil {
|
||||
|
||||
+32
-4
@@ -24,6 +24,11 @@ type ToolLoopConfig struct {
|
||||
Tools *ToolRegistry
|
||||
MaxIterations int
|
||||
LLMOptions map[string]any
|
||||
|
||||
// MediaResolver resolves media:// refs in messages before each LLM call.
|
||||
// This is optional and is mainly used by subagent legacy fallback execution
|
||||
// so subagents can reuse the same multimodal media handling as the main loop.
|
||||
MediaResolver func(messages []providers.Message) []providers.Message
|
||||
}
|
||||
|
||||
// ToolLoopResult contains the result of running the tool loop.
|
||||
@@ -63,8 +68,27 @@ func RunToolLoop(
|
||||
if llmOpts == nil {
|
||||
llmOpts = map[string]any{}
|
||||
}
|
||||
// 3. Call LLM
|
||||
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
|
||||
|
||||
// 3. Resolve media:// refs and Call LLM.
|
||||
// Tools like load_image produce media:// refs in their result messages.
|
||||
// Without this step, the LLM would receive raw "media://uuid" strings
|
||||
// instead of base64-encoded image data URLs.
|
||||
//
|
||||
// We build a separate callMessages slice so that:
|
||||
// (a) the resolver output is used for the LLM call only,
|
||||
// (b) the original `messages` slice keeps the unresolved refs for
|
||||
// subsequent iterations — the resolver is idempotent but working
|
||||
// on the original avoids double-encoding issues.
|
||||
//
|
||||
// On iteration 1 the initial user messages typically have no media://
|
||||
// refs (they come from plain text), so this is effectively a no-op;
|
||||
// it becomes relevant from iteration 2 onward when tool results may
|
||||
// contain media refs.
|
||||
callMessages := messages
|
||||
if config.MediaResolver != nil && iteration > 1 {
|
||||
callMessages = config.MediaResolver(messages)
|
||||
}
|
||||
response, err := config.Provider.Chat(ctx, callMessages, providerToolDefs, config.Model, llmOpts)
|
||||
if err != nil {
|
||||
logger.ErrorCF("toolloop", "LLM call failed",
|
||||
map[string]any{
|
||||
@@ -161,11 +185,15 @@ func RunToolLoop(
|
||||
for _, r := range results {
|
||||
contentForLLM := r.result.ContentForLLM()
|
||||
|
||||
messages = append(messages, providers.Message{
|
||||
toolMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: r.tc.ID,
|
||||
})
|
||||
}
|
||||
if len(r.result.Media) > 0 && !r.result.ResponseHandled {
|
||||
toolMsg.Media = append(toolMsg.Media, r.result.Media...)
|
||||
}
|
||||
messages = append(messages, toolMsg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
type SendTTSTool struct {
|
||||
provider tts.TTSProvider
|
||||
mediaStore media.MediaStore
|
||||
}
|
||||
|
||||
func NewSendTTSTool(provider tts.TTSProvider, store media.MediaStore) *SendTTSTool {
|
||||
return &SendTTSTool{
|
||||
provider: provider,
|
||||
mediaStore: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Name() string { return "send_tts" }
|
||||
|
||||
func (t *SendTTSTool) Description() string {
|
||||
return "Synthesize speech from text and send it as an audio file to the user."
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to synthesize into speech. NOTE: Reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally.",
|
||||
},
|
||||
"filename": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional filename for the audio file (e.g., response.ogg).",
|
||||
},
|
||||
},
|
||||
"required": []string{"text"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
text, _ := args["text"].(string)
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ErrorResult("text is required")
|
||||
}
|
||||
|
||||
channel := ToolChannel(ctx)
|
||||
chatID := ToolChatID(ctx)
|
||||
filename, _ := args["filename"].(string)
|
||||
|
||||
ref, err := tts.SynthesizeAndStore(
|
||||
ctx,
|
||||
t.provider,
|
||||
t.mediaStore,
|
||||
text,
|
||||
filename,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
|
||||
// Return with ForUser set to original text, Media containing the audio ref,
|
||||
// and mark as ResponseHandled so the audio is sent immediately without LLM intervention.
|
||||
return &ToolResult{
|
||||
ForLLM: "TTS audio sent",
|
||||
ForUser: text,
|
||||
Media: []string{ref},
|
||||
ResponseHandled: true,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user