diff --git a/config/config.example.json b/config/config.example.json index 7cd0ab8c6..5a0476052 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -130,6 +130,28 @@ }, "cron": { "exec_timeout_minutes": 5 + }, + "mcp": { + "enabled": false, + "servers": { + "filesystem": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp" + ], + "protocol": "mcp", + "env": {}, + "working_dir": "", + "init_timeout_seconds": 60, + "call_timeout_seconds": 30, + "max_response_bytes": 65536, + "include_tools": [], + "exclude_tools": [] + } + } } }, "heartbeat": { @@ -144,4 +166,4 @@ "host": "0.0.0.0", "port": 18790 } -} \ No newline at end of file +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index d3afa298e..368db47a8 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -23,6 +23,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/mcp" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/state" @@ -44,8 +45,12 @@ type AgentLoop struct { running atomic.Bool summarizing sync.Map // Tracks which sessions are currently being summarized channelManager *channels.Manager + mcpManager *mcp.Manager + mcpCloseOnce sync.Once } +const defaultWebFetchMaxChars = 50000 + // processOptions configures how a message is processed type processOptions struct { SessionKey string // Session identifier for history/context @@ -60,7 +65,14 @@ type processOptions struct { // createToolRegistry creates a tool registry with common tools. // This is shared between main agent and subagents. -func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { +func createToolRegistry( + workspace string, + restrict bool, + cfg *config.Config, + msgBus *bus.MessageBus, + mcpManager *mcp.Manager, + discoveredMCPTools []mcp.RegisteredTool, +) *tools.ToolRegistry { registry := tools.NewToolRegistry() // File system tools @@ -85,7 +97,9 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg }); searchTool != nil { registry.Register(searchTool) } - registry.Register(tools.NewWebFetchTool(50000)) + registry.Register(tools.NewWebFetchTool(defaultWebFetchMaxChars)) + + tools.RegisterKnownMCPTools(registry, mcpManager, discoveredMCPTools) // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms registry.Register(tools.NewI2CTool()) @@ -113,12 +127,35 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers restrict := cfg.Agents.Defaults.RestrictToWorkspace + var ( + mcpManager *mcp.Manager + discoveredMCPTools []mcp.RegisteredTool + ) + if cfg.Tools.MCP.Enabled { + bootstrap, err := bootstrapMCP(cfg.Tools.MCP) + if err != nil { + logger.WarnCF("agent", "MCP tool bootstrap failed", + map[string]interface{}{ + "error": err.Error(), + }) + } else if bootstrap != nil { + mcpManager = bootstrap.Manager + discoveredMCPTools = bootstrap.Tools + if len(discoveredMCPTools) > 0 { + logger.InfoCF("agent", "MCP tools registered", + map[string]interface{}{ + "count": len(discoveredMCPTools), + }) + } + } + } + // Create tool registry for main agent - toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools) // Create subagent manager with its own tool registry subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) - subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) + subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools) // Subagent doesn't need spawn/subagent tools to avoid recursion subagentManager.SetTools(subagentTools) @@ -151,11 +188,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers contextBuilder: contextBuilder, tools: toolsRegistry, summarizing: sync.Map{}, + mcpManager: mcpManager, } } func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) + defer al.closeMCP() for al.running.Load() { select { @@ -198,6 +237,22 @@ func (al *AgentLoop) Run(ctx context.Context) error { func (al *AgentLoop) Stop() { al.running.Store(false) + al.closeMCP() +} + +func (al *AgentLoop) closeMCP() { + if al.mcpManager == nil { + return + } + + al.mcpCloseOnce.Do(func() { + if err := al.mcpManager.Close(); err != nil { + logger.WarnCF("agent", "Failed to close MCP manager", + map[string]interface{}{ + "error": err.Error(), + }) + } + }) } func (al *AgentLoop) RegisterTool(tool tools.Tool) { diff --git a/pkg/agent/mcp_bootstrap.go b/pkg/agent/mcp_bootstrap.go new file mode 100644 index 000000000..40aa44fff --- /dev/null +++ b/pkg/agent/mcp_bootstrap.go @@ -0,0 +1,110 @@ +package agent + +import ( + "context" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/mcp" +) + +const ( + mcpBootstrapMinTimeout = 10 * time.Second + mcpBootstrapMaxTimeout = 5 * time.Minute + mcpBootstrapGraceTimeout = 5 * time.Second +) + +type mcpBootstrapResult struct { + Manager *mcp.Manager + Tools []mcp.RegisteredTool +} + +func bootstrapMCP(cfg config.MCPToolsConfig) (*mcpBootstrapResult, error) { + serverConfigs := buildMCPServerConfigs(cfg) + if len(serverConfigs) == 0 { + return nil, nil + } + + manager := mcp.NewManager(serverConfigs) + + discoveryTimeout := calculateMCPDiscoveryTimeout(serverConfigs) + discoveryCtx, cancel := context.WithTimeout(context.Background(), discoveryTimeout) + defer cancel() + + discoveredTools, err := manager.DiscoverTools(discoveryCtx) + if err != nil { + _ = manager.Close() + return nil, err + } + + return &mcpBootstrapResult{ + Manager: manager, + Tools: discoveredTools, + }, nil +} + +func calculateMCPDiscoveryTimeout(serverConfigs map[string]mcp.ServerConfig) time.Duration { + maxInitTimeout := mcpBootstrapMinTimeout + + for _, serverConfig := range serverConfigs { + initTimeout := serverConfig.InitTimeout() + if initTimeout > maxInitTimeout { + maxInitTimeout = initTimeout + } + } + + timeout := maxInitTimeout + mcpBootstrapGraceTimeout + if timeout < mcpBootstrapMinTimeout { + return mcpBootstrapMinTimeout + } + if timeout > mcpBootstrapMaxTimeout { + return mcpBootstrapMaxTimeout + } + return timeout +} + +func buildMCPServerConfigs(cfg config.MCPToolsConfig) map[string]mcp.ServerConfig { + servers := make(map[string]mcp.ServerConfig, len(cfg.Servers)) + + for serverName, serverCfg := range cfg.Servers { + if !serverCfg.Enabled { + continue + } + + envCopy := make(map[string]string, len(serverCfg.Env)) + for key, value := range serverCfg.Env { + envCopy[key] = value + } + + servers[serverName] = mcp.ServerConfig{ + Name: serverName, + Command: serverCfg.Command, + Args: append([]string{}, serverCfg.Args...), + Env: envCopy, + WorkingDir: serverCfg.WorkingDir, + Protocol: inferMCPProtocol(serverCfg.Protocol, serverCfg.Command), + InitTimeoutSeconds: serverCfg.InitTimeoutSeconds, + CallTimeoutSeconds: serverCfg.CallTimeoutSeconds, + MaxResponseBytes: serverCfg.MaxResponseBytes, + IncludeTools: append([]string{}, serverCfg.IncludeTools...), + ExcludeTools: append([]string{}, serverCfg.ExcludeTools...), + } + } + + return servers +} + +func inferMCPProtocol(configuredProtocol, command string) string { + if protocol := strings.TrimSpace(configuredProtocol); protocol != "" { + return protocol + } + + // Context7 currently emits JSON-RPC messages as JSONL on stdio, + // so defaulting avoids long startup waits when protocol is omitted. + if strings.Contains(strings.ToLower(command), "context7-mcp") { + return mcp.ProtocolJSONLines + } + + return "" +} diff --git a/pkg/agent/mcp_bootstrap_test.go b/pkg/agent/mcp_bootstrap_test.go new file mode 100644 index 000000000..2e22453ba --- /dev/null +++ b/pkg/agent/mcp_bootstrap_test.go @@ -0,0 +1,79 @@ +package agent + +import ( + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/mcp" +) + +func TestCalculateMCPDiscoveryTimeout_UsesMaxInitWithGrace(t *testing.T) { + serverConfigs := map[string]struct { + initSeconds int + }{ + "fast": {initSeconds: 5}, + "slow": {initSeconds: 60}, + } + + cfg := config.MCPToolsConfig{ + Enabled: true, + Servers: map[string]config.MCPServerConfig{ + "fast": { + Enabled: true, + Command: "fast", + InitTimeoutSeconds: serverConfigs["fast"].initSeconds, + }, + "slow": { + Enabled: true, + Command: "slow", + InitTimeoutSeconds: serverConfigs["slow"].initSeconds, + }, + }, + } + + mcpConfigs := buildMCPServerConfigs(cfg) + timeout := calculateMCPDiscoveryTimeout(mcpConfigs) + + want := 65 * time.Second + if timeout != want { + t.Fatalf("calculateMCPDiscoveryTimeout() = %v, want %v", timeout, want) + } +} + +func TestBuildMCPServerConfigs_SkipsDisabledServers(t *testing.T) { + cfg := config.MCPToolsConfig{ + Enabled: true, + Servers: map[string]config.MCPServerConfig{ + "context7": { + Enabled: true, + Command: "context7-mcp", + Protocol: "jsonl", + }, + "disabled": { + Enabled: false, + Command: "ignored", + }, + }, + } + + mcpConfigs := buildMCPServerConfigs(cfg) + if len(mcpConfigs) != 1 { + t.Fatalf("buildMCPServerConfigs() count = %d, want 1", len(mcpConfigs)) + } + + context7, ok := mcpConfigs["context7"] + if !ok { + t.Fatalf("context7 not found in buildMCPServerConfigs output") + } + if context7.Protocol != "jsonl" { + t.Fatalf("context7 protocol = %q, want jsonl", context7.Protocol) + } +} + +func TestInferMCPProtocol_Context7DefaultsToJSONL(t *testing.T) { + got := inferMCPProtocol("", "context7-mcp") + if got != mcp.ProtocolJSONLines { + t.Fatalf("inferMCPProtocol() = %q, want %s", got, mcp.ProtocolJSONLines) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 1d34f56f3..a20b5cc2d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "github.com/caarlos0/env/v11" @@ -51,7 +52,10 @@ type Config struct { Tools ToolsConfig `json:"tools"` Heartbeat HeartbeatConfig `json:"heartbeat"` Devices DevicesConfig `json:"devices"` - mu sync.RWMutex + // MCPServers is a compatibility alias for configs using top-level "mcpServers". + // Canonical config remains tools.mcp.servers. + MCPServers map[string]LegacyMCPServerConfig `json:"mcpServers,omitempty"` + mu sync.RWMutex } type AgentsConfig struct { @@ -222,9 +226,38 @@ type CronToolsConfig struct { ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout } +type MCPServerConfig struct { + Enabled bool `json:"enabled"` + Command string `json:"command"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + WorkingDir string `json:"working_dir"` + Protocol string `json:"protocol"` + InitTimeoutSeconds int `json:"init_timeout_seconds"` + CallTimeoutSeconds int `json:"call_timeout_seconds"` + MaxResponseBytes int `json:"max_response_bytes"` + IncludeTools []string `json:"include_tools"` + ExcludeTools []string `json:"exclude_tools"` +} + +type MCPToolsConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"` + Servers map[string]MCPServerConfig `json:"servers"` +} + +// LegacyMCPServerConfig supports compatibility with "mcpServers" style config. +type LegacyMCPServerConfig struct { + Type string `json:"type"` + Command string `json:"command"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + Protocol string `json:"protocol"` +} + type ToolsConfig struct { Web WebToolsConfig `json:"web"` Cron CronToolsConfig `json:"cron"` + MCP MCPToolsConfig `json:"mcp"` } func DefaultConfig() *Config { @@ -342,6 +375,10 @@ func DefaultConfig() *Config { Cron: CronToolsConfig{ ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations }, + MCP: MCPToolsConfig{ + Enabled: false, + Servers: map[string]MCPServerConfig{}, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, @@ -373,9 +410,53 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + cfg.applyLegacyMCPServers() + return cfg, nil } +func (c *Config) applyLegacyMCPServers() { + // If canonical MCP config already exists, keep it as source of truth. + if len(c.Tools.MCP.Servers) > 0 { + return + } + if len(c.MCPServers) == 0 { + return + } + + if c.Tools.MCP.Servers == nil { + c.Tools.MCP.Servers = map[string]MCPServerConfig{} + } + + for name, legacy := range c.MCPServers { + if strings.TrimSpace(legacy.Command) == "" { + continue + } + + enabled := true + if legacy.Type != "" && legacy.Type != "stdio" { + enabled = false + } + + envCopy := make(map[string]string, len(legacy.Env)) + for key, value := range legacy.Env { + envCopy[key] = value + } + + c.Tools.MCP.Servers[name] = MCPServerConfig{ + Enabled: enabled, + Command: legacy.Command, + Args: append([]string{}, legacy.Args...), + Env: envCopy, + Protocol: legacy.Protocol, + } + } + + if len(c.Tools.MCP.Servers) > 0 { + c.Tools.MCP.Enabled = true + } +} + func SaveConfig(path string, cfg *Config) error { cfg.mu.RLock() defer cfg.mu.RUnlock() diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index febfd0456..a517dbb5c 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -150,6 +150,17 @@ func TestDefaultConfig_WebTools(t *testing.T) { } } +func TestDefaultConfig_MCPTools(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Tools.MCP.Enabled { + t.Error("MCP tools should be disabled by default") + } + if cfg.Tools.MCP.Servers == nil { + t.Error("MCP servers map should be initialized") + } +} + func TestSaveConfig_FilePermissions(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("file permission bits are not enforced on Windows") @@ -204,3 +215,58 @@ func TestConfig_Complete(t *testing.T) { t.Error("Heartbeat should be enabled by default") } } + +func TestLoadConfig_LegacyMCPServersCompatibility(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + configJSON := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "test-model", + "max_tokens": 1024, + "temperature": 0.7, + "max_tool_iterations": 10 + } + }, + "mcpServers": { + "context7": { + "type": "stdio", + "protocol": "jsonl", + "command": "npx", + "args": ["-y", "@upstash/context7-mcp", "--api-key", "test-key"] + } + } + }` + + if err := os.WriteFile(configPath, []byte(configJSON), 0600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + if !cfg.Tools.MCP.Enabled { + t.Fatal("Tools.MCP should be enabled from legacy mcpServers") + } + + server, ok := cfg.Tools.MCP.Servers["context7"] + if !ok { + t.Fatal("context7 server not mapped from legacy mcpServers") + } + if !server.Enabled { + t.Fatal("context7 server should be enabled") + } + if server.Command != "npx" { + t.Fatalf("context7 command = %q, want npx", server.Command) + } + if server.Protocol != "jsonl" { + t.Fatalf("context7 protocol = %q, want jsonl", server.Protocol) + } + if len(server.Args) == 0 { + t.Fatal("context7 args should be mapped") + } +} diff --git a/pkg/mcp/client.go b/pkg/mcp/client.go new file mode 100644 index 000000000..66dda1f61 --- /dev/null +++ b/pkg/mcp/client.go @@ -0,0 +1,603 @@ +package mcp + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Client is the transport-agnostic MCP client contract. +type Client interface { + Start(ctx context.Context) error + ListTools(ctx context.Context) ([]RemoteTool, error) + CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) + Close() error +} + +// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers). +type StdioClient struct { + config ServerConfig + mode string + + mu sync.Mutex + writeMu sync.Mutex + + started bool + closed bool + + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser + waitCh chan struct{} + pending map[string]chan rpcResponse + + nextID uint64 +} + +type rpcRequest struct { + JSONRPC string `json:"jsonrpc"` + ID string `json:"id,omitempty"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type rpcResponseEnvelope struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *rpcError `json:"error,omitempty"` + Method string `json:"method,omitempty"` +} + +type rpcError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type rpcResponse struct { + result json.RawMessage + rpcErr *rpcError + err error +} + +type initializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]any `json:"capabilities"` + ClientInfo map[string]interface{} `json:"clientInfo"` +} + +func NewStdioClient(config ServerConfig) *StdioClient { + return &StdioClient{ + config: config, + mode: normalizeProtocol(config.Protocol), + } +} + +func (c *StdioClient) Start(ctx context.Context) error { + c.mu.Lock() + if c.started { + c.mu.Unlock() + return nil + } + if strings.TrimSpace(c.config.Command) == "" { + c.mu.Unlock() + return fmt.Errorf("mcp server %q command is empty", c.config.Name) + } + + cmd := exec.Command(c.config.Command, c.config.Args...) + if c.config.WorkingDir != "" { + cmd.Dir = c.config.WorkingDir + } + cmd.Env = buildProcessEnv(c.config.Env) + + stdin, err := cmd.StdinPipe() + if err != nil { + c.mu.Unlock() + return fmt.Errorf("create stdin pipe: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + c.mu.Unlock() + return fmt.Errorf("create stdout pipe: %w", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + c.mu.Unlock() + return fmt.Errorf("create stderr pipe: %w", err) + } + if err := cmd.Start(); err != nil { + c.mu.Unlock() + return fmt.Errorf("start process: %w", err) + } + + c.started = true + c.closed = false + c.cmd = cmd + c.stdin = stdin + c.stdout = stdout + c.stderr = stderr + c.waitCh = make(chan struct{}) + c.pending = make(map[string]chan rpcResponse) + c.mu.Unlock() + + go c.readLoop() + go c.waitLoop() + go c.drainStderr() + + initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout()) + defer cancel() + + _, err = c.request(initCtx, "initialize", initializeParams{ + ProtocolVersion: "2024-11-05", + Capabilities: map[string]any{ + "tools": map[string]any{}, + }, + ClientInfo: map[string]any{ + "name": "picoclaw", + "version": "dev", + }, + }) + if err != nil { + _ = c.Close() + return fmt.Errorf("initialize failed: %w", err) + } + + if err := c.notify("notifications/initialized", map[string]any{}); err != nil { + _ = c.Close() + return fmt.Errorf("initialized notification failed: %w", err) + } + + return nil +} + +func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) { + if err := c.Start(ctx); err != nil { + return nil, err + } + + type listToolsResponse struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]any `json:"inputSchema"` + } `json:"tools"` + NextCursor string `json:"nextCursor,omitempty"` + } + + allTools := make([]RemoteTool, 0, 8) + cursor := "" + + for page := 0; page < maxToolListPages; page++ { + params := map[string]any{} + if cursor != "" { + params["cursor"] = cursor + } + + callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout()) + raw, err := c.request(callCtx, "tools/list", params) + cancel() + if err != nil { + return nil, err + } + + var response listToolsResponse + if err := json.Unmarshal(raw, &response); err != nil { + return nil, fmt.Errorf("decode tools/list response: %w", err) + } + + for _, tool := range response.Tools { + allTools = append(allTools, RemoteTool{ + Name: tool.Name, + Description: tool.Description, + InputSchema: tool.InputSchema, + }) + } + + if response.NextCursor == "" { + return allTools, nil + } + cursor = response.NextCursor + } + + return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages) +} + +func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) { + if err := c.Start(ctx); err != nil { + return CallResult{}, err + } + + callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout()) + defer cancel() + + raw, err := c.request(callCtx, "tools/call", map[string]any{ + "name": toolName, + "arguments": arguments, + }) + if err != nil { + return CallResult{}, err + } + + return formatCallPayload(raw, c.config.ResponseLimit()) +} + +func (c *StdioClient) Close() error { + c.mu.Lock() + if !c.started || c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + cmd := c.cmd + stdin := c.stdin + waitCh := c.waitCh + c.mu.Unlock() + + c.failPending(errors.New("mcp client closed")) + + if stdin != nil { + _ = stdin.Close() + } + if cmd != nil && cmd.Process != nil { + _ = cmd.Process.Kill() + } + + if waitCh != nil { + select { + case <-waitCh: + case <-time.After(2 * time.Second): + } + } + return nil +} + +func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) { + id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10) + responseCh := make(chan rpcResponse, 1) + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, fmt.Errorf("mcp server %q is closed", c.config.Name) + } + c.pending[id] = responseCh + c.mu.Unlock() + + req := rpcRequest{ + JSONRPC: "2.0", + ID: id, + Method: method, + Params: params, + } + if err := c.writeMessage(req); err != nil { + c.removePending(id) + return nil, err + } + + select { + case <-ctx.Done(): + c.removePending(id) + return nil, ctx.Err() + case response := <-responseCh: + if response.err != nil { + return nil, response.err + } + if response.rpcErr != nil { + return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message) + } + return response.result, nil + } +} + +func (c *StdioClient) notify(method string, params any) error { + req := rpcRequest{ + JSONRPC: "2.0", + Method: method, + Params: params, + } + return c.writeMessage(req) +} + +func (c *StdioClient) writeMessage(payload any) error { + c.mu.Lock() + if c.closed || c.stdin == nil { + c.mu.Unlock() + return fmt.Errorf("mcp server %q is not writable", c.config.Name) + } + stdin := c.stdin + c.mu.Unlock() + + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshal json-rpc payload: %w", err) + } + + if c.mode == ProtocolJSONLines { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if _, err := stdin.Write(append(data, '\n')); err != nil { + return fmt.Errorf("write jsonl body: %w", err) + } + return nil + } + + frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) + + c.writeMu.Lock() + defer c.writeMu.Unlock() + + if _, err := io.WriteString(stdin, frameHeader); err != nil { + return fmt.Errorf("write frame header: %w", err) + } + if _, err := stdin.Write(data); err != nil { + return fmt.Errorf("write frame body: %w", err) + } + return nil +} + +func (c *StdioClient) readLoop() { + if c.mode == ProtocolJSONLines { + c.readJSONLLoop() + return + } + + c.readMCPFrameLoop() +} + +func (c *StdioClient) readMCPFrameLoop() { + reader := bufio.NewReader(c.stdout) + + for { + payload, err := readFramePayload(reader) + if err != nil { + c.failPending(err) + return + } + + var envelope rpcResponseEnvelope + if err := json.Unmarshal(payload, &envelope); err != nil { + continue + } + c.dispatchResponse(envelope) + } +} + +func (c *StdioClient) readJSONLLoop() { + scanner := bufio.NewScanner(c.stdout) + scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var envelope rpcResponseEnvelope + if err := json.Unmarshal([]byte(line), &envelope); err != nil { + continue + } + c.dispatchResponse(envelope) + } + + if err := scanner.Err(); err != nil { + c.failPending(err) + return + } + c.failPending(io.EOF) +} + +func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) { + if len(envelope.ID) == 0 { + return + } + + id, ok := parseRPCID(envelope.ID) + if !ok { + return + } + + c.mu.Lock() + responseCh := c.pending[id] + if responseCh != nil { + delete(c.pending, id) + } + c.mu.Unlock() + + if responseCh == nil { + return + } + + response := rpcResponse{ + result: envelope.Result, + rpcErr: envelope.Error, + } + select { + case responseCh <- response: + default: + } +} + +func (c *StdioClient) waitLoop() { + c.mu.Lock() + cmd := c.cmd + waitCh := c.waitCh + serverName := c.config.Name + c.mu.Unlock() + + if cmd == nil { + if waitCh != nil { + close(waitCh) + } + return + } + + err := cmd.Wait() + if waitCh != nil { + close(waitCh) + } + if err != nil { + logger.WarnCF("mcp", "MCP process exited with error", + map[string]any{ + "server": serverName, + "error": err.Error(), + }) + } +} + +func (c *StdioClient) drainStderr() { + c.mu.Lock() + stderr := c.stderr + serverName := c.config.Name + c.mu.Unlock() + + if stderr == nil { + return + } + + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + logger.DebugCF("mcp", "MCP server stderr", + map[string]any{ + "server": serverName, + "line": line, + }) + } +} + +func (c *StdioClient) failPending(err error) { + c.mu.Lock() + pending := c.pending + c.pending = make(map[string]chan rpcResponse) + c.mu.Unlock() + + if len(pending) == 0 { + return + } + + for _, ch := range pending { + select { + case ch <- rpcResponse{err: err}: + default: + } + } +} + +func (c *StdioClient) removePending(id string) { + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() +} + +func readFramePayload(reader *bufio.Reader) ([]byte, error) { + contentLength := -1 + + for { + line, err := reader.ReadString('\n') + if err != nil { + return nil, err + } + trimmed := strings.TrimRight(line, "\r\n") + if trimmed == "" { + break + } + + parts := strings.SplitN(trimmed, ":", 2) + if len(parts) != 2 { + continue + } + headerName := strings.TrimSpace(strings.ToLower(parts[0])) + if headerName != "content-length" { + continue + } + value := strings.TrimSpace(parts[1]) + length, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("invalid content-length %q: %w", value, err) + } + contentLength = length + } + + if contentLength <= 0 { + return nil, fmt.Errorf("missing content-length") + } + if contentLength > maxFrameBytes { + return nil, fmt.Errorf("frame too large (%d bytes)", contentLength) + } + + payload := make([]byte, contentLength) + if _, err := io.ReadFull(reader, payload); err != nil { + return nil, err + } + return payload, nil +} + +func parseRPCID(raw json.RawMessage) (string, bool) { + var stringID string + if err := json.Unmarshal(raw, &stringID); err == nil { + return stringID, true + } + + var numberID float64 + if err := json.Unmarshal(raw, &numberID); err == nil { + return strconv.FormatInt(int64(numberID), 10), true + } + + return "", false +} + +func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if _, hasDeadline := parent.Deadline(); hasDeadline { + return context.WithCancel(parent) + } + return context.WithTimeout(parent, timeout) +} + +func buildProcessEnv(custom map[string]string) []string { + base := os.Environ() + if len(custom) == 0 { + return base + } + + keys := make([]string, 0, len(custom)) + for key := range custom { + keys = append(keys, key) + } + sort.Strings(keys) + + env := make([]string, 0, len(base)+len(keys)) + env = append(env, base...) + for _, key := range keys { + env = append(env, key+"="+custom[key]) + } + return env +} + +func normalizeProtocol(protocol string) string { + switch strings.ToLower(strings.TrimSpace(protocol)) { + case "", ProtocolMCPFrames: + return ProtocolMCPFrames + case ProtocolJSONLines: + return ProtocolJSONLines + default: + return ProtocolMCPFrames + } +} diff --git a/pkg/mcp/client_test.go b/pkg/mcp/client_test.go new file mode 100644 index 000000000..5411f93e1 --- /dev/null +++ b/pkg/mcp/client_test.go @@ -0,0 +1,23 @@ +package mcp + +import "testing" + +func TestNormalizeProtocol(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "", want: ProtocolMCPFrames}, + {input: "mcp", want: ProtocolMCPFrames}, + {input: "jsonl", want: ProtocolJSONLines}, + {input: "JSONL", want: ProtocolJSONLines}, + {input: "unknown", want: ProtocolMCPFrames}, + } + + for _, tt := range tests { + got := normalizeProtocol(tt.input) + if got != tt.want { + t.Fatalf("normalizeProtocol(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} diff --git a/pkg/mcp/format.go b/pkg/mcp/format.go new file mode 100644 index 000000000..e96389389 --- /dev/null +++ b/pkg/mcp/format.go @@ -0,0 +1,61 @@ +package mcp + +import ( + "encoding/json" + "strings" +) + +type callResponse struct { + Content []contentBlock `json:"content"` + StructuredContent any `json:"structuredContent,omitempty"` + IsError bool `json:"isError,omitempty"` +} + +type contentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +func formatCallPayload(raw json.RawMessage, responseLimit int) (CallResult, error) { + var payload callResponse + if err := json.Unmarshal(raw, &payload); err != nil { + // Fallback for servers that return non-standard payloads. + return CallResult{ + Content: truncateString(strings.TrimSpace(string(raw)), responseLimit), + IsError: false, + }, nil + } + + parts := make([]string, 0, len(payload.Content)+1) + for _, block := range payload.Content { + if block.Type == "text" && strings.TrimSpace(block.Text) != "" { + parts = append(parts, block.Text) + } + } + + if payload.StructuredContent != nil { + if encoded, err := json.Marshal(payload.StructuredContent); err == nil { + parts = append(parts, string(encoded)) + } + } + + content := strings.TrimSpace(strings.Join(parts, "\n")) + if content == "" { + content = "{}" + } + + return CallResult{ + Content: truncateString(content, responseLimit), + IsError: payload.IsError, + }, nil +} + +func truncateString(value string, maxBytes int) string { + if maxBytes <= 0 || len(value) <= maxBytes { + return value + } + if maxBytes <= 12 { + return value[:maxBytes] + } + return value[:maxBytes-12] + "\n...[truncated]" +} diff --git a/pkg/mcp/format_test.go b/pkg/mcp/format_test.go new file mode 100644 index 000000000..cf4e4bcb4 --- /dev/null +++ b/pkg/mcp/format_test.go @@ -0,0 +1,52 @@ +package mcp + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestFormatCallPayload_TextAndStructured(t *testing.T) { + raw := json.RawMessage(`{ + "content":[{"type":"text","text":"hello"}], + "structuredContent":{"ok":true} + }`) + + result, err := formatCallPayload(raw, 4096) + if err != nil { + t.Fatalf("formatCallPayload() error = %v", err) + } + if result.IsError { + t.Fatalf("expected IsError=false") + } + if !strings.Contains(result.Content, "hello") { + t.Fatalf("expected content to contain text block, got %q", result.Content) + } + if !strings.Contains(result.Content, `"ok":true`) { + t.Fatalf("expected content to contain structured content, got %q", result.Content) + } +} + +func TestFormatCallPayload_Truncates(t *testing.T) { + raw := json.RawMessage(`{"content":[{"type":"text","text":"abcdefghijklmnopqrstuvwxyz"}]}`) + + result, err := formatCallPayload(raw, 12) + if err != nil { + t.Fatalf("formatCallPayload() error = %v", err) + } + if len(result.Content) != 12 { + t.Fatalf("expected truncated length 12, got %d", len(result.Content)) + } +} + +func TestFormatCallPayload_RespectsIsError(t *testing.T) { + raw := json.RawMessage(`{"content":[{"type":"text","text":"failed"}],"isError":true}`) + + result, err := formatCallPayload(raw, 4096) + if err != nil { + t.Fatalf("formatCallPayload() error = %v", err) + } + if !result.IsError { + t.Fatalf("expected IsError=true") + } +} diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go new file mode 100644 index 000000000..16edf1744 --- /dev/null +++ b/pkg/mcp/manager.go @@ -0,0 +1,190 @@ +package mcp + +import ( + "context" + "fmt" + "slices" + "strings" + "sync" +) + +type clientFactory func(config ServerConfig) Client + +type managedServer struct { + config ServerConfig + client Client +} + +// Manager owns MCP servers and maps discovered MCP tools to PicoClaw tools. +type Manager struct { + mu sync.RWMutex + + servers map[string]*managedServer + tools map[string]RegisteredTool + + discovered bool + newClient clientFactory +} + +func NewManager(configs map[string]ServerConfig) *Manager { + servers := make(map[string]*managedServer, len(configs)) + for name, cfg := range configs { + copied := cfg + copied.Name = name + servers[name] = &managedServer{config: copied} + } + return &Manager{ + servers: servers, + tools: make(map[string]RegisteredTool), + discovered: false, + newClient: func(config ServerConfig) Client { + return NewStdioClient(config) + }, + } +} + +// DiscoverTools starts configured MCP servers and returns discovered tool metadata. +func (m *Manager) DiscoverTools(ctx context.Context) ([]RegisteredTool, error) { + m.mu.Lock() + if m.discovered { + tools := toolsFromMap(m.tools) + m.mu.Unlock() + return tools, nil + } + + discoveryErrors := make([]string, 0) + + for serverName, server := range m.servers { + client := m.newClient(server.config) + if err := client.Start(ctx); err != nil { + discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err)) + continue + } + + remoteTools, err := client.ListTools(ctx) + if err != nil { + _ = client.Close() + discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err)) + continue + } + + server.client = client + for _, remoteTool := range remoteTools { + if !isToolAllowed(remoteTool.Name, server.config.IncludeTools, server.config.ExcludeTools) { + continue + } + + qualifiedName := m.makeUniqueToolName(serverName, remoteTool.Name) + parameters := normalizeSchema(remoteTool.InputSchema) + m.tools[qualifiedName] = RegisteredTool{ + QualifiedName: qualifiedName, + ServerName: serverName, + ToolName: remoteTool.Name, + Description: remoteTool.Description, + Parameters: parameters, + } + } + } + + m.discovered = true + tools := toolsFromMap(m.tools) + m.mu.Unlock() + + if len(tools) == 0 && len(discoveryErrors) > 0 { + return nil, fmt.Errorf("mcp tool discovery failed: %s", strings.Join(discoveryErrors, "; ")) + } + return tools, nil +} + +func (m *Manager) CallTool(ctx context.Context, qualifiedName string, args map[string]any) (CallResult, error) { + m.mu.RLock() + tool, ok := m.tools[qualifiedName] + if !ok { + m.mu.RUnlock() + return CallResult{}, fmt.Errorf("mcp tool %q not found", qualifiedName) + } + + server := m.servers[tool.ServerName] + if server == nil || server.client == nil { + m.mu.RUnlock() + return CallResult{}, fmt.Errorf("mcp server %q is not active", tool.ServerName) + } + client := server.client + toolName := tool.ToolName + m.mu.RUnlock() + + if args == nil { + args = map[string]any{} + } + return client.CallTool(ctx, toolName, args) +} + +func (m *Manager) Close() error { + m.mu.Lock() + servers := make([]*managedServer, 0, len(m.servers)) + for _, server := range m.servers { + servers = append(servers, server) + } + m.mu.Unlock() + + var firstErr error + for _, server := range servers { + if server.client == nil { + continue + } + if err := server.client.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +func (m *Manager) makeUniqueToolName(serverName, toolName string) string { + base := QualifiedToolName(serverName, toolName) + if _, exists := m.tools[base]; !exists { + return base + } + + for index := 2; ; index++ { + candidate := fmt.Sprintf("%s_%d", base, index) + if len(candidate) > qualifiedNameMaxLen { + overflow := len(candidate) - qualifiedNameMaxLen + if overflow < len(base) { + candidate = base[:len(base)-overflow] + fmt.Sprintf("_%d", index) + } else { + candidate = candidate[:qualifiedNameMaxLen] + } + } + if _, exists := m.tools[candidate]; !exists { + return candidate + } + } +} + +func normalizeSchema(schema map[string]any) map[string]any { + if len(schema) == 0 { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + return schema +} + +func isToolAllowed(name string, include, exclude []string) bool { + if len(include) > 0 && !slices.Contains(include, name) { + return false + } + if slices.Contains(exclude, name) { + return false + } + return true +} + +func toolsFromMap(tools map[string]RegisteredTool) []RegisteredTool { + out := make([]RegisteredTool, 0, len(tools)) + for _, tool := range tools { + out = append(out, tool) + } + return out +} diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go new file mode 100644 index 000000000..3560a7533 --- /dev/null +++ b/pkg/mcp/manager_test.go @@ -0,0 +1,101 @@ +package mcp + +import ( + "context" + "testing" +) + +type fakeClient struct { + tools []RemoteTool + callResult CallResult + callErr error + + lastToolName string + lastArgs map[string]any +} + +func (f *fakeClient) Start(_ context.Context) error { return nil } +func (f *fakeClient) ListTools(_ context.Context) ([]RemoteTool, error) { + return f.tools, nil +} +func (f *fakeClient) CallTool(_ context.Context, toolName string, arguments map[string]any) (CallResult, error) { + f.lastToolName = toolName + f.lastArgs = arguments + if f.callErr != nil { + return CallResult{}, f.callErr + } + return f.callResult, nil +} +func (f *fakeClient) Close() error { return nil } + +func TestManager_DiscoverTools_FilterAndCall(t *testing.T) { + serverCfg := map[string]ServerConfig{ + "Local Dev": { + Command: "fake", + IncludeTools: []string{"alpha", "beta"}, + ExcludeTools: []string{"beta"}, + }, + } + manager := NewManager(serverCfg) + + client := &fakeClient{ + tools: []RemoteTool{ + {Name: "alpha", Description: "tool alpha"}, + {Name: "beta", Description: "tool beta"}, + {Name: "gamma", Description: "tool gamma"}, + }, + callResult: CallResult{Content: "ok"}, + } + manager.newClient = func(_ ServerConfig) Client { + return client + } + + tools, err := manager.DiscoverTools(context.Background()) + if err != nil { + t.Fatalf("DiscoverTools() error = %v", err) + } + if len(tools) != 1 { + t.Fatalf("DiscoverTools() returned %d tools, want 1", len(tools)) + } + + tool := tools[0] + if tool.ToolName != "alpha" { + t.Fatalf("discovered tool = %q, want alpha", tool.ToolName) + } + + result, err := manager.CallTool(context.Background(), tool.QualifiedName, map[string]any{"x": 1}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if result.Content != "ok" { + t.Fatalf("CallTool() content = %q, want ok", result.Content) + } + if client.lastToolName != "alpha" { + t.Fatalf("called MCP tool = %q, want alpha", client.lastToolName) + } +} + +func TestManager_NormalizeEmptySchema(t *testing.T) { + serverCfg := map[string]ServerConfig{ + "srv": {Command: "fake"}, + } + manager := NewManager(serverCfg) + manager.newClient = func(_ ServerConfig) Client { + return &fakeClient{ + tools: []RemoteTool{{Name: "empty_schema", InputSchema: nil}}, + } + } + + tools, err := manager.DiscoverTools(context.Background()) + if err != nil { + t.Fatalf("DiscoverTools() error = %v", err) + } + if len(tools) != 1 { + t.Fatalf("DiscoverTools() returned %d tools, want 1", len(tools)) + } + + parameters := tools[0].Parameters + if parameters["type"] != "object" { + t.Fatalf("normalized schema type = %v, want object", parameters["type"]) + } +} diff --git a/pkg/mcp/naming.go b/pkg/mcp/naming.go new file mode 100644 index 000000000..1a20895a3 --- /dev/null +++ b/pkg/mcp/naming.go @@ -0,0 +1,53 @@ +package mcp + +import "strings" + +const qualifiedNameMaxLen = 64 + +// QualifiedToolName creates a stable, provider-safe function name. +func QualifiedToolName(serverName, toolName string) string { + prefix := "mcp_" + sanitizeName(serverName) + "__" + tool := sanitizeName(toolName) + maxToolLen := qualifiedNameMaxLen - len(prefix) + if maxToolLen <= 0 { + return prefix[:qualifiedNameMaxLen] + } + if len(tool) > maxToolLen { + tool = tool[:maxToolLen] + } + return prefix + tool +} + +func sanitizeName(value string) string { + trimmed := strings.TrimSpace(strings.ToLower(value)) + if trimmed == "" { + return "unknown" + } + + var b strings.Builder + b.Grow(len(trimmed)) + + lastUnderscore := false + for i := 0; i < len(trimmed); i++ { + ch := trimmed[i] + isAlphaNum := (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') + if isAlphaNum { + b.WriteByte(ch) + lastUnderscore = false + continue + } + if !lastUnderscore { + b.WriteByte('_') + lastUnderscore = true + } + } + + s := strings.Trim(b.String(), "_") + if s == "" { + s = "unknown" + } + if s[0] >= '0' && s[0] <= '9' { + return "t_" + s + } + return s +} diff --git a/pkg/mcp/naming_test.go b/pkg/mcp/naming_test.go new file mode 100644 index 000000000..f23aefe5a --- /dev/null +++ b/pkg/mcp/naming_test.go @@ -0,0 +1,19 @@ +package mcp + +import "testing" + +func TestQualifiedToolName_SanitizesAndPrefixes(t *testing.T) { + got := QualifiedToolName("My Server", "Read-File!") + want := "mcp_my_server__read_file" + if got != want { + t.Fatalf("QualifiedToolName() = %q, want %q", got, want) + } +} + +func TestQualifiedToolName_TrimToMaxLen(t *testing.T) { + longToolName := "tool_name_with_many_segments_and_extra_text_that_exceeds_the_limit_significantly" + got := QualifiedToolName("server", longToolName) + if len(got) > qualifiedNameMaxLen { + t.Fatalf("qualified name length = %d, want <= %d", len(got), qualifiedNameMaxLen) + } +} diff --git a/pkg/mcp/types.go b/pkg/mcp/types.go new file mode 100644 index 000000000..1a1e95ea8 --- /dev/null +++ b/pkg/mcp/types.go @@ -0,0 +1,77 @@ +package mcp + +import "time" + +const ( + defaultInitTimeoutSeconds = 60 + defaultCallTimeoutSeconds = 30 + defaultMaxResponseBytes = 64 * 1024 + defaultScannerBufferBytes = 64 * 1024 + maxFrameBytes = 2 * 1024 * 1024 + maxToolListPages = 50 +) + +const ( + ProtocolMCPFrames = "mcp" + ProtocolJSONLines = "jsonl" +) + +// ServerConfig defines one MCP server connection. +type ServerConfig struct { + Name string + Command string + Args []string + Env map[string]string + WorkingDir string + Protocol string + InitTimeoutSeconds int + CallTimeoutSeconds int + MaxResponseBytes int + IncludeTools []string + ExcludeTools []string +} + +func (c ServerConfig) InitTimeout() time.Duration { + seconds := c.InitTimeoutSeconds + if seconds <= 0 { + seconds = defaultInitTimeoutSeconds + } + return time.Duration(seconds) * time.Second +} + +func (c ServerConfig) CallTimeout() time.Duration { + seconds := c.CallTimeoutSeconds + if seconds <= 0 { + seconds = defaultCallTimeoutSeconds + } + return time.Duration(seconds) * time.Second +} + +func (c ServerConfig) ResponseLimit() int { + if c.MaxResponseBytes <= 0 { + return defaultMaxResponseBytes + } + return c.MaxResponseBytes +} + +// RemoteTool is an MCP tool discovered from a server. +type RemoteTool struct { + Name string + Description string + InputSchema map[string]any +} + +// RegisteredTool is a discovered tool with a PicoClaw-facing qualified name. +type RegisteredTool struct { + QualifiedName string + ServerName string + ToolName string + Description string + Parameters map[string]any +} + +// CallResult is a normalized MCP tool call result. +type CallResult struct { + Content string + IsError bool +} diff --git a/pkg/tools/mcp.go b/pkg/tools/mcp.go new file mode 100644 index 000000000..518998844 --- /dev/null +++ b/pkg/tools/mcp.go @@ -0,0 +1,85 @@ +package tools + +import ( + "context" + "errors" + "fmt" + + "github.com/sipeed/picoclaw/pkg/mcp" +) + +type MCPTool struct { + manager *mcp.Manager + name string + description string + parameters map[string]any +} + +func NewMCPTool(manager *mcp.Manager, tool mcp.RegisteredTool) *MCPTool { + description := tool.Description + if description == "" { + description = fmt.Sprintf("MCP tool %s from server %s", tool.ToolName, tool.ServerName) + } + + return &MCPTool{ + manager: manager, + name: tool.QualifiedName, + description: description, + parameters: tool.Parameters, + } +} + +func (t *MCPTool) Name() string { + return t.name +} + +func (t *MCPTool) Description() string { + return t.description +} + +func (t *MCPTool) Parameters() map[string]interface{} { + return t.parameters +} + +func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + if t.manager == nil { + return ErrorResult("MCP manager is not configured") + } + + result, err := t.manager.CallTool(ctx, t.name, args) + if err != nil { + return ErrorResult(fmt.Sprintf("MCP tool %s failed: %v", t.name, err)).WithError(err) + } + if result.IsError { + err := errors.New(result.Content) + return ErrorResult(result.Content).WithError(err) + } + return SilentResult(result.Content) +} + +// RegisterMCPTools discovers tools from MCP servers and registers them into the registry. +func RegisterMCPTools(ctx context.Context, registry *ToolRegistry, manager *mcp.Manager) (int, error) { + if registry == nil || manager == nil { + return 0, nil + } + + discoveredTools, err := manager.DiscoverTools(ctx) + if err != nil { + return 0, err + } + + return RegisterKnownMCPTools(registry, manager, discoveredTools), nil +} + +// RegisterKnownMCPTools registers already-discovered MCP tools. +// This avoids repeated discovery work when multiple registries share one manager. +func RegisterKnownMCPTools(registry *ToolRegistry, manager *mcp.Manager, discoveredTools []mcp.RegisteredTool) int { + if registry == nil || manager == nil || len(discoveredTools) == 0 { + return 0 + } + + for _, tool := range discoveredTools { + registry.Register(NewMCPTool(manager, tool)) + } + return len(discoveredTools) +} diff --git a/pkg/tools/mcp_test.go b/pkg/tools/mcp_test.go new file mode 100644 index 000000000..ee846a041 --- /dev/null +++ b/pkg/tools/mcp_test.go @@ -0,0 +1,44 @@ +package tools + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/mcp" +) + +func TestRegisterKnownMCPTools_RegistersAllTools(t *testing.T) { + registry := NewToolRegistry() + manager := &mcp.Manager{} + discovered := []mcp.RegisteredTool{ + { + QualifiedName: "mcp_context7__resolve_library_id", + ServerName: "context7", + ToolName: "resolve-library-id", + Description: "Resolve library ID", + Parameters: map[string]any{ + "type": "object", + }, + }, + { + QualifiedName: "mcp_context7__query_docs", + ServerName: "context7", + ToolName: "query-docs", + Description: "Query docs", + Parameters: map[string]any{ + "type": "object", + }, + }, + } + + count := RegisterKnownMCPTools(registry, manager, discovered) + if count != 2 { + t.Fatalf("RegisterKnownMCPTools count = %d, want 2", count) + } + + if _, ok := registry.Get("mcp_context7__resolve_library_id"); !ok { + t.Fatalf("expected mcp_context7__resolve_library_id to be registered") + } + if _, ok := registry.Get("mcp_context7__query_docs"); !ok { + t.Fatalf("expected mcp_context7__query_docs to be registered") + } +}