mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(tools): reorganize tool packages and facades
This commit is contained in:
@@ -0,0 +1,134 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"mime"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
inlineMarkdownDataURLRe = regexp.MustCompile(`!\[[^\]]*\]\((data:[^)]+)\)`)
|
||||
inlineRawDataURLRe = regexp.MustCompile(`data:[^;\s]+;base64,[A-Za-z0-9+/=\r\n]+`)
|
||||
)
|
||||
|
||||
const (
|
||||
largeBase64OmittedMessage = "[Tool returned a large base64-like payload; omitted from model context.]"
|
||||
inlineMediaOmittedMessage = "[Tool returned inline media content; omitted from model context.]"
|
||||
)
|
||||
|
||||
func sanitizeToolLLMContent(text string) string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return text
|
||||
}
|
||||
if inlineMarkdownDataURLRe.MatchString(trimmed) || inlineRawDataURLRe.MatchString(trimmed) {
|
||||
cleaned := inlineMarkdownDataURLRe.ReplaceAllString(trimmed, "")
|
||||
cleaned = inlineRawDataURLRe.ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
if cleaned == "" {
|
||||
return inlineMediaOmittedMessage
|
||||
}
|
||||
return cleaned + "\n" + inlineMediaOmittedMessage
|
||||
}
|
||||
if looksLikeLargeBase64Payload(trimmed) {
|
||||
return largeBase64OmittedMessage
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func looksLikeLargeBase64Payload(text string) bool {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if len(trimmed) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
nonSpace := 0
|
||||
base64Like := 0
|
||||
spaceCount := 0
|
||||
|
||||
for _, r := range trimmed {
|
||||
if unicode.IsSpace(r) {
|
||||
spaceCount++
|
||||
continue
|
||||
}
|
||||
nonSpace++
|
||||
if (r >= 'A' && r <= 'Z') ||
|
||||
(r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '+' || r == '/' || r == '=' {
|
||||
base64Like++
|
||||
}
|
||||
}
|
||||
|
||||
if nonSpace == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
ratio := float64(base64Like) / float64(nonSpace)
|
||||
return ratio >= 0.97 && spaceCount <= len(trimmed)/128
|
||||
}
|
||||
|
||||
func extensionForMIMEType(mimeType string) string {
|
||||
if mimeType == "" {
|
||||
return ".bin"
|
||||
}
|
||||
if exts, err := mime.ExtensionsByType(mimeType); err == nil && len(exts) > 0 {
|
||||
return exts[0]
|
||||
}
|
||||
|
||||
switch strings.ToLower(mimeType) {
|
||||
case "image/jpeg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "audio/wav", "audio/x-wav":
|
||||
return ".wav"
|
||||
case "audio/mpeg":
|
||||
return ".mp3"
|
||||
case "audio/ogg":
|
||||
return ".ogg"
|
||||
case "video/mp4":
|
||||
return ".mp4"
|
||||
default:
|
||||
return filepath.Ext(mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
func getInt64Arg(args map[string]any, key string, defaultVal int64) (int64, error) {
|
||||
raw, exists := args[key]
|
||||
if !exists {
|
||||
return defaultVal, nil
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
if v != math.Trunc(v) {
|
||||
return 0, fmt.Errorf("%s must be an integer, got float %v", key, v)
|
||||
}
|
||||
if v > math.MaxInt64 || v < math.MinInt64 {
|
||||
return 0, fmt.Errorf("%s value %v overflows int64", key, v)
|
||||
}
|
||||
return int64(v), nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case int64:
|
||||
return v, nil
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid integer format for %s parameter: %w", key, err)
|
||||
}
|
||||
return parsed, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported type %T for %s parameter", raw, key)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,601 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// MCPManager defines the interface for MCP manager operations
|
||||
// This allows for easier testing with mock implementations
|
||||
type MCPManager interface {
|
||||
CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
// MCPTool wraps an MCP tool to implement the Tool interface
|
||||
type MCPTool struct {
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
mediaStore media.MediaStore
|
||||
workspace string
|
||||
maxInlineTextRunes int
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
maxInlineTextRunes: maxMCPInlineTextRunes,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
func (t *MCPTool) SetWorkspace(workspace string) {
|
||||
t.workspace = strings.TrimSpace(workspace)
|
||||
}
|
||||
|
||||
func (t *MCPTool) SetMaxInlineTextRunes(limit int) {
|
||||
if limit > 0 {
|
||||
t.maxInlineTextRunes = limit
|
||||
}
|
||||
}
|
||||
|
||||
const maxMCPInlineTextRunes = 16 * 1024
|
||||
|
||||
// sanitizeIdentifierComponent normalizes a string so it can be safely used
|
||||
// as part of a tool/function identifier for downstream providers.
|
||||
// It:
|
||||
// - lowercases the string
|
||||
// - replaces any character not in [a-z0-9_-] with '_'
|
||||
// - collapses multiple consecutive '_' into a single '_'
|
||||
// - trims leading/trailing '_'
|
||||
// - falls back to "unnamed" if the result is empty
|
||||
// - truncates overly long components to a reasonable length
|
||||
func sanitizeIdentifierComponent(s string) string {
|
||||
const maxLen = 64
|
||||
|
||||
s = strings.ToLower(s)
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
prevUnderscore := false
|
||||
for _, r := range s {
|
||||
isAllowed := (r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' || r == '-'
|
||||
|
||||
if !isAllowed {
|
||||
// Normalize any disallowed character to '_'
|
||||
if !prevUnderscore {
|
||||
b.WriteRune('_')
|
||||
prevUnderscore = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if r == '_' {
|
||||
if prevUnderscore {
|
||||
continue
|
||||
}
|
||||
prevUnderscore = true
|
||||
} else {
|
||||
prevUnderscore = false
|
||||
}
|
||||
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
result := strings.Trim(b.String(), "_")
|
||||
if result == "" {
|
||||
result = "unnamed"
|
||||
}
|
||||
|
||||
if len(result) > maxLen {
|
||||
result = result[:maxLen]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Name returns the tool name, prefixed with the server name.
|
||||
// The total length is capped at 64 characters (OpenAI-compatible API limit).
|
||||
// A short hash of the original (unsanitized) server and tool names is appended
|
||||
// whenever sanitization is lossy or the name is truncated, ensuring that two
|
||||
// names which differ only in disallowed characters remain distinct after sanitization.
|
||||
func (t *MCPTool) Name() string {
|
||||
// Prefix with server name to avoid conflicts, and sanitize components
|
||||
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
|
||||
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
|
||||
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
|
||||
|
||||
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
|
||||
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
|
||||
strings.ToLower(t.tool.Name) == sanitizedTool
|
||||
|
||||
const maxTotal = 64
|
||||
if lossless && len(full) <= maxTotal {
|
||||
return full
|
||||
}
|
||||
|
||||
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
|
||||
// (not the sanitized names) so different originals always yield different hashes.
|
||||
h := fnv.New32a()
|
||||
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
|
||||
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
|
||||
|
||||
base := full
|
||||
if len(base) > maxTotal-9 {
|
||||
base = strings.TrimRight(full[:maxTotal-9], "_")
|
||||
}
|
||||
return base + "_" + suffix
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *MCPTool) Description() string {
|
||||
desc := t.tool.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
|
||||
}
|
||||
// Add server info to description
|
||||
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *MCPTool) Parameters() map[string]any {
|
||||
// The InputSchema is already a JSON Schema object
|
||||
schema := t.tool.InputSchema
|
||||
|
||||
// Handle nil schema
|
||||
if schema == nil {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// Try direct conversion first (fast path)
|
||||
if schemaMap, ok := schema.(map[string]any); ok {
|
||||
return schemaMap
|
||||
}
|
||||
|
||||
// Handle json.RawMessage and []byte - unmarshal directly
|
||||
var jsonData []byte
|
||||
if rawMsg, ok := schema.(json.RawMessage); ok {
|
||||
jsonData = rawMsg
|
||||
} else if bytes, ok := schema.([]byte); ok {
|
||||
jsonData = bytes
|
||||
}
|
||||
|
||||
if jsonData != nil {
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err == nil {
|
||||
return result
|
||||
}
|
||||
// Fallback on error
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// For other types (structs, etc.), convert via JSON marshal/unmarshal
|
||||
var err error
|
||||
jsonData, err = json.Marshal(schema)
|
||||
if err != nil {
|
||||
// Fallback to empty schema if marshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err != nil {
|
||||
// Fallback to empty schema if unmarshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Execute executes the MCP tool
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
nilErr := fmt.Errorf("MCP tool returned nil result without error")
|
||||
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
|
||||
}
|
||||
|
||||
// Handle error result from server
|
||||
if result.IsError {
|
||||
errMsg := extractContentText(result.Content)
|
||||
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
|
||||
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
|
||||
}
|
||||
|
||||
return t.normalizeResultContent(ctx, result.Content)
|
||||
}
|
||||
|
||||
// extractContentText extracts text from MCP content array
|
||||
func extractContentText(content []mcp.Content) string {
|
||||
var parts []string
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
parts = append(parts, sanitizeToolLLMContent(v.Text))
|
||||
case *mcp.ImageContent:
|
||||
parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType)))
|
||||
case *mcp.AudioContent:
|
||||
parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType)))
|
||||
case *mcp.ResourceLink:
|
||||
parts = append(parts, summarizeResourceLink(v))
|
||||
case *mcp.EmbeddedResource:
|
||||
parts = append(parts, summarizeEmbeddedResource(v))
|
||||
default:
|
||||
// For other content types, use string representation
|
||||
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
|
||||
}
|
||||
}
|
||||
return sanitizeToolLLMContent(strings.Join(parts, "\n"))
|
||||
}
|
||||
|
||||
func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult {
|
||||
llmParts := make([]string, 0, len(content))
|
||||
rawTextParts := make([]string, 0, len(content))
|
||||
mediaRefs := make([]string, 0, len(content))
|
||||
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
rawText := strings.TrimSpace(v.Text)
|
||||
if rawText != "" {
|
||||
rawTextParts = append(rawTextParts, rawText)
|
||||
}
|
||||
safeText := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
|
||||
if safeText != "" {
|
||||
llmParts = append(llmParts, safeText)
|
||||
}
|
||||
case *mcp.ImageContent:
|
||||
ref, note := t.storeBinaryContent(
|
||||
ctx,
|
||||
"image",
|
||||
normalizedMIMEType(v.MIMEType),
|
||||
v.Data,
|
||||
v.Annotations,
|
||||
)
|
||||
if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
if note != "" {
|
||||
llmParts = append(llmParts, note)
|
||||
}
|
||||
case *mcp.AudioContent:
|
||||
ref, note := t.storeBinaryContent(
|
||||
ctx,
|
||||
"audio",
|
||||
normalizedMIMEType(v.MIMEType),
|
||||
v.Data,
|
||||
v.Annotations,
|
||||
)
|
||||
if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
if note != "" {
|
||||
llmParts = append(llmParts, note)
|
||||
}
|
||||
case *mcp.ResourceLink:
|
||||
llmParts = append(llmParts, summarizeResourceLink(v))
|
||||
case *mcp.EmbeddedResource:
|
||||
ref, note, rawText := t.storeEmbeddedResource(ctx, v)
|
||||
if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
if rawText != "" {
|
||||
rawTextParts = append(rawTextParts, rawText)
|
||||
}
|
||||
if note != "" {
|
||||
llmParts = append(llmParts, note)
|
||||
}
|
||||
default:
|
||||
llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v))
|
||||
}
|
||||
}
|
||||
|
||||
forLLM := strings.Join(compactStrings(llmParts), "\n")
|
||||
rawText := strings.Join(compactStrings(rawTextParts), "\n")
|
||||
if artifactResult := t.persistLargeTextArtifact(rawText); artifactResult != nil {
|
||||
artifactResult.Media = mediaRefs
|
||||
return artifactResult
|
||||
}
|
||||
|
||||
result := &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Media: mediaRefs,
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *MCPTool) persistLargeTextArtifact(text string) *ToolResult {
|
||||
text = strings.TrimSpace(text)
|
||||
limit := t.maxInlineTextRunes
|
||||
if limit <= 0 {
|
||||
limit = maxMCPInlineTextRunes
|
||||
}
|
||||
size := utf8.RuneCountInString(text)
|
||||
if text == "" || size <= limit || t.workspace == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
dir := filepath.Join(t.workspace, ".artifacts", "mcp")
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
// TODO: Add lifecycle cleanup/retention for MCP artifact files.
|
||||
|
||||
pattern := fmt.Sprintf(
|
||||
"%s_%s_*.txt",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
sanitizeIdentifierComponent(t.tool.Name),
|
||||
)
|
||||
tmpFile, err := os.CreateTemp(dir, pattern)
|
||||
if err != nil {
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
path := tmpFile.Name()
|
||||
if _, err = tmpFile.WriteString(text); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(path)
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(path)
|
||||
return t.largeTextArtifactFallback(text, err)
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
"[MCP returned a large text result (%d chars); omitted from model context and saved as a local artifact.]",
|
||||
size,
|
||||
),
|
||||
ArtifactTags: []string{"[file:" + path + "]"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) largeTextArtifactFallback(text string, err error) *ToolResult {
|
||||
size := utf8.RuneCountInString(text)
|
||||
logger.WarnCF("tool", "Failed to persist large MCP text artifact", map[string]any{
|
||||
"server": t.serverName,
|
||||
"tool": t.tool.Name,
|
||||
"chars": size,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
"[MCP returned a large text result (%d chars); omitted from model context because artifact persistence failed.]",
|
||||
size,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string, string) {
|
||||
if content == nil || content.Resource == nil {
|
||||
return "", "[MCP returned an embedded resource without data.]", ""
|
||||
}
|
||||
|
||||
resource := content.Resource
|
||||
if len(resource.Blob) > 0 {
|
||||
ref, note := t.storeBinaryContent(
|
||||
ctx,
|
||||
"resource",
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
resource.Blob,
|
||||
content.Annotations,
|
||||
)
|
||||
return ref, note, ""
|
||||
}
|
||||
|
||||
rawText := strings.TrimSpace(resource.Text)
|
||||
if rawText != "" {
|
||||
return "", sanitizeToolLLMContent(resource.Text), rawText
|
||||
}
|
||||
|
||||
return "", summarizeEmbeddedResource(content), ""
|
||||
}
|
||||
|
||||
func (t *MCPTool) storeBinaryContent(
|
||||
ctx context.Context,
|
||||
kind string,
|
||||
mimeType string,
|
||||
data []byte,
|
||||
annotations *mcp.Annotations,
|
||||
) (string, string) {
|
||||
if len(data) == 0 {
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType)
|
||||
}
|
||||
if !annotationsAllowUser(annotations) {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) for non-user audience; omitted from model context.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
if t.mediaStore == nil {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
channel := ToolChannel(ctx)
|
||||
chatID := ToolChatID(ctx)
|
||||
if channel == "" || chatID == "" {
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s); omitted from model context because no target chat was available.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
dir := media.TempDir()
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
ext := extensionForMIMEType(mimeType)
|
||||
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, err = tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
|
||||
}
|
||||
|
||||
scope := fmt.Sprintf(
|
||||
"tool:mcp:%s:%s:%s:%d",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
channel,
|
||||
chatID,
|
||||
time.Now().UnixNano(),
|
||||
)
|
||||
filename := fmt.Sprintf(
|
||||
"%s_%s%s",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
sanitizeIdentifierComponent(t.tool.Name),
|
||||
ext,
|
||||
)
|
||||
|
||||
ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: mimeType,
|
||||
Source: fmt.Sprintf(
|
||||
"tool:mcp:%s:%s",
|
||||
sanitizeIdentifierComponent(t.serverName),
|
||||
sanitizeIdentifierComponent(t.tool.Name),
|
||||
),
|
||||
}, scope)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Sprintf(
|
||||
"[MCP returned %s content (%s) but it could not be registered as media.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
return ref, fmt.Sprintf(
|
||||
"[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]",
|
||||
kind,
|
||||
mimeType,
|
||||
)
|
||||
}
|
||||
|
||||
func summarizeResourceLink(content *mcp.ResourceLink) string {
|
||||
if content == nil {
|
||||
return "[MCP returned an empty resource link.]"
|
||||
}
|
||||
|
||||
parts := []string{"[MCP returned resource link"}
|
||||
if content.Name != "" {
|
||||
parts = append(parts, fmt.Sprintf("name=%q", content.Name))
|
||||
}
|
||||
if content.URI != "" {
|
||||
parts = append(parts, fmt.Sprintf("uri=%q", content.URI))
|
||||
}
|
||||
if content.MIMEType != "" {
|
||||
parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType))
|
||||
}
|
||||
if content.Description != "" {
|
||||
desc := strings.TrimSpace(content.Description)
|
||||
if len(desc) > 200 {
|
||||
desc = desc[:200] + "..."
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("description=%q", desc))
|
||||
}
|
||||
return strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
|
||||
func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
|
||||
if content == nil || content.Resource == nil {
|
||||
return "[MCP returned an embedded resource.]"
|
||||
}
|
||||
|
||||
resource := content.Resource
|
||||
if resource.URI != "" {
|
||||
return fmt.Sprintf(
|
||||
"[MCP returned embedded resource %q (%s).]",
|
||||
resource.URI,
|
||||
normalizedMIMEType(resource.MIMEType),
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
|
||||
}
|
||||
|
||||
func annotationsAllowUser(annotations *mcp.Annotations) bool {
|
||||
if annotations == nil || len(annotations.Audience) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, audience := range annotations.Audience {
|
||||
if strings.EqualFold(string(audience), "user") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizedMIMEType(mimeType string) string {
|
||||
if strings.TrimSpace(mimeType) == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return mimeType
|
||||
}
|
||||
|
||||
func compactStrings(parts []string) []string {
|
||||
compact := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if strings.TrimSpace(part) == "" {
|
||||
continue
|
||||
}
|
||||
compact = append(compact, part)
|
||||
}
|
||||
return compact
|
||||
}
|
||||
@@ -0,0 +1,810 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// MockMCPManager is a mock implementation of MCPManager interface for testing
|
||||
type MockMCPManager struct {
|
||||
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
func (m *MockMCPManager) CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
if m.callToolFunc != nil {
|
||||
return m.callToolFunc(ctx, serverName, toolName, arguments)
|
||||
}
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "mock result"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestNewMCPTool verifies MCP tool creation
|
||||
func TestNewMCPTool(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Test input",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
if mcpTool == nil {
|
||||
t.Fatal("NewMCPTool should not return nil")
|
||||
}
|
||||
// Verify tool properties we can access
|
||||
if mcpTool.Name() != "mcp_test_server_test_tool" {
|
||||
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Name verifies tool name with server prefix
|
||||
func TestMCPTool_Name(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple name",
|
||||
serverName: "github",
|
||||
toolName: "create_issue",
|
||||
expected: "mcp_github_create_issue",
|
||||
},
|
||||
{
|
||||
name: "filesystem server",
|
||||
serverName: "filesystem",
|
||||
toolName: "read_file",
|
||||
expected: "mcp_filesystem_read_file",
|
||||
},
|
||||
{
|
||||
name: "remote server",
|
||||
serverName: "remote-api",
|
||||
toolName: "fetch_data",
|
||||
expected: "mcp_remote-api_fetch_data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: tt.toolName}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Description verifies tool description generation
|
||||
func TestMCPTool_Description(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolDescription string
|
||||
expectContains []string
|
||||
}{
|
||||
{
|
||||
name: "with description",
|
||||
serverName: "github",
|
||||
toolDescription: "Create a GitHub issue",
|
||||
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
|
||||
},
|
||||
{
|
||||
name: "empty description",
|
||||
serverName: "filesystem",
|
||||
toolDescription: "",
|
||||
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: tt.toolDescription,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Description()
|
||||
|
||||
for _, expected := range tt.expectContains {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Description should contain '%s', got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters verifies parameter schema conversion
|
||||
func TestMCPTool_Parameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputSchema any
|
||||
expectType string
|
||||
checkProperty string
|
||||
expectProperty bool
|
||||
}{
|
||||
{
|
||||
name: "map schema",
|
||||
inputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
expectType: "object",
|
||||
checkProperty: "query",
|
||||
expectProperty: true,
|
||||
},
|
||||
{
|
||||
name: "nil schema",
|
||||
inputSchema: nil,
|
||||
expectType: "object",
|
||||
expectProperty: false,
|
||||
},
|
||||
{
|
||||
name: "json.RawMessage schema",
|
||||
inputSchema: []byte(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"stars": {
|
||||
"type": "integer",
|
||||
"description": "Minimum stars"
|
||||
}
|
||||
},
|
||||
"required": ["repo"]
|
||||
}`),
|
||||
expectType: "object",
|
||||
checkProperty: "repo",
|
||||
expectProperty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: tt.inputSchema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
if params == nil {
|
||||
t.Fatal("Parameters should not be nil")
|
||||
}
|
||||
|
||||
if params["type"] != tt.expectType {
|
||||
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
|
||||
}
|
||||
|
||||
// Check if property exists when expected
|
||||
if tt.checkProperty != "" {
|
||||
properties, ok := params["properties"].(map[string]any)
|
||||
if !ok && tt.expectProperty {
|
||||
t.Errorf("Expected properties to be a map")
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
_, hasProperty := properties[tt.checkProperty]
|
||||
if hasProperty != tt.expectProperty {
|
||||
t.Errorf("Expected property '%s' existence: %v, got: %v",
|
||||
tt.checkProperty, tt.expectProperty, hasProperty)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_Success tests successful tool execution
|
||||
func TestMCPTool_Execute_Success(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
// Verify correct parameters passed
|
||||
if serverName != "github" {
|
||||
t.Errorf("Expected serverName 'github', got '%s'", serverName)
|
||||
}
|
||||
if toolName != "search_repos" {
|
||||
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Found 3 repositories"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "search_repos",
|
||||
Description: "Search GitHub repositories",
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "github", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "golang mcp",
|
||||
}
|
||||
|
||||
result := mcpTool.Execute(ctx, args)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "Found 3 repositories" {
|
||||
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
|
||||
func TestMCPTool_Execute_ManagerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return nil, fmt.Errorf("connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
|
||||
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "connection failed") {
|
||||
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ServerError tests execution when server returns error
|
||||
func TestMCPTool_Execute_ServerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Invalid API key"},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
|
||||
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Invalid API key") {
|
||||
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
|
||||
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "First line"},
|
||||
&mcp.TextContent{Text: "Second line"},
|
||||
&mcp.TextContent{Text: "Third line"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "multi_output"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
expected := "First line\nSecond line\nThird line"
|
||||
if result.ForLLM != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_TextContent tests text content extraction
|
||||
func TestExtractContentText_TextContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Hello World"},
|
||||
&mcp.TextContent{Text: "Second message"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
expected := "Hello World\nSecond message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_ImageContent tests image content extraction
|
||||
func TestExtractContentText_ImageContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("base64data"),
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Expected image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "image/png") {
|
||||
t.Errorf("Expected MIME type in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_MixedContent tests mixed content types
|
||||
func TestExtractContentText_MixedContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Description"},
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("data"),
|
||||
MIMEType: "image/jpeg",
|
||||
},
|
||||
&mcp.TextContent{Text: "More text"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "Description") {
|
||||
t.Errorf("Should contain text content, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Should contain image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "More text") {
|
||||
t.Errorf("Should contain second text, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_EmptyContent tests empty content array
|
||||
func TestExtractContentText_EmptyContent(t *testing.T) {
|
||||
content := []mcp.Content{}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty string for empty content, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
|
||||
func TestMCPTool_InterfaceCompliance(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: "test"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
// Verify it implements Tool interface
|
||||
var _ Tool = mcpTool
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
|
||||
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The name parameter",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: schema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
// Should return the schema as-is when it's already a map
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got '%v'", params["type"])
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Properties should be a map")
|
||||
}
|
||||
|
||||
nameParam, ok := props["name"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Name parameter should exist")
|
||||
}
|
||||
|
||||
if nameParam["type"] != "string" {
|
||||
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_ImageContentStoredAsMedia(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("fake-image-bytes"),
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
|
||||
mcpTool.SetMediaStore(store)
|
||||
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got %q", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
if result.ResponseHandled {
|
||||
t.Fatal("expected MCP image artifact not to mark response as handled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "stored as a local media artifact") {
|
||||
t.Fatalf("expected local media artifact note, got %q", result.ForLLM)
|
||||
}
|
||||
|
||||
path, meta, err := store.ResolveWithMeta(result.Media[0])
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media ref to resolve: %v", err)
|
||||
}
|
||||
if meta.ContentType != "image/png" {
|
||||
t.Fatalf("expected image/png content type, got %q", meta.ContentType)
|
||||
}
|
||||
if filepath.Ext(path) != ".png" {
|
||||
t.Fatalf("expected png temp file, got %q", path)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != "fake-image-bytes" {
|
||||
t.Fatalf("expected stored media bytes to match input, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_EmbeddedResourceBlobStoredAsMedia(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.EmbeddedResource{
|
||||
Resource: &mcp.ResourceContents{
|
||||
URI: "file:///tmp/report.png",
|
||||
MIMEType: "image/png",
|
||||
Blob: []byte("blob-bytes"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "grafana", &mcp.Tool{Name: "get_dashboard_image"})
|
||||
mcpTool.SetMediaStore(store)
|
||||
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected embedded resource blob to be stored as media, got %d refs", len(result.Media))
|
||||
}
|
||||
path, _, err := store.ResolveWithMeta(result.Media[0])
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media ref to resolve: %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected stored media file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != "blob-bytes" {
|
||||
t.Fatalf("expected stored blob bytes to match input, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_RespectsUserAudienceForBinaryContent(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("assistant-only"),
|
||||
MIMEType: "image/png",
|
||||
Annotations: &mcp.Annotations{Audience: []mcp.Role{"assistant"}},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "screenshoto", &mcp.Tool{Name: "take_screenshot"})
|
||||
mcpTool.SetMediaStore(store)
|
||||
|
||||
result := mcpTool.Execute(WithToolContext(context.Background(), "telegram", "chat-42"), nil)
|
||||
|
||||
if len(result.Media) != 0 {
|
||||
t.Fatalf("expected no media ref for non-user audience, got %d", len(result.Media))
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "non-user audience") {
|
||||
t.Fatalf("expected audience note, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: strings.Repeat("QUJD", 400)},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if result.ForLLM != largeBase64OmittedMessage {
|
||||
t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeBase64TextArtifactPreservesRawPayload(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
largeBase64 := strings.Repeat("QUJD", 400)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeBase64},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(32)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if !strings.Contains(result.ForLLM, "saved as a local artifact") {
|
||||
t.Fatalf("expected artifact note, got %q", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM == largeBase64OmittedMessage {
|
||||
t.Fatalf("expected artifact note instead of sanitized base64 placeholder")
|
||||
}
|
||||
if len(result.ArtifactTags) != 1 {
|
||||
t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
|
||||
}
|
||||
tag := result.ArtifactTags[0]
|
||||
const prefix = "[file:"
|
||||
if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
|
||||
t.Fatalf("expected file artifact tag, got %q", tag)
|
||||
}
|
||||
path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected artifact file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != largeBase64 {
|
||||
t.Fatalf("expected artifact file contents to preserve raw MCP payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeTextStoredAsArtifact(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspace)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
|
||||
t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "saved as a local artifact") {
|
||||
t.Fatalf("expected artifact note, got %q", result.ForLLM)
|
||||
}
|
||||
if len(result.ArtifactTags) != 1 {
|
||||
t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
|
||||
}
|
||||
tag := result.ArtifactTags[0]
|
||||
const prefix = "[file:"
|
||||
if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
|
||||
t.Fatalf("expected file artifact tag, got %q", tag)
|
||||
}
|
||||
path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
|
||||
if !strings.HasPrefix(path, workspace) {
|
||||
t.Fatalf("expected artifact inside workspace, got %q", path)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("expected artifact file to be readable: %v", err)
|
||||
}
|
||||
if string(data) != strings.TrimSpace(largeText) {
|
||||
t.Fatalf("expected artifact file contents to match source text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_CustomInlineTextThreshold(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
text := strings.Repeat("small custom threshold text\n", 20)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: text},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(32)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if len(result.ArtifactTags) != 1 {
|
||||
t.Fatalf("expected custom threshold to persist artifact, got %+v", result)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "small custom threshold text") {
|
||||
t.Fatalf("expected text to be omitted from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_LargeTextArtifactFailureStillOmitsContext(t *testing.T) {
|
||||
workspaceRoot := t.TempDir()
|
||||
workspaceFile := filepath.Join(workspaceRoot, "not-a-directory")
|
||||
if err := os.WriteFile(workspaceFile, []byte("x"), 0o600); err != nil {
|
||||
t.Fatalf("failed to create workspace file: %v", err)
|
||||
}
|
||||
|
||||
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(workspaceFile)
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
|
||||
t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "artifact persistence failed") {
|
||||
t.Fatalf("expected persistence failure note, got %q", result.ForLLM)
|
||||
}
|
||||
if len(result.ArtifactTags) != 0 {
|
||||
t.Fatalf("expected no artifact tags on persistence failure, got %+v", result.ArtifactTags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_WhitespaceWorkspaceDisablesArtifactPersistence(t *testing.T) {
|
||||
largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: largeText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
|
||||
mcpTool.SetWorkspace(" \n\t ")
|
||||
|
||||
result := mcpTool.Execute(context.Background(), nil)
|
||||
|
||||
if len(result.ArtifactTags) != 0 {
|
||||
t.Fatalf("expected no artifact tags for whitespace workspace, got %+v", result.ArtifactTags)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "This is a large MCP text payload") {
|
||||
t.Fatalf("expected large text to remain inline when workspace is blank, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error
|
||||
|
||||
// sentTarget records the channel+chatID that the message tool sent to.
|
||||
type sentTarget struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
}
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallbackWithContext
|
||||
mu sync.Mutex
|
||||
// sentTargets tracks targets sent to in the current round, keyed by session key
|
||||
// to support parallel turns for different sessions.
|
||||
sentTargets map[string][]sentTarget
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
return &MessageTool{
|
||||
sentTargets: make(map[string][]sentTarget),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MessageTool) Name() string {
|
||||
return "message"
|
||||
}
|
||||
|
||||
func (t *MessageTool) Description() string {
|
||||
return "Send a message to user on a chat channel. Use this when you want to communicate something."
|
||||
}
|
||||
|
||||
func (t *MessageTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message content to send",
|
||||
},
|
||||
"channel": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, whatsapp, etc.)",
|
||||
},
|
||||
"chat_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID",
|
||||
},
|
||||
"reply_to_message_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: reply target message ID for channels that support threaded replies",
|
||||
},
|
||||
},
|
||||
"required": []string{"content"},
|
||||
}
|
||||
}
|
||||
|
||||
// ResetSentInRound resets the per-round send tracker for the given session key.
|
||||
// Called by the agent loop at the start of each inbound message processing round.
|
||||
func (t *MessageTool) ResetSentInRound(sessionKey string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Delete the key entirely to prevent unbounded map growth over time
|
||||
// with many unique sessions. Truncating the slice keeps the key alive.
|
||||
delete(t.sentTargets, sessionKey)
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
func (t *MessageTool) HasSentInRound(sessionKey string) bool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return len(t.sentTargets[sessionKey]) > 0
|
||||
}
|
||||
|
||||
// HasSentTo returns true if the message tool sent to the specific channel+chatID
|
||||
// during the current round. Used by PublishResponseIfNeeded to avoid suppressing
|
||||
// the final response when the message tool only sent to a different conversation.
|
||||
func (t *MessageTool) HasSentTo(sessionKey, channel, chatID string) bool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, st := range t.sentTargets[sessionKey] {
|
||||
if st.Channel == channel && st.ChatID == chatID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
}
|
||||
|
||||
channel, _ := args["channel"].(string)
|
||||
chatID, _ := args["chat_id"].(string)
|
||||
replyToMessageID, _ := args["reply_to_message_id"].(string)
|
||||
|
||||
if channel == "" {
|
||||
channel = ToolChannel(ctx)
|
||||
}
|
||||
if chatID == "" {
|
||||
chatID = ToolChatID(ctx)
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
|
||||
}
|
||||
|
||||
if t.sendCallback == nil {
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
sessionKey := ToolSessionKey(ctx)
|
||||
t.mu.Lock()
|
||||
t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID})
|
||||
t.mu.Unlock()
|
||||
|
||||
// Silent: user already received the message directly
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil {
|
||||
t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v",
|
||||
ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify message was sent with correct parameters
|
||||
if sentChannel != "test-channel" {
|
||||
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "test-chat-id" {
|
||||
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
if sentContent != "Hello, world!" {
|
||||
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
|
||||
}
|
||||
|
||||
// Verify ToolResult meets US-011 criteria:
|
||||
// - Send success returns SilentResult (Silent=true)
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true for successful send")
|
||||
}
|
||||
|
||||
// - ForLLM contains send status description
|
||||
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
|
||||
// - ForUser is empty (user already received message directly)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
|
||||
}
|
||||
|
||||
// - IsError should be false
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError=false for successful send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
"channel": "custom-channel",
|
||||
"chat_id": "custom-chat-id",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify custom channel/chatID were used instead of defaults
|
||||
if sentChannel != "custom-channel" {
|
||||
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "custom-chat-id" {
|
||||
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true")
|
||||
}
|
||||
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify ToolResult for send failure:
|
||||
// - Send failure returns ErrorResult (IsError=true)
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for failed send")
|
||||
}
|
||||
|
||||
// - ForLLM contains error description
|
||||
expectedErrMsg := "sending message: network error"
|
||||
if result.ForLLM != expectedErrMsg {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
|
||||
}
|
||||
|
||||
// - Err field should contain original error
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err != sendErr {
|
||||
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{} // content missing
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error result for missing content
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for missing content")
|
||||
}
|
||||
if result.ForLLM != "content is required" {
|
||||
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when no target channel specified
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when no target channel")
|
||||
}
|
||||
if result.ForLLM != "No target channel/chat specified" {
|
||||
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No SetSendCallback called
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when send callback not configured
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when send callback not configured")
|
||||
}
|
||||
if result.ForLLM != "Message sending not configured" {
|
||||
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Name(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
if tool.Name() != "message" {
|
||||
t.Errorf("Expected name 'message', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Description(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Parameters(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
params := tool.Parameters()
|
||||
|
||||
// Verify parameters structure
|
||||
typ, ok := params["type"].(string)
|
||||
if !ok || typ != "object" {
|
||||
t.Error("Expected type 'object'")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected properties to be a map")
|
||||
}
|
||||
|
||||
// Check required properties
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok || len(required) != 1 || required[0] != "content" {
|
||||
t.Error("Expected 'content' to be required")
|
||||
}
|
||||
|
||||
// Check content property
|
||||
contentProp, ok := props["content"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected 'content' property")
|
||||
}
|
||||
if contentProp["type"] != "string" {
|
||||
t.Error("Expected content type to be 'string'")
|
||||
}
|
||||
|
||||
// Check channel property (optional)
|
||||
channelProp, ok := props["channel"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected 'channel' property")
|
||||
}
|
||||
if channelProp["type"] != "string" {
|
||||
t.Error("Expected channel type to be 'string'")
|
||||
}
|
||||
|
||||
// Check chat_id property (optional)
|
||||
chatIDProp, ok := props["chat_id"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected 'chat_id' property")
|
||||
}
|
||||
if chatIDProp["type"] != "string" {
|
||||
t.Error("Expected chat_id type to be 'string'")
|
||||
}
|
||||
|
||||
// Check reply_to_message_id property (optional)
|
||||
replyToProp, ok := props["reply_to_message_id"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Expected 'reply_to_message_id' property")
|
||||
}
|
||||
if replyToProp["type"] != "string" {
|
||||
t.Error("Expected reply_to_message_id type to be 'string'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var sentReplyTo string
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
sentReplyTo = replyToMessageID
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Reply test",
|
||||
"reply_to_message_id": "msg-123",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if sentReplyTo != "msg-123" {
|
||||
t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
|
||||
var gotAgentID, gotSessionKey string
|
||||
var gotScope *session.SessionScope
|
||||
tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
gotAgentID = ToolAgentID(ctx)
|
||||
gotSessionKey = ToolSessionKey(ctx)
|
||||
gotScope = ToolSessionScope(ctx)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Dimensions: []string{"chat"},
|
||||
Values: map[string]string{
|
||||
"chat": "direct:test-chat-id",
|
||||
},
|
||||
})
|
||||
|
||||
result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotAgentID != "main" {
|
||||
t.Fatalf("ToolAgentID() = %q, want main", gotAgentID)
|
||||
}
|
||||
if gotSessionKey != "sk_v1_tool" {
|
||||
t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey)
|
||||
}
|
||||
if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" {
|
||||
t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type ReactionCallback func(ctx context.Context, channel, chatID, messageID string) error
|
||||
|
||||
type ReactionTool struct {
|
||||
reactionCallback ReactionCallback
|
||||
}
|
||||
|
||||
func NewReactionTool() *ReactionTool {
|
||||
return &ReactionTool{}
|
||||
}
|
||||
|
||||
func (t *ReactionTool) Name() string {
|
||||
return "reaction"
|
||||
}
|
||||
|
||||
func (t *ReactionTool) Description() string {
|
||||
return "Add a reaction to a message. Defaults to the current inbound message when message_id is omitted."
|
||||
}
|
||||
|
||||
func (t *ReactionTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"message_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target message ID; defaults to the current inbound message",
|
||||
},
|
||||
"channel": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, whatsapp, etc.)",
|
||||
},
|
||||
"chat_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReactionTool) SetReactionCallback(callback ReactionCallback) {
|
||||
t.reactionCallback = callback
|
||||
}
|
||||
|
||||
func (t *ReactionTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
channel, _ := args["channel"].(string)
|
||||
chatID, _ := args["chat_id"].(string)
|
||||
messageID, _ := args["message_id"].(string)
|
||||
|
||||
if channel == "" {
|
||||
channel = ToolChannel(ctx)
|
||||
}
|
||||
if chatID == "" {
|
||||
chatID = ToolChatID(ctx)
|
||||
}
|
||||
if messageID == "" {
|
||||
messageID = ToolMessageID(ctx)
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
|
||||
}
|
||||
if messageID == "" {
|
||||
return &ToolResult{ForLLM: "message_id is required", IsError: true}
|
||||
}
|
||||
if t.reactionCallback == nil {
|
||||
return &ToolResult{ForLLM: "Reaction not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.reactionCallback(ctx, channel, chatID, messageID); err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("adding reaction: %v", err),
|
||||
IsError: true,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Reaction added to %s:%s message %s", channel, chatID, messageID),
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReactionTool_Execute_UsesContextMessageIDByDefault(t *testing.T) {
|
||||
tool := NewReactionTool()
|
||||
|
||||
var gotChannel, gotChatID, gotMessageID string
|
||||
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error {
|
||||
gotChannel = channel
|
||||
gotChatID = chatID
|
||||
gotMessageID = messageID
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-100", "")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotChannel != "telegram" || gotChatID != "chat-1" || gotMessageID != "msg-100" {
|
||||
t.Fatalf("unexpected callback args: channel=%q chatID=%q messageID=%q", gotChannel, gotChatID, gotMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReactionTool_Execute_AllowsExplicitMessageIDOverride(t *testing.T) {
|
||||
tool := NewReactionTool()
|
||||
|
||||
var gotMessageID string
|
||||
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error {
|
||||
gotMessageID = messageID
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-context", "")
|
||||
result := tool.Execute(ctx, map[string]any{"message_id": "msg-explicit"})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if gotMessageID != "msg-explicit" {
|
||||
t.Fatalf("expected explicit message id, got %q", gotMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReactionTool_Execute_MissingMessageID(t *testing.T) {
|
||||
tool := NewReactionTool()
|
||||
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error { return nil })
|
||||
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if result.ForLLM != "message_id is required" {
|
||||
t.Fatalf("unexpected error message: %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReactionTool_Execute_CallbackError(t *testing.T) {
|
||||
tool := NewReactionTool()
|
||||
tool.SetReactionCallback(func(ctx context.Context, channel, chatID, messageID string) error {
|
||||
return errors.New("unsupported")
|
||||
})
|
||||
|
||||
ctx := WithToolInboundContext(context.Background(), "telegram", "chat-1", "msg-100", "")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if result.Err == nil {
|
||||
t.Fatal("expected wrapped error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReactionTool_Parameters(t *testing.T) {
|
||||
tool := NewReactionTool()
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected properties map")
|
||||
}
|
||||
if _, ok := props["message_id"]; !ok {
|
||||
t.Fatal("expected message_id parameter")
|
||||
}
|
||||
if _, ok := props["channel"]; !ok {
|
||||
t.Fatal("expected channel parameter")
|
||||
}
|
||||
if _, ok := props["chat_id"]; !ok {
|
||||
t.Fatal("expected chat_id parameter")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
toolshared "github.com/sipeed/picoclaw/pkg/tools/shared"
|
||||
)
|
||||
|
||||
type (
|
||||
Tool = toolshared.Tool
|
||||
ToolResult = toolshared.ToolResult
|
||||
AsyncCallback = toolshared.AsyncCallback
|
||||
)
|
||||
|
||||
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
|
||||
return toolshared.WithToolContext(ctx, channel, chatID)
|
||||
}
|
||||
|
||||
func WithToolInboundContext(
|
||||
ctx context.Context,
|
||||
channel, chatID, messageID, replyToMessageID string,
|
||||
) context.Context {
|
||||
return toolshared.WithToolInboundContext(ctx, channel, chatID, messageID, replyToMessageID)
|
||||
}
|
||||
|
||||
func WithToolSessionContext(
|
||||
ctx context.Context,
|
||||
agentID, sessionKey string,
|
||||
scope *session.SessionScope,
|
||||
) context.Context {
|
||||
return toolshared.WithToolSessionContext(ctx, agentID, sessionKey, scope)
|
||||
}
|
||||
|
||||
func ToolChannel(ctx context.Context) string {
|
||||
return toolshared.ToolChannel(ctx)
|
||||
}
|
||||
|
||||
func ToolChatID(ctx context.Context) string {
|
||||
return toolshared.ToolChatID(ctx)
|
||||
}
|
||||
|
||||
func ToolMessageID(ctx context.Context) string {
|
||||
return toolshared.ToolMessageID(ctx)
|
||||
}
|
||||
|
||||
func ToolAgentID(ctx context.Context) string {
|
||||
return toolshared.ToolAgentID(ctx)
|
||||
}
|
||||
|
||||
func ToolSessionKey(ctx context.Context) string {
|
||||
return toolshared.ToolSessionKey(ctx)
|
||||
}
|
||||
|
||||
func ToolSessionScope(ctx context.Context) *session.SessionScope {
|
||||
return toolshared.ToolSessionScope(ctx)
|
||||
}
|
||||
|
||||
func ErrorResult(message string) *ToolResult {
|
||||
return toolshared.ErrorResult(message)
|
||||
}
|
||||
|
||||
func SilentResult(forLLM string) *ToolResult {
|
||||
return toolshared.SilentResult(forLLM)
|
||||
}
|
||||
|
||||
func NewToolResult(forLLM string) *ToolResult {
|
||||
return toolshared.NewToolResult(forLLM)
|
||||
}
|
||||
|
||||
func UserResult(content string) *ToolResult {
|
||||
return toolshared.UserResult(content)
|
||||
}
|
||||
|
||||
func MediaResult(forLLM string, mediaRefs []string) *ToolResult {
|
||||
return toolshared.MediaResult(forLLM, mediaRefs)
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const defaultSkillRegistryName = "github"
|
||||
|
||||
var persistInstalledSkillOriginMeta = writeOriginMeta
|
||||
|
||||
// InstallSkillTool allows the LLM agent to install skills from registries.
|
||||
// It shares the same RegistryManager that FindSkillsTool uses,
|
||||
// so all registries configured in config are available for installation.
|
||||
type InstallSkillTool struct {
|
||||
registryMgr *skills.RegistryManager
|
||||
workspace string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewInstallSkillTool creates a new InstallSkillTool.
|
||||
// registryMgr is the shared registry manager (same instance as FindSkillsTool).
|
||||
// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/.
|
||||
func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool {
|
||||
return &InstallSkillTool{
|
||||
registryMgr: registryMgr,
|
||||
workspace: workspace,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Name() string {
|
||||
return "install_skill"
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Description() string {
|
||||
return "Install a skill from a registry by slug. Defaults to GitHub when registry is omitted. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills."
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"slug": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')",
|
||||
},
|
||||
"version": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Specific version to install (optional, defaults to latest)",
|
||||
},
|
||||
"registry": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Registry to install from (optional, defaults to 'github')",
|
||||
},
|
||||
"force": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Force reinstall if skill already exists (default false)",
|
||||
},
|
||||
},
|
||||
"required": []string{"slug"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
// Install lock to prevent concurrent directory operations.
|
||||
// Ideally this should be done at a `slug` level, currently, its at a `workspace` level.
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
slug, _ := args["slug"].(string)
|
||||
if strings.TrimSpace(slug) == "" {
|
||||
return ErrorResult("identifier is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
// Validate registry
|
||||
registryName, _ := args["registry"].(string)
|
||||
if registryName == "" {
|
||||
registryName = defaultSkillRegistryName
|
||||
}
|
||||
if err := utils.ValidateSkillIdentifier(registryName); err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error()))
|
||||
}
|
||||
|
||||
// Resolve which registry to use.
|
||||
registry := t.registryMgr.GetRegistry(registryName)
|
||||
if registry == nil {
|
||||
return ErrorResult(fmt.Sprintf("registry %q not found", registryName))
|
||||
}
|
||||
|
||||
// Validate target and resolve install directory.
|
||||
dirName, err := registry.ResolveInstallDirName(slug)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error()))
|
||||
}
|
||||
|
||||
version, _ := args["version"].(string)
|
||||
force, _ := args["force"].(bool)
|
||||
|
||||
// Check if already installed.
|
||||
skillsDir := filepath.Join(t.workspace, "skills")
|
||||
targetDir := filepath.Join(skillsDir, dirName)
|
||||
backupDir := ""
|
||||
restorePreviousInstall := func() {
|
||||
if backupDir == "" {
|
||||
return
|
||||
}
|
||||
if rmErr := os.RemoveAll(targetDir); rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove failed install before restore",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if restoreErr := os.Rename(backupDir, targetDir); restoreErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to restore previous install after failed reinstall",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"backup_dir": backupDir,
|
||||
"target_dir": targetDir,
|
||||
"error": restoreErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
backupDir = ""
|
||||
}
|
||||
|
||||
if !force {
|
||||
if _, statErr := os.Stat(targetDir); statErr == nil {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if _, statErr := os.Stat(targetDir); statErr == nil {
|
||||
backupDir = filepath.Join(skillsDir, fmt.Sprintf(".%s.picoclaw-backup-%d", dirName, time.Now().UnixNano()))
|
||||
if renameErr := os.Rename(targetDir, backupDir); renameErr != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to prepare reinstall for %q: %v", slug, renameErr))
|
||||
}
|
||||
} else if !os.IsNotExist(statErr) {
|
||||
return ErrorResult(fmt.Sprintf("failed to inspect existing install for %q: %v", slug, statErr))
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure skills directory exists.
|
||||
if mkdirErr := os.MkdirAll(skillsDir, 0o755); mkdirErr != nil {
|
||||
restorePreviousInstall()
|
||||
return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", mkdirErr))
|
||||
}
|
||||
|
||||
// Download and install (handles metadata, version resolution, extraction).
|
||||
result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir)
|
||||
if err != nil {
|
||||
// Clean up partial install.
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove partial install",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
restorePreviousInstall()
|
||||
return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err))
|
||||
}
|
||||
|
||||
// Moderation: block malware.
|
||||
if result.IsMalwareBlocked {
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove partial install",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
restorePreviousInstall()
|
||||
return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
|
||||
}
|
||||
|
||||
if !workspaceHasValidInstalledSkill(t.workspace, dirName) {
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove invalid installed skill",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
restorePreviousInstall()
|
||||
return ErrorResult(fmt.Sprintf("failed to install %q: registry archive is not a valid skill", slug))
|
||||
}
|
||||
|
||||
// Write origin metadata.
|
||||
if err := persistInstalledSkillOriginMeta(targetDir, registry, slug, result.Version); err != nil {
|
||||
logger.ErrorCF("tool", "Failed to write origin metadata",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"error": err.Error(),
|
||||
"target": targetDir,
|
||||
"registry": registry.Name(),
|
||||
"slug": slug,
|
||||
"version": result.Version,
|
||||
})
|
||||
rmErr := os.RemoveAll(targetDir)
|
||||
if rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to roll back install after metadata write failure",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"target_dir": targetDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
restorePreviousInstall()
|
||||
return ErrorResult(fmt.Sprintf("failed to persist skill metadata for %q: %v", slug, err))
|
||||
}
|
||||
if backupDir != "" {
|
||||
if rmErr := os.RemoveAll(backupDir); rmErr != nil {
|
||||
logger.ErrorCF("tool", "Failed to remove previous install backup after successful reinstall",
|
||||
map[string]any{
|
||||
"tool": "install_skill",
|
||||
"backup_dir": backupDir,
|
||||
"error": rmErr.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Build result with moderation warning if suspicious.
|
||||
var output string
|
||||
if result.IsSuspicious {
|
||||
output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
|
||||
}
|
||||
output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
|
||||
slug, result.Version, registry.Name(), targetDir)
|
||||
|
||||
if result.Summary != "" {
|
||||
output += fmt.Sprintf("Description: %s\n", result.Summary)
|
||||
}
|
||||
output += "\nThe skill is now available and can be loaded in the current session."
|
||||
|
||||
return SilentResult(output)
|
||||
}
|
||||
|
||||
// originMeta tracks which registry a skill was installed from.
|
||||
type originMeta struct {
|
||||
Version int `json:"version"`
|
||||
OriginKind string `json:"origin_kind,omitempty"`
|
||||
Registry string `json:"registry"`
|
||||
Slug string `json:"slug"`
|
||||
RegistryURL string `json:"registry_url,omitempty"`
|
||||
InstalledVersion string `json:"installed_version"`
|
||||
InstalledAt int64 `json:"installed_at"`
|
||||
}
|
||||
|
||||
func writeOriginMeta(targetDir string, registry skills.SkillRegistry, slug, version string) error {
|
||||
normalizedSlug, registryURL := skills.BuildInstallMetadataForRegistryInstance(registry, slug, version)
|
||||
registryName := ""
|
||||
if registry != nil {
|
||||
registryName = registry.Name()
|
||||
}
|
||||
|
||||
meta := originMeta{
|
||||
Version: 1,
|
||||
OriginKind: "third_party",
|
||||
Registry: registryName,
|
||||
Slug: normalizedSlug,
|
||||
RegistryURL: registryURL,
|
||||
InstalledVersion: version,
|
||||
InstalledAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use unified atomic write utility with explicit sync for flash storage reliability.
|
||||
return fileutil.WriteFileAtomic(filepath.Join(targetDir, ".skill-origin.json"), data, 0o600)
|
||||
}
|
||||
|
||||
func workspaceHasValidInstalledSkill(workspace, directory string) bool {
|
||||
loader := skills.NewSkillsLoader(workspace, "", "")
|
||||
for _, skill := range loader.ListSkills() {
|
||||
if skill.Source != "workspace" {
|
||||
continue
|
||||
}
|
||||
if filepath.Base(filepath.Dir(skill.Path)) == directory {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,423 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
type mockInstallRegistry struct{}
|
||||
|
||||
const validSkillMarkdown = "---\nname: pr-review\ndescription: Review pull requests\n---\n# PR Review\n"
|
||||
|
||||
func (m *mockInstallRegistry) Name() string { return "clawhub" }
|
||||
|
||||
func (m *mockInstallRegistry) ResolveInstallDirName(target string) (string, error) {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (m *mockInstallRegistry) SkillURL(slug, _ string) string { return slug }
|
||||
|
||||
func (m *mockInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInstallRegistry) DownloadAndInstall(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ string,
|
||||
targetDir string,
|
||||
) (*skills.InstallResult, error) {
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(targetDir, "SKILL.md"), []byte(validSkillMarkdown), 0o600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &skills.InstallResult{Version: "test"}, nil
|
||||
}
|
||||
|
||||
type mockGitHubInstallRegistry struct{}
|
||||
|
||||
func (m *mockGitHubInstallRegistry) Name() string { return "github" }
|
||||
|
||||
func (m *mockGitHubInstallRegistry) ResolveInstallDirName(target string) (string, error) {
|
||||
return "pr-review", nil
|
||||
}
|
||||
|
||||
func (m *mockGitHubInstallRegistry) SkillURL(slug, _ string) string { return slug }
|
||||
|
||||
func (m *mockGitHubInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGitHubInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockGitHubInstallRegistry) DownloadAndInstall(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ string,
|
||||
targetDir string,
|
||||
) (*skills.InstallResult, error) {
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(targetDir, "SKILL.md"), []byte(validSkillMarkdown), 0o600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &skills.InstallResult{Version: "main"}, nil
|
||||
}
|
||||
|
||||
type stubGitHubInstallRegistry struct {
|
||||
*skills.GitHubRegistry
|
||||
}
|
||||
|
||||
func (m *stubGitHubInstallRegistry) DownloadAndInstall(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ string,
|
||||
targetDir string,
|
||||
) (*skills.InstallResult, error) {
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(targetDir, "SKILL.md"), []byte(validSkillMarkdown), 0o600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &skills.InstallResult{Version: "main"}, nil
|
||||
}
|
||||
|
||||
type mockInvalidInstallRegistry struct{}
|
||||
|
||||
type mockFailingInstallRegistry struct{}
|
||||
|
||||
func (m *mockInvalidInstallRegistry) Name() string { return "clawhub" }
|
||||
|
||||
func (m *mockInvalidInstallRegistry) ResolveInstallDirName(target string) (string, error) {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (m *mockInvalidInstallRegistry) SkillURL(slug, _ string) string { return slug }
|
||||
|
||||
func (m *mockInvalidInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInvalidInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockInvalidInstallRegistry) DownloadAndInstall(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ string,
|
||||
targetDir string,
|
||||
) (*skills.InstallResult, error) {
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(targetDir, "SKILL.md"),
|
||||
[]byte("---\nname: bad_skill\ndescription: invalid name\n---\n# Invalid\n"),
|
||||
0o600,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &skills.InstallResult{Version: "test"}, nil
|
||||
}
|
||||
|
||||
func (m *mockFailingInstallRegistry) Name() string { return "clawhub" }
|
||||
|
||||
func (m *mockFailingInstallRegistry) ResolveInstallDirName(target string) (string, error) {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (m *mockFailingInstallRegistry) SkillURL(slug, _ string) string { return slug }
|
||||
|
||||
func (m *mockFailingInstallRegistry) Search(context.Context, string, int) ([]skills.SearchResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockFailingInstallRegistry) GetSkillMeta(context.Context, string) (*skills.SkillMeta, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockFailingInstallRegistry) DownloadAndInstall(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ string,
|
||||
_ string,
|
||||
) (*skills.InstallResult, error) {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
|
||||
func TestInstallSkillToolName(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
assert.Equal(t, "install_skill", tool.Name())
|
||||
}
|
||||
|
||||
func TestInstallSkillToolMissingSlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolEmptySlug(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": " ",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolUnsafeSlug(t *testing.T) {
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(skills.NewClawHubRegistry(skills.ClawHubConfig{Enabled: true}))
|
||||
tool := NewInstallSkillTool(registryMgr, t.TempDir())
|
||||
|
||||
cases := []string{
|
||||
"../etc/passwd",
|
||||
"path/traversal",
|
||||
"path\\traversal",
|
||||
}
|
||||
|
||||
for _, slug := range cases {
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": slug,
|
||||
"registry": "clawhub",
|
||||
})
|
||||
assert.True(t, result.IsError, "slug %q should be rejected", slug)
|
||||
assert.Contains(t, result.ForLLM, "invalid slug")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstallSkillToolAlreadyExists(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
skillDir := filepath.Join(workspace, "skills", "existing-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&mockInstallRegistry{})
|
||||
tool := NewInstallSkillTool(registryMgr, workspace)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "existing-skill",
|
||||
"registry": "clawhub",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "already installed")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolRegistryNotFound(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "some-skill",
|
||||
"registry": "nonexistent",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "registry")
|
||||
assert.Contains(t, result.ForLLM, "not found")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolParameters(t *testing.T) {
|
||||
tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, props, "slug")
|
||||
assert.Contains(t, props, "version")
|
||||
assert.Contains(t, props, "registry")
|
||||
assert.Contains(t, props, "force")
|
||||
|
||||
required, ok := params["required"].([]string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, required, "slug")
|
||||
assert.NotContains(t, required, "registry")
|
||||
}
|
||||
|
||||
func TestInstallSkillToolMissingRegistry(t *testing.T) {
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&mockGitHubInstallRegistry{})
|
||||
tool := NewInstallSkillTool(registryMgr, t.TempDir())
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "some-skill",
|
||||
})
|
||||
assert.False(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, `Successfully installed skill`)
|
||||
}
|
||||
|
||||
func TestInstallSkillToolAllowsGitHubURLSlug(t *testing.T) {
|
||||
registry := skills.GitHubRegistryConfig{Enabled: true, BaseURL: "https://github.com"}.BuildRegistry()
|
||||
githubRegistry, ok := registry.(*skills.GitHubRegistry)
|
||||
require.True(t, ok)
|
||||
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&stubGitHubInstallRegistry{GitHubRegistry: githubRegistry})
|
||||
workspace := t.TempDir()
|
||||
tool := NewInstallSkillTool(registryMgr, workspace)
|
||||
|
||||
slug := "https://github.com/synthetic-lab/octofriend/tree/main/.agents/skills/pr-review"
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": slug,
|
||||
"registry": "github",
|
||||
})
|
||||
|
||||
assert.False(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, `Successfully installed skill`)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, "skills", "pr-review", ".skill-origin.json"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var meta originMeta
|
||||
require.NoError(t, json.Unmarshal(data, &meta))
|
||||
assert.Equal(t, "third_party", meta.OriginKind)
|
||||
assert.Equal(t, "github", meta.Registry)
|
||||
assert.Equal(t, "synthetic-lab/octofriend/.agents/skills/pr-review", meta.Slug)
|
||||
assert.Equal(t, slug, meta.RegistryURL)
|
||||
assert.Equal(t, "main", meta.InstalledVersion)
|
||||
assert.NotZero(t, meta.InstalledAt)
|
||||
}
|
||||
|
||||
func TestInstallSkillToolPreservesGitHubSourceURLWithEnterpriseRegistry(t *testing.T) {
|
||||
registry := skills.GitHubRegistryConfig{Enabled: true, BaseURL: "https://ghe.example.com/git"}.BuildRegistry()
|
||||
githubRegistry, ok := registry.(*skills.GitHubRegistry)
|
||||
require.True(t, ok)
|
||||
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&stubGitHubInstallRegistry{GitHubRegistry: githubRegistry})
|
||||
workspace := t.TempDir()
|
||||
tool := NewInstallSkillTool(registryMgr, workspace)
|
||||
|
||||
slug := "https://github.com/synthetic-lab/octofriend/tree/main/.agents/skills/pr-review"
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": slug,
|
||||
"registry": "github",
|
||||
})
|
||||
|
||||
assert.False(t, result.IsError)
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(workspace, "skills", "pr-review", ".skill-origin.json"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var meta originMeta
|
||||
require.NoError(t, json.Unmarshal(data, &meta))
|
||||
assert.Equal(t, "synthetic-lab/octofriend/.agents/skills/pr-review", meta.Slug)
|
||||
assert.Equal(t, slug, meta.RegistryURL)
|
||||
assert.Equal(t, "main", meta.InstalledVersion)
|
||||
}
|
||||
|
||||
func TestInstallSkillToolRejectsInvalidInstalledSkill(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&mockInvalidInstallRegistry{})
|
||||
tool := NewInstallSkillTool(registryMgr, workspace)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "broken-skill",
|
||||
"registry": "clawhub",
|
||||
})
|
||||
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "not a valid skill")
|
||||
_, err := os.Stat(filepath.Join(workspace, "skills", "broken-skill"))
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func TestInstallSkillToolRollsBackOnOriginMetadataWriteFailure(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&mockInstallRegistry{})
|
||||
tool := NewInstallSkillTool(registryMgr, workspace)
|
||||
|
||||
previousPersist := persistInstalledSkillOriginMeta
|
||||
persistInstalledSkillOriginMeta = func(string, skills.SkillRegistry, string, string) error {
|
||||
return assert.AnError
|
||||
}
|
||||
defer func() {
|
||||
persistInstalledSkillOriginMeta = previousPersist
|
||||
}()
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "rollback-skill",
|
||||
"registry": "clawhub",
|
||||
})
|
||||
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "failed to persist skill metadata")
|
||||
_, err := os.Stat(filepath.Join(workspace, "skills", "rollback-skill"))
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func TestInstallSkillToolForceReinstallRestoresPreviousSkillAfterDownloadFailure(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
skillDir := filepath.Join(workspace, "skills", "existing-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
oldContent := []byte("---\nname: existing-skill\ndescription: Existing skill\n---\n# Existing\n")
|
||||
require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), oldContent, 0o600))
|
||||
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&mockFailingInstallRegistry{})
|
||||
tool := NewInstallSkillTool(registryMgr, workspace)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "existing-skill",
|
||||
"registry": "clawhub",
|
||||
"force": true,
|
||||
})
|
||||
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "failed to install")
|
||||
|
||||
gotContent, err := os.ReadFile(filepath.Join(skillDir, "SKILL.md"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, oldContent, gotContent)
|
||||
}
|
||||
|
||||
func TestInstallSkillToolForceReinstallRestoresPreviousSkillAfterMetadataFailure(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
skillDir := filepath.Join(workspace, "skills", "existing-skill")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
oldContent := []byte("---\nname: existing-skill\ndescription: Existing skill\n---\n# Existing\n")
|
||||
require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), oldContent, 0o600))
|
||||
|
||||
registryMgr := skills.NewRegistryManager()
|
||||
registryMgr.AddRegistry(&mockInstallRegistry{})
|
||||
tool := NewInstallSkillTool(registryMgr, workspace)
|
||||
|
||||
previousPersist := persistInstalledSkillOriginMeta
|
||||
persistInstalledSkillOriginMeta = func(string, skills.SkillRegistry, string, string) error {
|
||||
return assert.AnError
|
||||
}
|
||||
defer func() {
|
||||
persistInstalledSkillOriginMeta = previousPersist
|
||||
}()
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"slug": "existing-skill",
|
||||
"registry": "clawhub",
|
||||
"force": true,
|
||||
})
|
||||
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "failed to persist skill metadata")
|
||||
|
||||
gotContent, err := os.ReadFile(filepath.Join(skillDir, "SKILL.md"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, oldContent, gotContent)
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
// FindSkillsTool allows the LLM agent to search for installable skills from registries.
|
||||
type FindSkillsTool struct {
|
||||
registryMgr *skills.RegistryManager
|
||||
cache *skills.SearchCache
|
||||
}
|
||||
|
||||
// NewFindSkillsTool creates a new FindSkillsTool.
|
||||
// registryMgr is the shared registry manager (built from config in createToolRegistry).
|
||||
// cache is the search cache for deduplicating similar queries.
|
||||
func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
|
||||
return &FindSkillsTool{
|
||||
registryMgr: registryMgr,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Name() string {
|
||||
return "find_skills"
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Description() string {
|
||||
return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill."
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')",
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (1-20, default 5)",
|
||||
"minimum": 1.0,
|
||||
"maximum": 20.0,
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
query, ok := args["query"].(string)
|
||||
query = strings.ToLower(strings.TrimSpace(query))
|
||||
if !ok || query == "" {
|
||||
return ErrorResult("query is required and must be a non-empty string")
|
||||
}
|
||||
|
||||
limit := 5
|
||||
if l, ok := args["limit"].(float64); ok {
|
||||
li := int(l)
|
||||
if li >= 1 && li <= 20 {
|
||||
limit = li
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache first.
|
||||
if t.cache != nil {
|
||||
if cached, hit := t.cache.Get(query); hit {
|
||||
return SilentResult(formatSearchResults(query, cached, true))
|
||||
}
|
||||
}
|
||||
|
||||
// Search all registries.
|
||||
results, err := t.registryMgr.SearchAll(ctx, query, limit)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("skill search failed: %v", err))
|
||||
}
|
||||
|
||||
// Cache the results.
|
||||
if t.cache != nil && len(results) > 0 {
|
||||
t.cache.Put(query, results)
|
||||
}
|
||||
|
||||
return SilentResult(formatSearchResults(query, results, false))
|
||||
}
|
||||
|
||||
func formatSearchResults(query string, results []skills.SearchResult, cached bool) string {
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No skills found for query: %q", query)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
source := ""
|
||||
if cached {
|
||||
source = " (cached)"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source))
|
||||
|
||||
for i, r := range results {
|
||||
sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug))
|
||||
if r.Version != "" {
|
||||
sb.WriteString(fmt.Sprintf(" v%s", r.Version))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName))
|
||||
if r.DisplayName != "" && r.DisplayName != r.Slug {
|
||||
sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName))
|
||||
}
|
||||
if r.Summary != "" {
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", r.Summary))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("Use install_skill with the slug to install a skill.")
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
func TestFindSkillsToolName(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
assert.Equal(t, "find_skills", tool.Name())
|
||||
}
|
||||
|
||||
func TestFindSkillsToolMissingQuery(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
assert.True(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "query is required")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolEmptyQuery(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": " ",
|
||||
})
|
||||
assert.True(t, result.IsError)
|
||||
}
|
||||
|
||||
func TestFindSkillsToolCacheHit(t *testing.T) {
|
||||
cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min
|
||||
cache.Put("github", []skills.SearchResult{
|
||||
{Slug: "github", Score: 0.9, RegistryName: "clawhub"},
|
||||
})
|
||||
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), cache)
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"query": "github",
|
||||
})
|
||||
|
||||
assert.False(t, result.IsError)
|
||||
assert.Contains(t, result.ForLLM, "github")
|
||||
assert.Contains(t, result.ForLLM, "cached")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolParameters(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
params := tool.Parameters()
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, props, "query")
|
||||
assert.Contains(t, props, "limit")
|
||||
|
||||
required, ok := params["required"].([]string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, required, "query")
|
||||
}
|
||||
|
||||
func TestFindSkillsToolDescription(t *testing.T) {
|
||||
tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
|
||||
assert.NotEmpty(t, tool.Description())
|
||||
assert.Contains(t, tool.Description(), "skill")
|
||||
}
|
||||
|
||||
func TestFormatSearchResultsEmpty(t *testing.T) {
|
||||
result := formatSearchResults("test query", nil, false)
|
||||
assert.Contains(t, result, "No skills found")
|
||||
}
|
||||
|
||||
func TestFormatSearchResultsWithData(t *testing.T) {
|
||||
results := []skills.SearchResult{
|
||||
{
|
||||
Slug: "github",
|
||||
Score: 0.95,
|
||||
DisplayName: "GitHub",
|
||||
Summary: "GitHub API integration",
|
||||
Version: "1.0.0",
|
||||
RegistryName: "clawhub",
|
||||
},
|
||||
}
|
||||
output := formatSearchResults("github", results, false)
|
||||
assert.Contains(t, output, "github")
|
||||
assert.Contains(t, output, "v1.0.0")
|
||||
assert.Contains(t, output, "0.950")
|
||||
assert.Contains(t, output, "clawhub")
|
||||
assert.Contains(t, output, "install_skill")
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package integrationtools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
type SendTTSTool struct {
|
||||
provider tts.TTSProvider
|
||||
mediaStore media.MediaStore
|
||||
}
|
||||
|
||||
func NewSendTTSTool(provider tts.TTSProvider, store media.MediaStore) *SendTTSTool {
|
||||
return &SendTTSTool{
|
||||
provider: provider,
|
||||
mediaStore: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Name() string { return "send_tts" }
|
||||
|
||||
func (t *SendTTSTool) Description() string {
|
||||
return "Synthesize speech from text and send it as an audio file to the user."
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to synthesize into speech. NOTE: Reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally.",
|
||||
},
|
||||
"filename": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional filename for the audio file (e.g., response.ogg).",
|
||||
},
|
||||
},
|
||||
"required": []string{"text"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
func (t *SendTTSTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
text, _ := args["text"].(string)
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ErrorResult("text is required")
|
||||
}
|
||||
|
||||
channel := ToolChannel(ctx)
|
||||
chatID := ToolChatID(ctx)
|
||||
filename, _ := args["filename"].(string)
|
||||
|
||||
ref, err := tts.SynthesizeAndStore(
|
||||
ctx,
|
||||
t.provider,
|
||||
t.mediaStore,
|
||||
text,
|
||||
filename,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return ErrorResult(err.Error()).WithError(err)
|
||||
}
|
||||
|
||||
// Return with ForUser set to original text, Media containing the audio ref,
|
||||
// and mark as ResponseHandled so the audio is sent immediately without LLM intervention.
|
||||
return &ToolResult{
|
||||
ForLLM: "TTS audio sent",
|
||||
ForUser: text,
|
||||
Media: []string{ref},
|
||||
ResponseHandled: true,
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user