Files
picoclaw/pkg/tools/integration/mcp_tool.go
T

611 lines
16 KiB
Go

package integrationtools
import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"os"
"path/filepath"
"strings"
"time"
"unicode/utf8"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
toolshared "github.com/sipeed/picoclaw/pkg/tools/shared"
)
// MCPManager defines the interface for MCP manager operations
// This allows for easier testing with mock implementations
type MCPManager interface {
CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error)
}
// MCPTool wraps an MCP tool to implement the Tool interface
type MCPTool struct {
manager MCPManager
serverName string
tool *mcp.Tool
mediaStore media.MediaStore
workspace string
maxInlineTextRunes int
}
// NewMCPTool creates a new MCP tool wrapper
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
return &MCPTool{
manager: manager,
serverName: serverName,
tool: tool,
maxInlineTextRunes: maxMCPInlineTextRunes,
}
}
func (t *MCPTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
func (t *MCPTool) SetWorkspace(workspace string) {
t.workspace = strings.TrimSpace(workspace)
}
func (t *MCPTool) SetMaxInlineTextRunes(limit int) {
if limit > 0 {
t.maxInlineTextRunes = limit
}
}
const maxMCPInlineTextRunes = 16 * 1024
// sanitizeIdentifierComponent normalizes a string so it can be safely used
// as part of a tool/function identifier for downstream providers.
// It:
// - lowercases the string
// - replaces any character not in [a-z0-9_-] with '_'
// - collapses multiple consecutive '_' into a single '_'
// - trims leading/trailing '_'
// - falls back to "unnamed" if the result is empty
// - truncates overly long components to a reasonable length
func sanitizeIdentifierComponent(s string) string {
const maxLen = 64
s = strings.ToLower(s)
var b strings.Builder
b.Grow(len(s))
prevUnderscore := false
for _, r := range s {
isAllowed := (r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' || r == '-'
if !isAllowed {
// Normalize any disallowed character to '_'
if !prevUnderscore {
b.WriteRune('_')
prevUnderscore = true
}
continue
}
if r == '_' {
if prevUnderscore {
continue
}
prevUnderscore = true
} else {
prevUnderscore = false
}
b.WriteRune(r)
}
result := strings.Trim(b.String(), "_")
if result == "" {
result = "unnamed"
}
if len(result) > maxLen {
result = result[:maxLen]
}
return result
}
// Name returns the tool name, prefixed with the server name.
// The total length is capped at 64 characters (OpenAI-compatible API limit).
// A short hash of the original (unsanitized) server and tool names is appended
// whenever sanitization is lossy or the name is truncated, ensuring that two
// names which differ only in disallowed characters remain distinct after sanitization.
func (t *MCPTool) Name() string {
// Prefix with server name to avoid conflicts, and sanitize components
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
strings.ToLower(t.tool.Name) == sanitizedTool
const maxTotal = 64
if lossless && len(full) <= maxTotal {
return full
}
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
// (not the sanitized names) so different originals always yield different hashes.
h := fnv.New32a()
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
base := full
if len(base) > maxTotal-9 {
base = strings.TrimRight(full[:maxTotal-9], "_")
}
return base + "_" + suffix
}
// Description returns the tool description
func (t *MCPTool) Description() string {
desc := t.tool.Description
if desc == "" {
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
}
// Add server info to description
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
}
func (t *MCPTool) PromptMetadata() toolshared.PromptMetadata {
return toolshared.PromptMetadata{
Layer: toolshared.ToolPromptLayerCapability,
Slot: toolshared.ToolPromptSlotMCP,
Source: "mcp:" + sanitizeIdentifierComponent(t.serverName),
}
}
// Parameters returns the tool parameters schema
func (t *MCPTool) Parameters() map[string]any {
// The InputSchema is already a JSON Schema object
schema := t.tool.InputSchema
// Handle nil schema
if schema == nil {
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// Try direct conversion first (fast path)
if schemaMap, ok := schema.(map[string]any); ok {
return schemaMap
}
// Handle json.RawMessage and []byte - unmarshal directly
var jsonData []byte
if rawMsg, ok := schema.(json.RawMessage); ok {
jsonData = rawMsg
} else if bytes, ok := schema.([]byte); ok {
jsonData = bytes
}
if jsonData != nil {
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err == nil {
return result
}
// Fallback on error
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
// For other types (structs, etc.), convert via JSON marshal/unmarshal
var err error
jsonData, err = json.Marshal(schema)
if err != nil {
// Fallback to empty schema if marshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
var result map[string]any
if err := json.Unmarshal(jsonData, &result); err != nil {
// Fallback to empty schema if unmarshaling fails
return map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []string{},
}
}
return result
}
// Execute executes the MCP tool
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
if err != nil {
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
}
if result == nil {
nilErr := fmt.Errorf("MCP tool returned nil result without error")
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
}
// Handle error result from server
if result.IsError {
errMsg := extractContentText(result.Content)
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
}
return t.normalizeResultContent(ctx, result.Content)
}
// extractContentText extracts text from MCP content array
func extractContentText(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
parts = append(parts, sanitizeToolLLMContent(v.Text))
case *mcp.ImageContent:
parts = append(parts, fmt.Sprintf("[Image: %s]", normalizedMIMEType(v.MIMEType)))
case *mcp.AudioContent:
parts = append(parts, fmt.Sprintf("[Audio: %s]", normalizedMIMEType(v.MIMEType)))
case *mcp.ResourceLink:
parts = append(parts, summarizeResourceLink(v))
case *mcp.EmbeddedResource:
parts = append(parts, summarizeEmbeddedResource(v))
default:
// For other content types, use string representation
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
}
}
return sanitizeToolLLMContent(strings.Join(parts, "\n"))
}
func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult {
llmParts := make([]string, 0, len(content))
rawTextParts := make([]string, 0, len(content))
mediaRefs := make([]string, 0, len(content))
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
rawText := strings.TrimSpace(v.Text)
if rawText != "" {
rawTextParts = append(rawTextParts, rawText)
}
safeText := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
if safeText != "" {
llmParts = append(llmParts, safeText)
}
case *mcp.ImageContent:
ref, note := t.storeBinaryContent(
ctx,
"image",
normalizedMIMEType(v.MIMEType),
v.Data,
v.Annotations,
)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if note != "" {
llmParts = append(llmParts, note)
}
case *mcp.AudioContent:
ref, note := t.storeBinaryContent(
ctx,
"audio",
normalizedMIMEType(v.MIMEType),
v.Data,
v.Annotations,
)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if note != "" {
llmParts = append(llmParts, note)
}
case *mcp.ResourceLink:
llmParts = append(llmParts, summarizeResourceLink(v))
case *mcp.EmbeddedResource:
ref, note, rawText := t.storeEmbeddedResource(ctx, v)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
if rawText != "" {
rawTextParts = append(rawTextParts, rawText)
}
if note != "" {
llmParts = append(llmParts, note)
}
default:
llmParts = append(llmParts, fmt.Sprintf("[MCP returned unsupported content type %T]", v))
}
}
forLLM := strings.Join(compactStrings(llmParts), "\n")
rawText := strings.Join(compactStrings(rawTextParts), "\n")
if artifactResult := t.persistLargeTextArtifact(rawText); artifactResult != nil {
artifactResult.Media = mediaRefs
return artifactResult
}
result := &ToolResult{
ForLLM: forLLM,
Media: mediaRefs,
}
return result
}
func (t *MCPTool) persistLargeTextArtifact(text string) *ToolResult {
text = strings.TrimSpace(text)
limit := t.maxInlineTextRunes
if limit <= 0 {
limit = maxMCPInlineTextRunes
}
size := utf8.RuneCountInString(text)
if text == "" || size <= limit || t.workspace == "" {
return nil
}
dir := filepath.Join(t.workspace, ".artifacts", "mcp")
if err := os.MkdirAll(dir, 0o700); err != nil {
return t.largeTextArtifactFallback(text, err)
}
// TODO: Add lifecycle cleanup/retention for MCP artifact files.
pattern := fmt.Sprintf(
"%s_%s_*.txt",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
)
tmpFile, err := os.CreateTemp(dir, pattern)
if err != nil {
return t.largeTextArtifactFallback(text, err)
}
path := tmpFile.Name()
if _, err = tmpFile.WriteString(text); err != nil {
_ = tmpFile.Close()
_ = os.Remove(path)
return t.largeTextArtifactFallback(text, err)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(path)
return t.largeTextArtifactFallback(text, err)
}
return &ToolResult{
ForLLM: fmt.Sprintf(
"[MCP returned a large text result (%d chars); omitted from model context and saved as a local artifact.]",
size,
),
ArtifactTags: []string{"[file:" + path + "]"},
}
}
func (t *MCPTool) largeTextArtifactFallback(text string, err error) *ToolResult {
size := utf8.RuneCountInString(text)
logger.WarnCF("tool", "Failed to persist large MCP text artifact", map[string]any{
"server": t.serverName,
"tool": t.tool.Name,
"chars": size,
"error": err.Error(),
})
return &ToolResult{
ForLLM: fmt.Sprintf(
"[MCP returned a large text result (%d chars); omitted from model context because artifact persistence failed.]",
size,
),
}
}
func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string, string) {
if content == nil || content.Resource == nil {
return "", "[MCP returned an embedded resource without data.]", ""
}
resource := content.Resource
if len(resource.Blob) > 0 {
ref, note := t.storeBinaryContent(
ctx,
"resource",
normalizedMIMEType(resource.MIMEType),
resource.Blob,
content.Annotations,
)
return ref, note, ""
}
rawText := strings.TrimSpace(resource.Text)
if rawText != "" {
return "", sanitizeToolLLMContent(resource.Text), rawText
}
return "", summarizeEmbeddedResource(content), ""
}
func (t *MCPTool) storeBinaryContent(
ctx context.Context,
kind string,
mimeType string,
data []byte,
annotations *mcp.Annotations,
) (string, string) {
if len(data) == 0 {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it was empty.]", kind, mimeType)
}
if !annotationsAllowUser(annotations) {
return "", fmt.Sprintf(
"[MCP returned %s content (%s) for non-user audience; omitted from model context.]",
kind,
mimeType,
)
}
if t.mediaStore == nil {
return "", fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context because media delivery is unavailable.]",
kind,
mimeType,
)
}
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
if channel == "" || chatID == "" {
return "", fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context because no target chat was available.]",
kind,
mimeType,
)
}
dir := media.TempDir()
if err := os.MkdirAll(dir, 0o700); err != nil {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
ext := extensionForMIMEType(mimeType)
tmpFile, err := os.CreateTemp(dir, "mcp-*"+ext)
if err != nil {
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
tmpPath := tmpFile.Name()
if _, err = tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
if err = tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf("[MCP returned %s content (%s) but it could not be stored.]", kind, mimeType)
}
scope := fmt.Sprintf(
"tool:mcp:%s:%s:%s:%d",
sanitizeIdentifierComponent(t.serverName),
channel,
chatID,
time.Now().UnixNano(),
)
filename := fmt.Sprintf(
"%s_%s%s",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
ext,
)
ref, err := t.mediaStore.Store(tmpPath, media.MediaMeta{
Filename: filename,
ContentType: mimeType,
Source: fmt.Sprintf(
"tool:mcp:%s:%s",
sanitizeIdentifierComponent(t.serverName),
sanitizeIdentifierComponent(t.tool.Name),
),
}, scope)
if err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Sprintf(
"[MCP returned %s content (%s) but it could not be registered as media.]",
kind,
mimeType,
)
}
return ref, fmt.Sprintf(
"[MCP returned %s content (%s); omitted from model context and stored as a local media artifact.]",
kind,
mimeType,
)
}
func summarizeResourceLink(content *mcp.ResourceLink) string {
if content == nil {
return "[MCP returned an empty resource link.]"
}
parts := []string{"[MCP returned resource link"}
if content.Name != "" {
parts = append(parts, fmt.Sprintf("name=%q", content.Name))
}
if content.URI != "" {
parts = append(parts, fmt.Sprintf("uri=%q", content.URI))
}
if content.MIMEType != "" {
parts = append(parts, fmt.Sprintf("mime=%q", content.MIMEType))
}
if content.Description != "" {
desc := strings.TrimSpace(content.Description)
if len(desc) > 200 {
desc = desc[:200] + "..."
}
parts = append(parts, fmt.Sprintf("description=%q", desc))
}
return strings.Join(parts, ", ") + "]"
}
func summarizeEmbeddedResource(content *mcp.EmbeddedResource) string {
if content == nil || content.Resource == nil {
return "[MCP returned an embedded resource.]"
}
resource := content.Resource
if resource.URI != "" {
return fmt.Sprintf(
"[MCP returned embedded resource %q (%s).]",
resource.URI,
normalizedMIMEType(resource.MIMEType),
)
}
return fmt.Sprintf("[MCP returned embedded resource (%s).]", normalizedMIMEType(resource.MIMEType))
}
func annotationsAllowUser(annotations *mcp.Annotations) bool {
if annotations == nil || len(annotations.Audience) == 0 {
return true
}
for _, audience := range annotations.Audience {
if strings.EqualFold(string(audience), "user") {
return true
}
}
return false
}
func normalizedMIMEType(mimeType string) string {
if strings.TrimSpace(mimeType) == "" {
return "application/octet-stream"
}
return mimeType
}
func compactStrings(parts []string) []string {
compact := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) == "" {
continue
}
compact = append(compact, part)
}
return compact
}