refactor(tools): reorganize tool packages and facades

This commit is contained in:
lc6464
2026-04-17 13:44:31 +08:00
parent ee634dc8db
commit 4c133dc2d9
44 changed files with 778 additions and 98 deletions
+134
View File
@@ -0,0 +1,134 @@
package integrationtools
import (
"fmt"
"math"
"mime"
"path/filepath"
"regexp"
"strconv"
"strings"
"unicode"
)
var (
inlineMarkdownDataURLRe = regexp.MustCompile(`!\[[^\]]*\]\((data:[^)]+)\)`)
inlineRawDataURLRe = regexp.MustCompile(`data:[^;\s]+;base64,[A-Za-z0-9+/=\r\n]+`)
)
const (
largeBase64OmittedMessage = "[Tool returned a large base64-like payload; omitted from model context.]"
inlineMediaOmittedMessage = "[Tool returned inline media content; omitted from model context.]"
)
func sanitizeToolLLMContent(text string) string {
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return text
}
if inlineMarkdownDataURLRe.MatchString(trimmed) || inlineRawDataURLRe.MatchString(trimmed) {
cleaned := inlineMarkdownDataURLRe.ReplaceAllString(trimmed, "")
cleaned = inlineRawDataURLRe.ReplaceAllString(cleaned, "")
cleaned = strings.TrimSpace(cleaned)
if cleaned == "" {
return inlineMediaOmittedMessage
}
return cleaned + "\n" + inlineMediaOmittedMessage
}
if looksLikeLargeBase64Payload(trimmed) {
return largeBase64OmittedMessage
}
return text
}
func looksLikeLargeBase64Payload(text string) bool {
trimmed := strings.TrimSpace(text)
if len(trimmed) < 1024 {
return false
}
nonSpace := 0
base64Like := 0
spaceCount := 0
for _, r := range trimmed {
if unicode.IsSpace(r) {
spaceCount++
continue
}
nonSpace++
if (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '+' || r == '/' || r == '=' {
base64Like++
}
}
if nonSpace == 0 {
return false
}
ratio := float64(base64Like) / float64(nonSpace)
return ratio >= 0.97 && spaceCount <= len(trimmed)/128
}
func extensionForMIMEType(mimeType string) string {
if mimeType == "" {
return ".bin"
}
if exts, err := mime.ExtensionsByType(mimeType); err == nil && len(exts) > 0 {
return exts[0]
}
switch strings.ToLower(mimeType) {
case "image/jpeg":
return ".jpg"
case "image/png":
return ".png"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "audio/wav", "audio/x-wav":
return ".wav"
case "audio/mpeg":
return ".mp3"
case "audio/ogg":
return ".ogg"
case "video/mp4":
return ".mp4"
default:
return filepath.Ext(mimeType)
}
}
func getInt64Arg(args map[string]any, key string, defaultVal int64) (int64, error) {
raw, exists := args[key]
if !exists {
return defaultVal, nil
}
switch v := raw.(type) {
case float64:
if v != math.Trunc(v) {
return 0, fmt.Errorf("%s must be an integer, got float %v", key, v)
}
if v > math.MaxInt64 || v < math.MinInt64 {
return 0, fmt.Errorf("%s value %v overflows int64", key, v)
}
return int64(v), nil
case int:
return int64(v), nil
case int64:
return v, nil
case string:
parsed, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid integer format for %s parameter: %w", key, err)
}
return parsed, nil
default:
return 0, fmt.Errorf("unsupported type %T for %s parameter", raw, key)
}
}
+601
View File
@@ -0,0 +1,601 @@
package integrationtools
import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"os"
"path/filepath"
"strings"
"time"
"unicode/utf8"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
)
// MCPManager defines the interface for MCP manager operations
// This allows for easier testing with mock implementations
type MCPManager interface {
CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error)
}
// MCPTool wraps an MCP tool to implement the Tool interface
type MCPTool struct {
manager MCPManager
serverName string
tool *mcp.Tool
mediaStore media.MediaStore
workspace string
maxInlineTextRunes int
}
// NewMCPTool creates a new MCP tool wrapper
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
return &MCPTool{
manager: manager,
serverName: serverName,
tool: tool,
maxInlineTextRunes: maxMCPInlineTextRunes,
}
}
func (t *MCPTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
func (t *MCPTool) SetWorkspace(workspace string) {
t.workspace = strings.TrimSpace(workspace)
}
func (t *MCPTool) SetMaxInlineTextRunes(limit int) {
if limit > 0 {
t.maxInlineTextRunes = limit
}
}
const maxMCPInlineTextRunes = 16 * 1024
// sanitizeIdentifierComponent normalizes a string so it can be safely used
// as part of a tool/function identifier for downstream providers.
// It:
// - lowercases the string
// - replaces any character not in [a-z0-9_-] with '_'
// - collapses multiple consecutive '_' into a single '_'
// - trims leading/trailing '_'
// - falls back to "unnamed" if the result is empty
// - truncates overly long components to a reasonable length
func sanitizeIdentifierComponent(s string) string {
const maxLen = 64
s = strings.ToLower(s)
var b strings.Builder
b.Grow(len(s))
prevUnderscore := false
for _, r := range s {
isAllowed := (r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' || r == '-'
if !isAllowed {
// Normalize any disallowed character to '_'
if !prevUnderscore {
b.WriteRune('_')
prevUnderscore = true
}
continue
}
if r == '_' {
if prevUnderscore {
continue
}
prevUnderscore = true
} else {
prevUnderscore = false
}
b.WriteRune(r)
}
result := strings.Trim(b.String(), "_")
if result == "" {
result = "unnamed"
}
if len(result) > maxLen {
result = result[:maxLen]
}
return result
}
// Name returns the tool name, prefixed with the server name.
// The total length is capped at 64 characters (OpenAI-compatible API limit).
// A short hash of the original (unsanitized) server and tool names is appended
// whenever sanitization is lossy or the name is truncated, ensuring that two
// names which differ only in disallowed characters remain distinct after sanitization.
func (t *MCPTool) Name() string {
// Prefix with server name to avoid conflicts, and sanitize components
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
strings.ToLower(t.tool.Name) == sanitizedTool
const maxTotal = 64
if lossless && len(full) <= maxTotal {
return full
}
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
// (not the sanitized names) so different originals always yield different hashes.
h := fnv.New32a()
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
base := full
if len(base) > maxTotal-9 {
base = strings.TrimRight(full[:maxTotal-9], "_")
}
return base + "_" + suffix
}
// Description returns the tool description
func (t *MCPTool) Description() string {
desc := t.tool.Description
if desc == "" {
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
}
// Add server info to description
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
}
// Parameters returns the tool parameters schema
func (t *MCPTool) Parameters() map[string]any {
// The InputSchema is already a JSON Schema object
schema := t.tool.InputSchema
// Handle nil schema
if schema == nil {
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// Try direct conversion first (fast path)
if schemaMap, ok := schema.(map[string]any); ok {
return schemaMap
}
// Handle json.RawMessage and []byte - unmarshal directly
var jsonData []byte
if rawMsg, ok := schema.(json.RawMessage); ok {
jsonData = rawMsg
} else if bytes, ok := schema.([]byte); ok {
jsonData = bytes
}
if jsonData != nil {
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err == nil {
return result
}
// Fallback on error
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// For other types (structs, etc.), convert via JSON marshal/unmarshal
var err error
jsonData, err = json.Marshal(schema)
if err != nil {
// Fallback to empty schema if marshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err != nil {
// Fallback to empty schema if unmarshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
return result
}
// Execute executes the MCP tool
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
if err != nil {
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
}
if result == nil {
nilErr := fmt.Errorf("MCP tool returned nil result without error")
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
}
// Handle error result from server
if result.IsError {
errMsg := extractContentText(result.Content)
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
}
return t.normalizeResultContent(ctx, result.Content)
}
// extractContentText extracts text from MCP content array
func extractContentText(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
parts = append(parts, sanitizeToolLLMContent(v.Text))
case *mcp.ImageContent:
parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType)))
case *mcp.AudioContent:
parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType)))
case *mcp.ResourceLink:
parts = append(parts, summarizeResourceLink(v))
case *mcp.EmbeddedResource:
parts = append(parts, summarizeEmbeddedResource(v))
default:
// For other content types, use string representation
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
}
}
return sanitizeToolLLMContent(strings.Join(parts, "\n"))
}
func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult {
llmParts := make([]string, 0, len(content))
rawTextParts := make([]string, 0, len(content))
mediaRefs := make([]string, 0, len(content))
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
rawText := strings.TrimSpace(v.Text)
if rawText != "" {
rawTextParts = append(rawTextParts, rawText)
}
safeText := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
if safeText != "" {
llmParts = append(llmParts, safeText)
}
case *mcp.ImageContent:
ref, note := t.storeBinaryContent(
ctx,
"image",
normalizedMIMEType(v.MIMEType),
v.Data,
v.Annotations,
)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if note != "" {
llmParts = append(llmParts, note)
}
case *mcp.AudioContent:
ref, note := t.storeBinaryContent(
ctx,
"audio",
normalizedMIMEType(v.MIMEType),
v.Data,
v.Annotations,
)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if note != "" {
llmParts = append(llmParts, note)
}
case *mcp.ResourceLink:
llmParts = append(llmParts, summarizeResourceLink(v))
case *mcp.EmbeddedResource:
ref, note, rawText := t.storeEmbeddedResource(ctx, v)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if rawText != "" {
rawTextParts = append(rawTextParts, rawText)
}
if note != "" {
llmParts = append(llmParts, note)
}
default:
llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v))
}
}
forLLM := strings.Join(compactStrings(llmParts), "\n")
rawText := strings.Join(compactStrings(rawTextParts), "\n")
if artifactResult := t.persistLargeTextArtifact(rawText); artifactResult != nil {
artifactResult.Media = mediaRefs
return artifactResult
}
result := &ToolResult{
ForLLM: forLLM,
Media: mediaRefs,
}
return result
}
func (t *MCPTool) persistLargeTextArtifact(text string) *ToolResult {
text = strings.TrimSpace(text)
limit := t.maxInlineTextRunes
if limit <= 0 {
limit = maxMCPInlineTextRunes
}
size := utf8.RuneCountInString(text)
if text == "" || size <= limit || t.workspace == "" {
return nil
}
dir := filepath.Join(t.workspace, ".artifacts", "mcp")
if err := os.MkdirAll(dir, 0o700); err != nil {
return t.largeTextArtifactFallback(text, err)
}
// TODO: Add lifecycle cleanup/retention for MCP artifact files.
pattern := fmt.Sprintf(
"%s_%s_*.txt",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
)
tmpFile, err := os.CreateTemp(dir, pattern)
if err != nil {
return t.largeTextArtifactFallback(text, err)
}
path := tmpFile.Name()
if _, err = tmpFile.WriteString(text); err != nil {
_ = tmpFile.Close()
_ = os.Remove(path)
return t.largeTextArtifactFallback(text, err)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(path)
return t.largeTextArtifactFallback(text, err)
}
return &ToolResult{
ForLLM: fmt.Sprintf(
"[MCP returned a large text result (%d chars); omitted from model context and saved as a local artifact.]",
size,
),
ArtifactTags: []string{"[file:" + path + "]"},
}
}
func (t *MCPTool) largeTextArtifactFallback(text string, err error) *ToolResult {
size := utf8.RuneCountInString(text)
logger.WarnCF("tool", "Failed to persist large MCP text artifact", map[string]any{
"server": t.serverName,
"tool": t.tool.Name,
"chars": size,
"error": err.Error(),
})
return &ToolResult{
ForLLM: fmt.Sprintf(
"[MCP returned a large text result (%d chars); omitted from model context because artifact persistence failed.]",
size,
),
}
}
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string, string) {
if content == nil || content.Resource == nil {
return "", "[MCP returned an embedded resource without data.]", ""
}
resource := content.Resource
if len(resource.Blob) > 0 {
ref, note := t.storeBinaryContent(
ctx,
"resource",
normalizedMIMEType(resource.MIMEType),
resource.Blob,
content.Annotations,
)
return ref, note, ""
}
rawText := strings.TrimSpace(resource.Text)
if rawText != "" {
return "", sanitizeToolLLMContent(resource.Text), rawText
}
return "", summarizeEmbeddedResource(content), ""
}
func (t *MCPTool) storeBinaryContent(
ctx context.Context,
kind string,
mimeType string,
data []byte,
annotations *mcp.Annotations,
) (string, string) {
if len(data) == 0 {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType)
}
if !annotationsAllowUser(annotations) {
return "", fmt.Sprintf(
"[MCP returned %s content (%s) for non-user audience; omitted from model context.]",
kind,
mimeType,
)
}
if t.mediaStore == nil {
return "", fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]",
kind,
mimeType,
)
}
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
if channel == "" || chatID == "" {
return "", fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context because no target chat was available.]",
kind,
mimeType,
)
}
dir := media.TempDir()
if err := os.MkdirAll(dir, 0o700); err != nil {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
ext := extensionForMIMEType(mimeType)
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
if err != nil {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
tmpPath := tmpFile.Name()
if _, err = tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
scope := fmt.Sprintf(
"tool:mcp:%s:%s:%s:%d",
sanitizeIdentifierComponent(t.serverName),
channel,
chatID,
time.Now().UnixNano(),
)
filename := fmt.Sprintf(
"%s_%s%s",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
ext,
)
ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{
Filename: filename,
ContentType: mimeType,
Source: fmt.Sprintf(
"tool:mcp:%s:%s",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
),
}, scope)
if err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[MCP returned %s content (%s) but it could not be registered as media.]",
kind,
mimeType,
)
}
return ref, fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]",
kind,
mimeType,
)
}
func summarizeResourceLink(content *mcp.ResourceLink) string {
if content == nil {
return "[MCP returned an empty resource link.]"
}
parts := []string{"[MCP returned resource link"}
if content.Name != "" {
parts = append(parts, fmt.Sprintf("name=%q", content.Name))
}
if content.URI != "" {
parts = append(parts, fmt.Sprintf("uri=%q", content.URI))
}
if content.MIMEType != "" {
parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType))
}
if content.Description != "" {
desc := strings.TrimSpace(content.Description)
if len(desc) > 200 {
desc = desc[:200] + "..."
}
parts = append(parts, fmt.Sprintf("description=%q", desc))
}
return strings.Join(parts, ", ") + "]"
}
func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
if content == nil || content.Resource == nil {
return "[MCP returned an embedded resource.]"
}
resource := content.Resource
if resource.URI != "" {
return fmt.Sprintf(
"[MCP returned embedded resource %q (%s).]",
resource.URI,
normalizedMIMEType(resource.MIMEType),
)
}
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
}
func annotationsAllowUser(annotations *mcp.Annotations) bool {
if annotations == nil || len(annotations.Audience) == 0 {
return true
}
for _, audience := range annotations.Audience {
if strings.EqualFold(string(audience), "user") {
return true
}
}
return false
}
func normalizedMIMEType(mimeType string) string {
if strings.TrimSpace(mimeType) == "" {
return "application/octet-stream"
}
return mimeType
}
func compactStrings(parts []string) []string {
compact := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) == "" {
continue
}
compact = append(compact, part)
}
return compact
}
+810
View File
@@ -0,0 +1,810 @@
package integrationtools
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/media"
)
// MockMCPManager is a mock implementation of MCPManager interface for testing
type MockMCPManager struct {
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
}
func (m *MockMCPManager) CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error) {
if m.callToolFunc != nil {
return m.callToolFunc(ctx, serverName, toolName, arguments)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "mock result"},
},
IsError: false,
}, nil
}
// TestNewMCPTool verifies MCP tool creation
func TestNewMCPTool(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: "A test tool",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"input": map[string]any{
"type": "string",
"description": "Test input",
},
},
},
}
mcpTool := NewMCPTool(manager, "test_server", tool)
if mcpTool == nil {
t.Fatal("NewMCPTool should not return nil")
}
// Verify tool properties we can access
if mcpTool.Name() != "mcp_test_server_test_tool" {
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
}
}
// TestMCPTool_Name verifies tool name with server prefix
func TestMCPTool_Name(t *testing.T) {
tests := []struct {
name string
serverName string
toolName string
expected string
}{
{
name: "simple name",
serverName: "github",
toolName: "create_issue",
expected: "mcp_github_create_issue",
},
{
name: "filesystem server",
serverName: "filesystem",
toolName: "read_file",
expected: "mcp_filesystem_read_file",
},
{
name: "remote server",
serverName: "remote-api",
toolName: "fetch_data",
expected: "mcp_remote-api_fetch_data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: tt.toolName}
mcpTool := NewMCPTool(manager, tt.serverName, tool)
result := mcpTool.Name()
if result != tt.expected {
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestMCPTool_Description verifies tool description generation
func TestMCPTool_Description(t *testing.T) {
tests := []struct {
name string
serverName string
toolDescription string
expectContains []string
}{
{
name: "with description",
serverName: "github",
toolDescription: "Create a GitHub issue",
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
},
{
name: "empty description",
serverName: "filesystem",
toolDescription: "",
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: tt.toolDescription,
}
mcpTool := NewMCPTool(manager, tt.serverName, tool)
result := mcpTool.Description()
for _, expected := range tt.expectContains {
if !strings.Contains(result, expected) {
t.Errorf("Description should contain '%s', got: %s", expected, result)
}
}
})
}
}
// TestMCPTool_Parameters verifies parameter schema conversion
func TestMCPTool_Parameters(t *testing.T) {
tests := []struct {
name string
inputSchema any
expectType string
checkProperty string
expectProperty bool
}{
{
name: "map schema",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query",
},
},
"required": []string{"query"},
},
expectType: "object",
checkProperty: "query",
expectProperty: true,
},
{
name: "nil schema",
inputSchema: nil,
expectType: "object",
expectProperty: false,
},
{
name: "json.RawMessage schema",
inputSchema: []byte(`{
"type": "object",
"properties": {
"repo": {
"type": "string",
"description": "Repository name"
},
"stars": {
"type": "integer",
"description": "Minimum stars"
}
},
"required": ["repo"]
}`),
expectType: "object",
checkProperty: "repo",
expectProperty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: tt.inputSchema,
}
mcpTool := NewMCPTool(manager, "test_server", tool)
params := mcpTool.Parameters()
if params == nil {
t.Fatal("Parameters should not be nil")
}
if params["type"] != tt.expectType {
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
}
// Check if property exists when expected
if tt.checkProperty != "" {
properties, ok := params["properties"].(map[string]any)
if !ok && tt.expectProperty {
t.Errorf("Expected properties to be a map")
return
}
if ok {
_, hasProperty := properties[tt.checkProperty]
if hasProperty != tt.expectProperty {
t.Errorf("Expected property '%s' existence: %v, got: %v",
tt.checkProperty, tt.expectProperty, hasProperty)
}
}
}
})
}
}
// TestMCPTool_Execute_Success tests successful tool execution
func TestMCPTool_Execute_Success(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
// Verify correct parameters passed
if serverName != "github" {
t.Errorf("Expected serverName 'github', got '%s'", serverName)
}
if toolName != "search_repos" {
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Found 3 repositories"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{
Name: "search_repos",
Description: "Search GitHub repositories",
}
mcpTool := NewMCPTool(manager, "github", tool)
ctx := context.Background()
args := map[string]any{
"query": "golang mcp",
}
result := mcpTool.Execute(ctx, args)
if result == nil {
t.Fatal("Result should not be nil")
}
if result.IsError {
t.Errorf("Expected no error, got error: %s", result.ForLLM)
}
if result.ForLLM != "Found 3 repositories" {
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
}
}
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
func TestMCPTool_Execute_ManagerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("connection failed")
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "connection failed") {
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_ServerError tests execution when server returns error
func TestMCPTool_Execute_ServerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Invalid API key"},
},
IsError: true,
}, nil
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Invalid API key") {
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "First line"},
&mcp.TextContent{Text: "Second line"},
&mcp.TextContent{Text: "Third line"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{Name: "multi_output"}
mcpTool := NewMCPTool(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]any{})
if result.IsError {
t.Errorf("Expected no error, got: %s", result.ForLLM)
}
expected := "First line\nSecond line\nThird line"
if result.ForLLM != expected {
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
}
}
// TestExtractContentText_TextContent tests text content extraction
func TestExtractContentText_TextContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Hello World"},
&mcp.TextContent{Text: "Second message"},
}
result := extractContentText(content)
expected := "Hello World\nSecond message"
if result != expected {
t.Errorf("Expected '%s', got '%s'", expected, result)
}
}
// TestExtractContentText_ImageContent tests image content extraction
func TestExtractContentText_ImageContent(t *testing.T) {
content := []mcp.Content{
&mcp.ImageContent{
Data: []byte("base64data"),
MIMEType: "image/png",
},
}
result := extractContentText(content)
if !strings.Contains(result, "[Image:") {
t.Errorf("Expected image indicator, got: %s", result)
}
if !strings.Contains(result, "image/png") {
t.Errorf("Expected MIME type in output, got: %s", result)
}
}
// TestExtractContentText_MixedContent tests mixed content types
func TestExtractContentText_MixedContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Description"},
&mcp.ImageContent{
Data: []byte("data"),
MIMEType: "image/jpeg",
},
&mcp.TextContent{Text: "More text"},
}
result := extractContentText(content)
if !strings.Contains(result, "Description") {
t.Errorf("Should contain text content, got: %s", result)
}
if !strings.Contains(result, "[Image:") {
t.Errorf("Should contain image indicator, got: %s", result)
}
if !strings.Contains(result, "More text") {
t.Errorf("Should contain second text, got: %s", result)
}
}
// TestExtractContentText_EmptyContent tests empty content array
func TestExtractContentText_EmptyContent(t *testing.T) {
content := []mcp.Content{}
result := extractContentText(content)
if result != "" {
t.Errorf("Expected empty string for empty content, got: %s", result)
}
}
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
func TestMCPTool_InterfaceCompliance(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: "test"}
mcpTool := NewMCPTool(manager, "test_server", tool)
// Verify it implements Tool interface
var _ Tool = mcpTool
}
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
manager := &MockMCPManager{}
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"description": "The name parameter",
},
},
"required": []string{"name"},
}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: schema,
}
mcpTool := NewMCPTool(manager, "test_server", tool)
params := mcpTool.Parameters()
// Should return the schema as-is when it's already a map
if params["type"] != "object" {
t.Errorf("Expected type 'object', got '%v'", params["type"])
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Error("Properties should be a map")
}
nameParam, ok := props["name"].(map[string]any)
if !ok {
t.Error("Name parameter should exist")
}
if nameParam["type"] != "string" {
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
}
}
func TestMCPTool_Execute_ImageContentStoredAsMedia(t *testing.T) {
store := media.NewFileMediaStore()
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.ImageContent{
Data: []byte("fake-image-bytes"),
MIMEType: "image/png",
},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
mcpTool.SetMediaStore(store)
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
if result.IsError {
t.Fatalf("expected success, got %q", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
if result.ResponseHandled {
t.Fatal("expected MCP image artifact not to mark response as handled")
}
if !strings.Contains(result.ForLLM, "stored as a local media artifact") {
t.Fatalf("expected local media artifact note, got %q", result.ForLLM)
}
path, meta, err := store.ResolveWithMeta(result.Media[0])
if err != nil {
t.Fatalf("expected stored media ref to resolve: %v", err)
}
if meta.ContentType != "image/png" {
t.Fatalf("expected image/png content type, got %q", meta.ContentType)
}
if filepath.Ext(path) != ".png" {
t.Fatalf("expected png temp file, got %q", path)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("expected stored media file to be readable: %v", err)
}
if string(data) != "fake-image-bytes" {
t.Fatalf("expected stored media bytes to match input, got %q", string(data))
}
}
func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) {
store := media.NewFileMediaStore()
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.EmbeddedResource{
Resource: &mcp.ResourceContents{
URI: "file:///tmp/report.png",
MIMEType: "image/png",
Blob: []byte("blob-bytes"),
},
},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "grafana", &mcp.Tool{Name: "get_dashboard_image"})
mcpTool.SetMediaStore(store)
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
if len(result.Media) != 1 {
t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media))
}
path, _, err := store.ResolveWithMeta(result.Media[0])
if err != nil {
t.Fatalf("expected stored media ref to resolve: %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("expected stored media file to be readable: %v", err)
}
if string(data) != "blob-bytes" {
t.Fatalf("expected stored blob bytes to match input, got %q", string(data))
}
}
func TestMCPTool_Execute_RespectsUserAudienceForBinaryContent(t *testing.T) {
store := media.NewFileMediaStore()
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.ImageContent{
Data: []byte("assistant-only"),
MIMEType: "image/png",
Annotations: &mcp.Annotations{Audience: []mcp.Role{"assistant"}},
},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
mcpTool.SetMediaStore(store)
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
if len(result.Media) != 0 {
t.Fatalf("expected no media ref for non-user audience, got %d", len(result.Media))
}
if !strings.Contains(result.ForLLM, "non-user audience") {
t.Fatalf("expected audience note, got %q", result.ForLLM)
}
}
func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: strings.Repeat("QUJD", 400)},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
result := mcpTool.Execute(context.Background(), nil)
if result.ForLLM != largeBase64OmittedMessage {
t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM)
}
}
func TestMCPTool_Execute_LargeBase64TextArtifactPreservesRawPayload(t *testing.T) {
workspace := t.TempDir()
largeBase64 := strings.Repeat("QUJD", 400)
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: largeBase64},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
mcpTool.SetWorkspace(workspace)
mcpTool.SetMaxInlineTextRunes(32)
result := mcpTool.Execute(context.Background(), nil)
if !strings.Contains(result.ForLLM, "saved as a local artifact") {
t.Fatalf("expected artifact note, got %q", result.ForLLM)
}
if result.ForLLM == largeBase64OmittedMessage {
t.Fatalf("expected artifact note instead of sanitized base64 placeholder")
}
if len(result.ArtifactTags) != 1 {
t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
}
tag := result.ArtifactTags[0]
const prefix = "[file:"
if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
t.Fatalf("expected file artifact tag, got %q", tag)
}
path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("expected artifact file to be readable: %v", err)
}
if string(data) != largeBase64 {
t.Fatalf("expected artifact file contents to preserve raw MCP payload")
}
}
func TestMCPTool_Execute_LargeTextStoredAsArtifact(t *testing.T) {
workspace := t.TempDir()
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: largeText},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
mcpTool.SetWorkspace(workspace)
result := mcpTool.Execute(context.Background(), nil)
if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "saved as a local artifact") {
t.Fatalf("expected artifact note, got %q", result.ForLLM)
}
if len(result.ArtifactTags) != 1 {
t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
}
tag := result.ArtifactTags[0]
const prefix = "[file:"
if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
t.Fatalf("expected file artifact tag, got %q", tag)
}
path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
if !strings.HasPrefix(path, workspace) {
t.Fatalf("expected artifact inside workspace, got %q", path)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("expected artifact file to be readable: %v", err)
}
if string(data) != strings.TrimSpace(largeText) {
t.Fatalf("expected artifact file contents to match source text")
}
}
func TestMCPTool_Execute_CustomInlineTextThreshold(t *testing.T) {
workspace := t.TempDir()
text := strings.Repeat("small custom threshold text\n", 20)
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: text},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
mcpTool.SetWorkspace(workspace)
mcpTool.SetMaxInlineTextRunes(32)
result := mcpTool.Execute(context.Background(), nil)
if len(result.ArtifactTags) != 1 {
t.Fatalf("expected custom threshold to persist artifact, got %+v", result)
}
if strings.Contains(result.ForLLM, "small custom threshold text") {
t.Fatalf("expected text to be omitted from ForLLM, got %q", result.ForLLM)
}
}
func TestMCPTool_Execute_LargeTextArtifactFailureStillOmitsContext(t *testing.T) {
workspaceRoot := t.TempDir()
workspaceFile := filepath.Join(workspaceRoot, "not-a-directory")
if err := os.WriteFile(workspaceFile, []byte("x"), 0o600); err != nil {
t.Fatalf("failed to create workspace file: %v", err)
}
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: largeText},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
mcpTool.SetWorkspace(workspaceFile)
result := mcpTool.Execute(context.Background(), nil)
if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "artifact persistence failed") {
t.Fatalf("expected persistence failure note, got %q", result.ForLLM)
}
if len(result.ArtifactTags) != 0 {
t.Fatalf("expected no artifact tags on persistence failure, got %+v", result.ArtifactTags)
}
}
func TestMCPTool_Execute_WhitespaceWorkspaceDisablesArtifactPersistence(t *testing.T) {
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: largeText},
},
}, nil
},
}
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
mcpTool.SetWorkspace(" \n\t ")
result := mcpTool.Execute(context.Background(), nil)
if len(result.ArtifactTags) != 0 {
t.Fatalf("expected no artifact tags for whitespace workspace, got %+v", result.ArtifactTags)
}
if !strings.Contains(result.ForLLM, "This is a large MCP text payload") {
t.Fatalf("expected large text to remain inline when workspace is blank, got %q", result.ForLLM)
}
}
+143
View File
@@ -0,0 +1,143 @@
package integrationtools
import (
"context"
"fmt"
"sync"
)
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
// sentTarget records the channel+chatID that the message tool sent to.
type sentTarget struct {
Channel string
ChatID string
}
type MessageTool struct {
sendCallback SendCallbackWithContext
mu sync.Mutex
// sentTargets tracks targets sent to in the current round, keyed by session key
// to support parallel turns for different sessions.
sentTargets map[string][]sentTarget
}
func NewMessageTool() *MessageTool {
return &MessageTool{
sentTargets: make(map[string][]sentTarget),
}
}
func (t *MessageTool) Name() string {
return "message"
}
func (t *MessageTool) Description() string {
return "Send a message to user on a chat channel. Use this when you want to communicate something."
}
func (t *MessageTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{
"type": "string",
"description": "The message content to send",
},
"channel": map[string]any{
"type": "string",
"description": "Optional: target channel (telegram, whatsapp, etc.)",
},
"chat_id": map[string]any{
"type": "string",
"description": "Optional: target chat/user ID",
},
"reply_to_message_id": map[string]any{
"type": "string",
"description": "Optional: reply target message ID for channels that support threaded replies",
},
},
"required": []string{"content"},
}
}
// ResetSentInRound resets the per-round send tracker for the given session key.
// Called by the agent loop at the start of each inbound message processing round.
func (t *MessageTool) ResetSentInRound(sessionKey string) {
t.mu.Lock()
defer t.mu.Unlock()
// Delete the key entirely to prevent unbounded map growth over time
// with many unique sessions. Truncating the slice keeps the key alive.
delete(t.sentTargets, sessionKey)
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound(sessionKey string) bool {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.sentTargets[sessionKey]) > 0
}
// HasSentTo returns true if the message tool sent to the specific channel+chatID
// during the current round. Used by PublishResponseIfNeeded to avoid suppressing
// the final response when the message tool only sent to a different conversation.
func (t *MessageTool) HasSentTo(sessionKey, channel, chatID string) bool {
t.mu.Lock()
defer t.mu.Unlock()
for _, st := range t.sentTargets[sessionKey] {
if st.Channel == channel && st.ChatID == chatID {
return true
}
}
return false
}
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
t.sendCallback = callback
}
func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
content, ok := args["content"].(string)
if !ok {
return &ToolResult{ForLLM: "content is required", IsError: true}
}
channel, _ := args["channel"].(string)
chatID, _ := args["chat_id"].(string)
replyToMessageID, _ := args["reply_to_message_id"].(string)
if channel == "" {
channel = ToolChannel(ctx)
}
if chatID == "" {
chatID = ToolChatID(ctx)
}
if channel == "" || chatID == "" {
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
}
if t.sendCallback == nil {
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
Err: err,
}
}
sessionKey := ToolSessionKey(ctx)
t.mu.Lock()
t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID})
t.mu.Unlock()
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
Silent: true,
}
}
+331
View File
@@ -0,0 +1,331 @@
package integrationtools
import (
"context"
"errors"
"testing"
"github.com/sipeed/picoclaw/pkg/session"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
}
return nil
})
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Hello, world!",
}
result := tool.Execute(ctx, args)
// Verify message was sent with correct parameters
if sentChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
}
if sentChatID != "test-chat-id" {
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
}
if sentContent != "Hello, world!" {
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
}
// Verify ToolResult meets US-011 criteria:
// - Send success returns SilentResult (Silent=true)
if !result.Silent {
t.Error("Expected Silent=true for successful send")
}
// - ForLLM contains send status description
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
}
// - ForUser is empty (user already received message directly)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
}
// - IsError should be false
if result.IsError {
t.Error("Expected IsError=false for successful send")
}
}
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
var sentChannel, sentChatID string
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentChannel = channel
sentChatID = chatID
return nil
})
ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id")
args := map[string]any{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
}
result := tool.Execute(ctx, args)
// Verify custom channel/chatID were used instead of defaults
if sentChannel != "custom-channel" {
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
}
if sentChatID != "custom-chat-id" {
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
}
if !result.Silent {
t.Error("Expected Silent=true")
}
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
sendErr := errors.New("network error")
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return sendErr
})
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify ToolResult for send failure:
// - Send failure returns ErrorResult (IsError=true)
if !result.IsError {
t.Error("Expected IsError=true for failed send")
}
// - ForLLM contains error description
expectedErrMsg := "sending message: network error"
if result.ForLLM != expectedErrMsg {
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
}
// - Err field should contain original error
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err != sendErr {
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
}
}
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{} // content missing
result := tool.Execute(ctx, args)
// Verify error result for missing content
if !result.IsError {
t.Error("Expected IsError=true for missing content")
}
if result.ForLLM != "content is required" {
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return nil
})
ctx := context.Background()
args := map[string]any{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when no target channel specified
if !result.IsError {
t.Error("Expected IsError=true when no target channel")
}
if result.ForLLM != "No target channel/chat specified" {
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
// No SetSendCallback called
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when send callback not configured
if !result.IsError {
t.Error("Expected IsError=true when send callback not configured")
}
if result.ForLLM != "Message sending not configured" {
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Name(t *testing.T) {
tool := NewMessageTool()
if tool.Name() != "message" {
t.Errorf("Expected name 'message', got '%s'", tool.Name())
}
}
func TestMessageTool_Description(t *testing.T) {
tool := NewMessageTool()
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
}
func TestMessageTool_Parameters(t *testing.T) {
tool := NewMessageTool()
params := tool.Parameters()
// Verify parameters structure
typ, ok := params["type"].(string)
if !ok || typ != "object" {
t.Error("Expected type 'object'")
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Expected properties to be a map")
}
// Check required properties
required, ok := params["required"].([]string)
if !ok || len(required) != 1 || required[0] != "content" {
t.Error("Expected 'content' to be required")
}
// Check content property
contentProp, ok := props["content"].(map[string]any)
if !ok {
t.Error("Expected 'content' property")
}
if contentProp["type"] != "string" {
t.Error("Expected content type to be 'string'")
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]any)
if !ok {
t.Error("Expected 'channel' property")
}
if channelProp["type"] != "string" {
t.Error("Expected channel type to be 'string'")
}
// Check chat_id property (optional)
chatIDProp, ok := props["chat_id"].(map[string]any)
if !ok {
t.Error("Expected 'chat_id' property")
}
if chatIDProp["type"] != "string" {
t.Error("Expected chat_id type to be 'string'")
}
// Check reply_to_message_id property (optional)
replyToProp, ok := props["reply_to_message_id"].(map[string]any)
if !ok {
t.Error("Expected 'reply_to_message_id' property")
}
if replyToProp["type"] != "string" {
t.Error("Expected reply_to_message_id type to be 'string'")
}
}
func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
tool := NewMessageTool()
var sentReplyTo string
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
sentReplyTo = replyToMessageID
return nil
})
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Reply test",
"reply_to_message_id": "msg-123",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if sentReplyTo != "msg-123" {
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
}
}
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
tool := NewMessageTool()
var gotAgentID, gotSessionKey string
var gotScope *session.SessionScope
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
gotAgentID = ToolAgentID(ctx)
gotSessionKey = ToolSessionKey(ctx)
gotScope = ToolSessionScope(ctx)
return nil
})
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "direct:test-chat-id",
},
})
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if gotAgentID != "main" {
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
}
if gotSessionKey != "sk_v1_tool" {
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
}
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
}
}
+87
View File
@@ -0,0 +1,87 @@
package integrationtools
import (
"context"
"fmt"
)
type ReactionCallback func(ctx context.Context, channel, chatID, messageID string) error
type ReactionTool struct {
reactionCallback ReactionCallback
}
func NewReactionTool() *ReactionTool {
return &ReactionTool{}
}
func (t *ReactionTool) Name() string {
return "reaction"
}
func (t *ReactionTool) Description() string {
return "Add a reaction to a message. Defaults to the current inbound message when message_id is omitted."
}
func (t *ReactionTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"message_id": map[string]any{
"type": "string",
"description": "Optional: target message ID; defaults to the current inbound message",
},
"channel": map[string]any{
"type": "string",
"description": "Optional: target channel (telegram, whatsapp, etc.)",
},
"chat_id": map[string]any{
"type": "string",
"description": "Optional: target chat/user ID",
},
},
}
}
func (t *ReactionTool) SetReactionCallback(callback ReactionCallback) {
t.reactionCallback = callback
}
func (t *ReactionTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
channel, _ := args["channel"].(string)
chatID, _ := args["chat_id"].(string)
messageID, _ := args["message_id"].(string)
if channel == "" {
channel = ToolChannel(ctx)
}
if chatID == "" {
chatID = ToolChatID(ctx)
}
if messageID == "" {
messageID = ToolMessageID(ctx)
}
if channel == "" || chatID == "" {
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
}
if messageID == "" {
return &ToolResult{ForLLM: "message_id is required", IsError: true}
}
if t.reactionCallback == nil {
return &ToolResult{ForLLM: "Reaction not configured", IsError: true}
}
if err := t.reactionCallback(ctx, channel, chatID, messageID); err != nil {
return &ToolResult{
ForLLM: fmt.Sprintf("adding reaction: %v", err),
IsError: true,
Err: err,
}
}
return &ToolResult{
ForLLM: fmt.Sprintf("Reaction added to %s:%s message %s", channel, chatID, messageID),
Silent: true,
}
}
+96
View File
@@ -0,0 +1,96 @@
package integrationtools
import (
"context"
"errors"
"testing"
)
func TestReactionTool_Execute_UsesContextMessageIDByDefault(t *testing.T) {
tool := NewReactionTool()
var gotChannel, gotChatID, gotMessageID string
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error {
gotChannel = channel
gotChatID = chatID
gotMessageID = messageID
return nil
})
ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-100", "")
result := tool.Execute(ctx, map[string]any{})
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if gotChannel != "telegram" || gotChatID != "chat-1" || gotMessageID != "msg-100" {
t.Fatalf("unexpected callback args: channel=%q chatID=%q messageID=%q", gotChannel, gotChatID, gotMessageID)
}
}
func TestReactionTool_Execute_AllowsExplicitMessageIDOverride(t *testing.T) {
tool := NewReactionTool()
var gotMessageID string
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error {
gotMessageID = messageID
return nil
})
ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-context", "")
result := tool.Execute(ctx, map[string]any{"message_id": "msg-explicit"})
if result.IsError {
t.Fatalf("expected success, got error: %s", result.ForLLM)
}
if gotMessageID != "msg-explicit" {
t.Fatalf("expected explicit message id, got %q", gotMessageID)
}
}
func TestReactionTool_Execute_MissingMessageID(t *testing.T) {
tool := NewReactionTool()
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error { return nil })
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{})
if !result.IsError {
t.Fatal("expected error")
}
if result.ForLLM != "message_id is required" {
t.Fatalf("unexpected error message: %q", result.ForLLM)
}
}
func TestReactionTool_Execute_CallbackError(t *testing.T) {
tool := NewReactionTool()
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error {
return errors.New("unsupported")
})
ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-100", "")
result := tool.Execute(ctx, map[string]any{})
if !result.IsError {
t.Fatal("expected error")
}
if result.Err == nil {
t.Fatal("expected wrapped error")
}
}
func TestReactionTool_Parameters(t *testing.T) {
tool := NewReactionTool()
params := tool.Parameters()
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties map")
}
if _, ok := props["message_id"]; !ok {
t.Fatal("expected message_id parameter")
}
if _, ok := props["channel"]; !ok {
t.Fatal("expected channel parameter")
}
if _, ok := props["chat_id"]; !ok {
t.Fatal("expected chat_id parameter")
}
}
+77
View File
@@ -0,0 +1,77 @@
package integrationtools
import (
"context"
"github.com/sipeed/picoclaw/pkg/session"
toolshared "github.com/sipeed/picoclaw/pkg/tools/shared"
)
type (
Tool = toolshared.Tool
ToolResult = toolshared.ToolResult
AsyncCallback = toolshared.AsyncCallback
)
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
return toolshared.WithToolContext(ctx, channel, chatID)
}
func WithToolInboundContext(
ctx context.Context,
channel, chatID, messageID, replyToMessageID string,
) context.Context {
return toolshared.WithToolInboundContext(ctx, channel, chatID, messageID, replyToMessageID)
}
func WithToolSessionContext(
ctx context.Context,
agentID, sessionKey string,
scope *session.SessionScope,
) context.Context {
return toolshared.WithToolSessionContext(ctx, agentID, sessionKey, scope)
}
func ToolChannel(ctx context.Context) string {
return toolshared.ToolChannel(ctx)
}
func ToolChatID(ctx context.Context) string {
return toolshared.ToolChatID(ctx)
}
func ToolMessageID(ctx context.Context) string {
return toolshared.ToolMessageID(ctx)
}
func ToolAgentID(ctx context.Context) string {
return toolshared.ToolAgentID(ctx)
}
func ToolSessionKey(ctx context.Context) string {
return toolshared.ToolSessionKey(ctx)
}
func ToolSessionScope(ctx context.Context) *session.SessionScope {
return toolshared.ToolSessionScope(ctx)
}
func ErrorResult(message string) *ToolResult {
return toolshared.ErrorResult(message)
}
func SilentResult(forLLM string) *ToolResult {
return toolshared.SilentResult(forLLM)
}
func NewToolResult(forLLM string) *ToolResult {
return toolshared.NewToolResult(forLLM)
}
func UserResult(content string) *ToolResult {
return toolshared.UserResult(content)
}
func MediaResult(forLLM string, mediaRefs []string) *ToolResult {
return toolshared.MediaResult(forLLM, mediaRefs)
}
+308
View File
@@ -0,0 +1,308 @@
package integrationtools
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/utils"
)
const defaultSkillRegistryName = "github"
var persistInstalledSkillOriginMeta = writeOriginMeta
// InstallSkillTool allows the LLM agent to install skills from registries.
// It shares the same RegistryManager that FindSkillsTool uses,
// so all registries configured in config are available for installation.
type InstallSkillTool struct {
registryMgr *skills.RegistryManager
workspace string
mu sync.Mutex
}
// NewInstallSkillTool creates a new InstallSkillTool.
// registryMgr is the shared registry manager (same instance as FindSkillsTool).
// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/.
func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool {
return &InstallSkillTool{
registryMgr: registryMgr,
workspace: workspace,
mu: sync.Mutex{},
}
}
func (t *InstallSkillTool) Name() string {
return "install_skill"
}
func (t *InstallSkillTool) Description() string {
return "Install a skill from a registry by slug. Defaults to GitHub when registry is omitted. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills."
}
func (t *InstallSkillTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"slug": map[string]any{
"type": "string",
"description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')",
},
"version": map[string]any{
"type": "string",
"description": "Specific version to install (optional, defaults to latest)",
},
"registry": map[string]any{
"type": "string",
"description": "Registry to install from (optional, defaults to 'github')",
},
"force": map[string]any{
"type": "boolean",
"description": "Force reinstall if skill already exists (default false)",
},
},
"required": []string{"slug"},
}
}
func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
// Install lock to prevent concurrent directory operations.
// Ideally this should be done at a `slug` level, currently, its at a `workspace` level.
t.mu.Lock()
defer t.mu.Unlock()
slug, _ := args["slug"].(string)
if strings.TrimSpace(slug) == "" {
return ErrorResult("identifier is required and must be a non-empty string")
}
// Validate registry
registryName, _ := args["registry"].(string)
if registryName == "" {
registryName = defaultSkillRegistryName
}
if err := utils.ValidateSkillIdentifier(registryName); err != nil {
return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error()))
}
// Resolve which registry to use.
registry := t.registryMgr.GetRegistry(registryName)
if registry == nil {
return ErrorResult(fmt.Sprintf("registry %q not found", registryName))
}
// Validate target and resolve install directory.
dirName, err := registry.ResolveInstallDirName(slug)
if err != nil {
return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error()))
}
version, _ := args["version"].(string)
force, _ := args["force"].(bool)
// Check if already installed.
skillsDir := filepath.Join(t.workspace, "skills")
targetDir := filepath.Join(skillsDir, dirName)
backupDir := ""
restorePreviousInstall := func() {
if backupDir == "" {
return
}
if rmErr := os.RemoveAll(targetDir); rmErr != nil {
logger.ErrorCF("tool", "Failed to remove failed install before restore",
map[string]any{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
return
}
if restoreErr := os.Rename(backupDir, targetDir); restoreErr != nil {
logger.ErrorCF("tool", "Failed to restore previous install after failed reinstall",
map[string]any{
"tool": "install_skill",
"backup_dir": backupDir,
"target_dir": targetDir,
"error": restoreErr.Error(),
})
return
}
backupDir = ""
}
if !force {
if _, statErr := os.Stat(targetDir); statErr == nil {
return ErrorResult(
fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
)
}
} else {
if _, statErr := os.Stat(targetDir); statErr == nil {
backupDir = filepath.Join(skillsDir, fmt.Sprintf(".%s.picoclaw-backup-%d", dirName, time.Now().UnixNano()))
if renameErr := os.Rename(targetDir, backupDir); renameErr != nil {
return ErrorResult(fmt.Sprintf("failed to prepare reinstall for %q: %v", slug, renameErr))
}
} else if !os.IsNotExist(statErr) {
return ErrorResult(fmt.Sprintf("failed to inspect existing install for %q: %v", slug, statErr))
}
}
// Ensure skills directory exists.
if mkdirErr := os.MkdirAll(skillsDir, 0o755); mkdirErr != nil {
restorePreviousInstall()
return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", mkdirErr))
}
// Download and install (handles metadata, version resolution, extraction).
result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir)
if err != nil {
// Clean up partial install.
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to remove partial install",
map[string]any{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
restorePreviousInstall()
return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err))
}
// Moderation: block malware.
if result.IsMalwareBlocked {
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to remove partial install",
map[string]any{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
restorePreviousInstall()
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
}
if !workspaceHasValidInstalledSkill(t.workspace, dirName) {
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to remove invalid installed skill",
map[string]any{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
restorePreviousInstall()
return ErrorResult(fmt.Sprintf("failed to install %q: registry archive is not a valid skill", slug))
}
// Write origin metadata.
if err := persistInstalledSkillOriginMeta(targetDir, registry, slug, result.Version); err != nil {
logger.ErrorCF("tool", "Failed to write origin metadata",
map[string]any{
"tool": "install_skill",
"error": err.Error(),
"target": targetDir,
"registry": registry.Name(),
"slug": slug,
"version": result.Version,
})
rmErr := os.RemoveAll(targetDir)
if rmErr != nil {
logger.ErrorCF("tool", "Failed to roll back install after metadata write failure",
map[string]any{
"tool": "install_skill",
"target_dir": targetDir,
"error": rmErr.Error(),
})
}
restorePreviousInstall()
return ErrorResult(fmt.Sprintf("failed to persist skill metadata for %q: %v", slug, err))
}
if backupDir != "" {
if rmErr := os.RemoveAll(backupDir); rmErr != nil {
logger.ErrorCF("tool", "Failed to remove previous install backup after successful reinstall",
map[string]any{
"tool": "install_skill",
"backup_dir": backupDir,
"error": rmErr.Error(),
})
}
}
// Build result with moderation warning if suspicious.
var output string
if result.IsSuspicious {
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
}
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
slug, result.Version, registry.Name(), targetDir)
if result.Summary != "" {
output += fmt.Sprintf("Description: %s\n", result.Summary)
}
output += "\nThe skill is now available and can be loaded in the current session."
return SilentResult(output)
}
// originMeta tracks which registry a skill was installed from.
type originMeta struct {
Version int `json:"version"`
OriginKind string `json:"origin_kind,omitempty"`
Registry string `json:"registry"`
Slug string `json:"slug"`
RegistryURL string `json:"registry_url,omitempty"`
InstalledVersion string `json:"installed_version"`
InstalledAt int64 `json:"installed_at"`
}
func writeOriginMeta(targetDir string, registry skills.SkillRegistry, slug, version string) error {
normalizedSlug, registryURL := skills.BuildInstallMetadataForRegistryInstance(registry, slug, version)
registryName := ""
if registry != nil {
registryName = registry.Name()
}
meta := originMeta{
Version: 1,
OriginKind: "third_party",
Registry: registryName,
Slug: normalizedSlug,
RegistryURL: registryURL,
InstalledVersion: version,
InstalledAt: time.Now().UnixMilli(),
}
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return err
}
// Use unified atomic write utility with explicit sync for flash storage reliability.
return fileutil.WriteFileAtomic(filepath.Join(targetDir, ".skill-origin.json"), data, 0o600)
}
func workspaceHasValidInstalledSkill(workspace, directory string) bool {
loader := skills.NewSkillsLoader(workspace, "", "")
for _, skill := range loader.ListSkills() {
if skill.Source != "workspace" {
continue
}
if filepath.Base(filepath.Dir(skill.Path)) == directory {
return true
}
}
return false
}
@@ -0,0 +1,423 @@
package integrationtools
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/skills"
)
type mockInstallRegistry struct{}
const validSkillMarkdown = "---\nname: pr-review\ndescription: Review pull requests\n---\n# PR Review\n"
func (m *mockInstallRegistry) Name() string { return "clawhub" }
func (m *mockInstallRegistry) ResolveInstallDirName(target string) (string, error) {
return target, nil
}
func (m *mockInstallRegistry) SkillURL(slug, _ string) string { return slug }
func (m *mockInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
return nil, nil
}
func (m *mockInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
return nil, nil
}
func (m *mockInstallRegistry) DownloadAndInstall(
_ context.Context,
_ string,
_ string,
targetDir string,
) (*skills.InstallResult, error) {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return nil, err
}
if err := os.WriteFile(filepath.Join(targetDir, "SKILL.md"), []byte(validSkillMarkdown), 0o600); err != nil {
return nil, err
}
return &skills.InstallResult{Version: "test"}, nil
}
type mockGitHubInstallRegistry struct{}
func (m *mockGitHubInstallRegistry) Name() string { return "github" }
func (m *mockGitHubInstallRegistry) ResolveInstallDirName(target string) (string, error) {
return "pr-review", nil
}
func (m *mockGitHubInstallRegistry) SkillURL(slug, _ string) string { return slug }
func (m *mockGitHubInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
return nil, nil
}
func (m *mockGitHubInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
return nil, nil
}
func (m *mockGitHubInstallRegistry) DownloadAndInstall(
_ context.Context,
_ string,
_ string,
targetDir string,
) (*skills.InstallResult, error) {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return nil, err
}
if err := os.WriteFile(filepath.Join(targetDir, "SKILL.md"), []byte(validSkillMarkdown), 0o600); err != nil {
return nil, err
}
return &skills.InstallResult{Version: "main"}, nil
}
type stubGitHubInstallRegistry struct {
*skills.GitHubRegistry
}
func (m *stubGitHubInstallRegistry) DownloadAndInstall(
_ context.Context,
_ string,
_ string,
targetDir string,
) (*skills.InstallResult, error) {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return nil, err
}
if err := os.WriteFile(filepath.Join(targetDir, "SKILL.md"), []byte(validSkillMarkdown), 0o600); err != nil {
return nil, err
}
return &skills.InstallResult{Version: "main"}, nil
}
type mockInvalidInstallRegistry struct{}
type mockFailingInstallRegistry struct{}
func (m *mockInvalidInstallRegistry) Name() string { return "clawhub" }
func (m *mockInvalidInstallRegistry) ResolveInstallDirName(target string) (string, error) {
return target, nil
}
func (m *mockInvalidInstallRegistry) SkillURL(slug, _ string) string { return slug }
func (m *mockInvalidInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
return nil, nil
}
func (m *mockInvalidInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
return nil, nil
}
func (m *mockInvalidInstallRegistry) DownloadAndInstall(
_ context.Context,
_ string,
_ string,
targetDir string,
) (*skills.InstallResult, error) {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return nil, err
}
if err := os.WriteFile(
filepath.Join(targetDir, "SKILL.md"),
[]byte("---\nname: bad_skill\ndescription: invalid name\n---\n# Invalid\n"),
0o600,
); err != nil {
return nil, err
}
return &skills.InstallResult{Version: "test"}, nil
}
func (m *mockFailingInstallRegistry) Name() string { return "clawhub" }
func (m *mockFailingInstallRegistry) ResolveInstallDirName(target string) (string, error) {
return target, nil
}
func (m *mockFailingInstallRegistry) SkillURL(slug, _ string) string { return slug }
func (m *mockFailingInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
return nil, nil
}
func (m *mockFailingInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
return nil, nil
}
func (m *mockFailingInstallRegistry) DownloadAndInstall(
_ context.Context,
_ string,
_ string,
_ string,
) (*skills.InstallResult, error) {
return nil, assert.AnError
}
func TestInstallSkillToolName(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
assert.Equal(t, "install_skill", tool.Name())
}
func TestInstallSkillToolMissingSlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]any{})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
}
func TestInstallSkillToolEmptySlug(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
result := tool.Execute(context.Background(), map[string]any{
"slug": " ",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
}
func TestInstallSkillToolUnsafeSlug(t *testing.T) {
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(skills.NewClawHubRegistry(skills.ClawHubConfig{Enabled: true}))
tool := NewInstallSkillTool(registryMgr, t.TempDir())
cases := []string{
"../etc/passwd",
"path/traversal",
"path\\traversal",
}
for _, slug := range cases {
result := tool.Execute(context.Background(), map[string]any{
"slug": slug,
"registry": "clawhub",
})
assert.True(t, result.IsError, "slug %q should be rejected", slug)
assert.Contains(t, result.ForLLM, "invalid slug")
}
}
func TestInstallSkillToolAlreadyExists(t *testing.T) {
workspace := t.TempDir()
skillDir := filepath.Join(workspace, "skills", "existing-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&mockInstallRegistry{})
tool := NewInstallSkillTool(registryMgr, workspace)
result := tool.Execute(context.Background(), map[string]any{
"slug": "existing-skill",
"registry": "clawhub",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "already installed")
}
func TestInstallSkillToolRegistryNotFound(t *testing.T) {
workspace := t.TempDir()
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
result := tool.Execute(context.Background(), map[string]any{
"slug": "some-skill",
"registry": "nonexistent",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "registry")
assert.Contains(t, result.ForLLM, "not found")
}
func TestInstallSkillToolParameters(t *testing.T) {
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
params := tool.Parameters()
props, ok := params["properties"].(map[string]any)
assert.True(t, ok)
assert.Contains(t, props, "slug")
assert.Contains(t, props, "version")
assert.Contains(t, props, "registry")
assert.Contains(t, props, "force")
required, ok := params["required"].([]string)
assert.True(t, ok)
assert.Contains(t, required, "slug")
assert.NotContains(t, required, "registry")
}
func TestInstallSkillToolMissingRegistry(t *testing.T) {
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&mockGitHubInstallRegistry{})
tool := NewInstallSkillTool(registryMgr, t.TempDir())
result := tool.Execute(context.Background(), map[string]any{
"slug": "some-skill",
})
assert.False(t, result.IsError)
assert.Contains(t, result.ForLLM, `Successfully installed skill`)
}
func TestInstallSkillToolAllowsGitHubURLSlug(t *testing.T) {
registry := skills.GitHubRegistryConfig{Enabled: true, BaseURL: "https://github.com"}.BuildRegistry()
githubRegistry, ok := registry.(*skills.GitHubRegistry)
require.True(t, ok)
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&stubGitHubInstallRegistry{GitHubRegistry: githubRegistry})
workspace := t.TempDir()
tool := NewInstallSkillTool(registryMgr, workspace)
slug := "https://github.com/synthetic-lab/octofriend/tree/main/.agents/skills/pr-review"
result := tool.Execute(context.Background(), map[string]any{
"slug": slug,
"registry": "github",
})
assert.False(t, result.IsError)
assert.Contains(t, result.ForLLM, `Successfully installed skill`)
data, err := os.ReadFile(filepath.Join(workspace, "skills", "pr-review", ".skill-origin.json"))
require.NoError(t, err)
var meta originMeta
require.NoError(t, json.Unmarshal(data, &meta))
assert.Equal(t, "third_party", meta.OriginKind)
assert.Equal(t, "github", meta.Registry)
assert.Equal(t, "synthetic-lab/octofriend/.agents/skills/pr-review", meta.Slug)
assert.Equal(t, slug, meta.RegistryURL)
assert.Equal(t, "main", meta.InstalledVersion)
assert.NotZero(t, meta.InstalledAt)
}
func TestInstallSkillToolPreservesGitHubSourceURLWithEnterpriseRegistry(t *testing.T) {
registry := skills.GitHubRegistryConfig{Enabled: true, BaseURL: "https://ghe.example.com/git"}.BuildRegistry()
githubRegistry, ok := registry.(*skills.GitHubRegistry)
require.True(t, ok)
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&stubGitHubInstallRegistry{GitHubRegistry: githubRegistry})
workspace := t.TempDir()
tool := NewInstallSkillTool(registryMgr, workspace)
slug := "https://github.com/synthetic-lab/octofriend/tree/main/.agents/skills/pr-review"
result := tool.Execute(context.Background(), map[string]any{
"slug": slug,
"registry": "github",
})
assert.False(t, result.IsError)
data, err := os.ReadFile(filepath.Join(workspace, "skills", "pr-review", ".skill-origin.json"))
require.NoError(t, err)
var meta originMeta
require.NoError(t, json.Unmarshal(data, &meta))
assert.Equal(t, "synthetic-lab/octofriend/.agents/skills/pr-review", meta.Slug)
assert.Equal(t, slug, meta.RegistryURL)
assert.Equal(t, "main", meta.InstalledVersion)
}
func TestInstallSkillToolRejectsInvalidInstalledSkill(t *testing.T) {
workspace := t.TempDir()
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&mockInvalidInstallRegistry{})
tool := NewInstallSkillTool(registryMgr, workspace)
result := tool.Execute(context.Background(), map[string]any{
"slug": "broken-skill",
"registry": "clawhub",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "not a valid skill")
_, err := os.Stat(filepath.Join(workspace, "skills", "broken-skill"))
assert.True(t, os.IsNotExist(err))
}
func TestInstallSkillToolRollsBackOnOriginMetadataWriteFailure(t *testing.T) {
workspace := t.TempDir()
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&mockInstallRegistry{})
tool := NewInstallSkillTool(registryMgr, workspace)
previousPersist := persistInstalledSkillOriginMeta
persistInstalledSkillOriginMeta = func(string, skills.SkillRegistry, string, string) error {
return assert.AnError
}
defer func() {
persistInstalledSkillOriginMeta = previousPersist
}()
result := tool.Execute(context.Background(), map[string]any{
"slug": "rollback-skill",
"registry": "clawhub",
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "failed to persist skill metadata")
_, err := os.Stat(filepath.Join(workspace, "skills", "rollback-skill"))
assert.True(t, os.IsNotExist(err))
}
func TestInstallSkillToolForceReinstallRestoresPreviousSkillAfterDownloadFailure(t *testing.T) {
workspace := t.TempDir()
skillDir := filepath.Join(workspace, "skills", "existing-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
oldContent := []byte("---\nname: existing-skill\ndescription: Existing skill\n---\n# Existing\n")
require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), oldContent, 0o600))
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&mockFailingInstallRegistry{})
tool := NewInstallSkillTool(registryMgr, workspace)
result := tool.Execute(context.Background(), map[string]any{
"slug": "existing-skill",
"registry": "clawhub",
"force": true,
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "failed to install")
gotContent, err := os.ReadFile(filepath.Join(skillDir, "SKILL.md"))
require.NoError(t, err)
assert.Equal(t, oldContent, gotContent)
}
func TestInstallSkillToolForceReinstallRestoresPreviousSkillAfterMetadataFailure(t *testing.T) {
workspace := t.TempDir()
skillDir := filepath.Join(workspace, "skills", "existing-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
oldContent := []byte("---\nname: existing-skill\ndescription: Existing skill\n---\n# Existing\n")
require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), oldContent, 0o600))
registryMgr := skills.NewRegistryManager()
registryMgr.AddRegistry(&mockInstallRegistry{})
tool := NewInstallSkillTool(registryMgr, workspace)
previousPersist := persistInstalledSkillOriginMeta
persistInstalledSkillOriginMeta = func(string, skills.SkillRegistry, string, string) error {
return assert.AnError
}
defer func() {
persistInstalledSkillOriginMeta = previousPersist
}()
result := tool.Execute(context.Background(), map[string]any{
"slug": "existing-skill",
"registry": "clawhub",
"force": true,
})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "failed to persist skill metadata")
gotContent, err := os.ReadFile(filepath.Join(skillDir, "SKILL.md"))
require.NoError(t, err)
assert.Equal(t, oldContent, gotContent)
}
+119
View File
@@ -0,0 +1,119 @@
package integrationtools
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/skills"
)
// FindSkillsTool allows the LLM agent to search for installable skills from registries.
type FindSkillsTool struct {
registryMgr *skills.RegistryManager
cache *skills.SearchCache
}
// NewFindSkillsTool creates a new FindSkillsTool.
// registryMgr is the shared registry manager (built from config in createToolRegistry).
// cache is the search cache for deduplicating similar queries.
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
return &FindSkillsTool{
registryMgr: registryMgr,
cache: cache,
}
}
func (t *FindSkillsTool) Name() string {
return "find_skills"
}
func (t *FindSkillsTool) Description() string {
return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill."
}
func (t *FindSkillsTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')",
},
"limit": map[string]any{
"type": "integer",
"description": "Maximum number of results to return (1-20, default 5)",
"minimum": 1.0,
"maximum": 20.0,
},
},
"required": []string{"query"},
}
}
func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
query, ok := args["query"].(string)
query = strings.ToLower(strings.TrimSpace(query))
if !ok || query == "" {
return ErrorResult("query is required and must be a non-empty string")
}
limit := 5
if l, ok := args["limit"].(float64); ok {
li := int(l)
if li >= 1 && li <= 20 {
limit = li
}
}
// Check cache first.
if t.cache != nil {
if cached, hit := t.cache.Get(query); hit {
return SilentResult(formatSearchResults(query, cached, true))
}
}
// Search all registries.
results, err := t.registryMgr.SearchAll(ctx, query, limit)
if err != nil {
return ErrorResult(fmt.Sprintf("skill search failed: %v", err))
}
// Cache the results.
if t.cache != nil && len(results) > 0 {
t.cache.Put(query, results)
}
return SilentResult(formatSearchResults(query, results, false))
}
func formatSearchResults(query string, results []skills.SearchResult, cached bool) string {
if len(results) == 0 {
return fmt.Sprintf("No skills found for query: %q", query)
}
var sb strings.Builder
source := ""
if cached {
source = " (cached)"
}
sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source))
for i, r := range results {
sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug))
if r.Version != "" {
sb.WriteString(fmt.Sprintf(" v%s", r.Version))
}
sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName))
if r.DisplayName != "" && r.DisplayName != r.Slug {
sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName))
}
if r.Summary != "" {
sb.WriteString(fmt.Sprintf(" %s\n", r.Summary))
}
sb.WriteString("\n")
}
sb.WriteString("Use install_skill with the slug to install a skill.")
return sb.String()
}
@@ -0,0 +1,90 @@
package integrationtools
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/sipeed/picoclaw/pkg/skills"
)
func TestFindSkillsToolName(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
assert.Equal(t, "find_skills", tool.Name())
}
func TestFindSkillsToolMissingQuery(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
result := tool.Execute(context.Background(), map[string]any{})
assert.True(t, result.IsError)
assert.Contains(t, result.ForLLM, "query is required")
}
func TestFindSkillsToolEmptyQuery(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
result := tool.Execute(context.Background(), map[string]any{
"query": " ",
})
assert.True(t, result.IsError)
}
func TestFindSkillsToolCacheHit(t *testing.T) {
cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min
cache.Put("github", []skills.SearchResult{
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
})
tool := NewFindSkillsTool(skills.NewRegistryManager(), cache)
result := tool.Execute(context.Background(), map[string]any{
"query": "github",
})
assert.False(t, result.IsError)
assert.Contains(t, result.ForLLM, "github")
assert.Contains(t, result.ForLLM, "cached")
}
func TestFindSkillsToolParameters(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
params := tool.Parameters()
props, ok := params["properties"].(map[string]any)
assert.True(t, ok)
assert.Contains(t, props, "query")
assert.Contains(t, props, "limit")
required, ok := params["required"].([]string)
assert.True(t, ok)
assert.Contains(t, required, "query")
}
func TestFindSkillsToolDescription(t *testing.T) {
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
assert.NotEmpty(t, tool.Description())
assert.Contains(t, tool.Description(), "skill")
}
func TestFormatSearchResultsEmpty(t *testing.T) {
result := formatSearchResults("test query", nil, false)
assert.Contains(t, result, "No skills found")
}
func TestFormatSearchResultsWithData(t *testing.T) {
results := []skills.SearchResult{
{
Slug: "github",
Score: 0.95,
DisplayName: "GitHub",
Summary: "GitHub API integration",
Version: "1.0.0",
RegistryName: "clawhub",
},
}
output := formatSearchResults("github", results, false)
assert.Contains(t, output, "github")
assert.Contains(t, output, "v1.0.0")
assert.Contains(t, output, "0.950")
assert.Contains(t, output, "clawhub")
assert.Contains(t, output, "install_skill")
}
+82
View File
@@ -0,0 +1,82 @@
package integrationtools
import (
"context"
"strings"
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/media"
)
type SendTTSTool struct {
provider tts.TTSProvider
mediaStore media.MediaStore
}
func NewSendTTSTool(provider tts.TTSProvider, store media.MediaStore) *SendTTSTool {
return &SendTTSTool{
provider: provider,
mediaStore: store,
}
}
func (t *SendTTSTool) Name() string { return "send_tts" }
func (t *SendTTSTool) Description() string {
return "Synthesize speech from text and send it as an audio file to the user."
}
func (t *SendTTSTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"text": map[string]any{
"type": "string",
"description": "The text to synthesize into speech. NOTE: Reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally.",
},
"filename": map[string]any{
"type": "string",
"description": "Optional filename for the audio file (e.g., response.ogg).",
},
},
"required": []string{"text"},
}
}
func (t *SendTTSTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
func (t *SendTTSTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
text, _ := args["text"].(string)
text = strings.TrimSpace(text)
if text == "" {
return ErrorResult("text is required")
}
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
filename, _ := args["filename"].(string)
ref, err := tts.SynthesizeAndStore(
ctx,
t.provider,
t.mediaStore,
text,
filename,
channel,
chatID,
)
if err != nil {
return ErrorResult(err.Error()).WithError(err)
}
// Return with ForUser set to original text, Media containing the audio ref,
// and mark as ResponseHandled so the audio is sent immediately without LLM intervention.
return &ToolResult{
ForLLM: "TTS audio sent",
ForUser: text,
Media: []string{ref},
ResponseHandled: true,
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff