feat: add MCP integration with context7 compatibility

This commit is contained in:
Danieldd28
2026-02-18 00:35:35 +07:00
parent 4fde0175cf
commit 403e048821
17 changed files with 1727 additions and 6 deletions
+23 -1
View File
@@ -130,6 +130,28 @@
},
"cron": {
"exec_timeout_minutes": 5
},
"mcp": {
"enabled": false,
"servers": {
"filesystem": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-filesystem",
"/tmp"
],
"protocol": "mcp",
"env": {},
"working_dir": "",
"init_timeout_seconds": 60,
"call_timeout_seconds": 30,
"max_response_bytes": 65536,
"include_tools": [],
"exclude_tools": []
}
}
}
},
"heartbeat": {
@@ -144,4 +166,4 @@
"host": "0.0.0.0",
"port": 18790
}
}
}
+59 -4
View File
@@ -23,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/mcp"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/state"
@@ -44,8 +45,12 @@ type AgentLoop struct {
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
channelManager *channels.Manager
mcpManager *mcp.Manager
mcpCloseOnce sync.Once
}
const defaultWebFetchMaxChars = 50000
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
@@ -60,7 +65,14 @@ type processOptions struct {
// createToolRegistry creates a tool registry with common tools.
// This is shared between main agent and subagents.
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
func createToolRegistry(
workspace string,
restrict bool,
cfg *config.Config,
msgBus *bus.MessageBus,
mcpManager *mcp.Manager,
discoveredMCPTools []mcp.RegisteredTool,
) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
// File system tools
@@ -85,7 +97,9 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg
}); searchTool != nil {
registry.Register(searchTool)
}
registry.Register(tools.NewWebFetchTool(50000))
registry.Register(tools.NewWebFetchTool(defaultWebFetchMaxChars))
tools.RegisterKnownMCPTools(registry, mcpManager, discoveredMCPTools)
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
registry.Register(tools.NewI2CTool())
@@ -113,12 +127,35 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
restrict := cfg.Agents.Defaults.RestrictToWorkspace
var (
mcpManager *mcp.Manager
discoveredMCPTools []mcp.RegisteredTool
)
if cfg.Tools.MCP.Enabled {
bootstrap, err := bootstrapMCP(cfg.Tools.MCP)
if err != nil {
logger.WarnCF("agent", "MCP tool bootstrap failed",
map[string]interface{}{
"error": err.Error(),
})
} else if bootstrap != nil {
mcpManager = bootstrap.Manager
discoveredMCPTools = bootstrap.Tools
if len(discoveredMCPTools) > 0 {
logger.InfoCF("agent", "MCP tools registered",
map[string]interface{}{
"count": len(discoveredMCPTools),
})
}
}
}
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
@@ -151,11 +188,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
contextBuilder: contextBuilder,
tools: toolsRegistry,
summarizing: sync.Map{},
mcpManager: mcpManager,
}
}
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
defer al.closeMCP()
for al.running.Load() {
select {
@@ -198,6 +237,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
func (al *AgentLoop) Stop() {
al.running.Store(false)
al.closeMCP()
}
func (al *AgentLoop) closeMCP() {
if al.mcpManager == nil {
return
}
al.mcpCloseOnce.Do(func() {
if err := al.mcpManager.Close(); err != nil {
logger.WarnCF("agent", "Failed to close MCP manager",
map[string]interface{}{
"error": err.Error(),
})
}
})
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
+110
View File
@@ -0,0 +1,110 @@
package agent
import (
"context"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/mcp"
)
const (
mcpBootstrapMinTimeout = 10 * time.Second
mcpBootstrapMaxTimeout = 5 * time.Minute
mcpBootstrapGraceTimeout = 5 * time.Second
)
type mcpBootstrapResult struct {
Manager *mcp.Manager
Tools []mcp.RegisteredTool
}
func bootstrapMCP(cfg config.MCPToolsConfig) (*mcpBootstrapResult, error) {
serverConfigs := buildMCPServerConfigs(cfg)
if len(serverConfigs) == 0 {
return nil, nil
}
manager := mcp.NewManager(serverConfigs)
discoveryTimeout := calculateMCPDiscoveryTimeout(serverConfigs)
discoveryCtx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
defer cancel()
discoveredTools, err := manager.DiscoverTools(discoveryCtx)
if err != nil {
_ = manager.Close()
return nil, err
}
return &mcpBootstrapResult{
Manager: manager,
Tools: discoveredTools,
}, nil
}
func calculateMCPDiscoveryTimeout(serverConfigs map[string]mcp.ServerConfig) time.Duration {
maxInitTimeout := mcpBootstrapMinTimeout
for _, serverConfig := range serverConfigs {
initTimeout := serverConfig.InitTimeout()
if initTimeout > maxInitTimeout {
maxInitTimeout = initTimeout
}
}
timeout := maxInitTimeout + mcpBootstrapGraceTimeout
if timeout < mcpBootstrapMinTimeout {
return mcpBootstrapMinTimeout
}
if timeout > mcpBootstrapMaxTimeout {
return mcpBootstrapMaxTimeout
}
return timeout
}
func buildMCPServerConfigs(cfg config.MCPToolsConfig) map[string]mcp.ServerConfig {
servers := make(map[string]mcp.ServerConfig, len(cfg.Servers))
for serverName, serverCfg := range cfg.Servers {
if !serverCfg.Enabled {
continue
}
envCopy := make(map[string]string, len(serverCfg.Env))
for key, value := range serverCfg.Env {
envCopy[key] = value
}
servers[serverName] = mcp.ServerConfig{
Name: serverName,
Command: serverCfg.Command,
Args: append([]string{}, serverCfg.Args...),
Env: envCopy,
WorkingDir: serverCfg.WorkingDir,
Protocol: inferMCPProtocol(serverCfg.Protocol, serverCfg.Command),
InitTimeoutSeconds: serverCfg.InitTimeoutSeconds,
CallTimeoutSeconds: serverCfg.CallTimeoutSeconds,
MaxResponseBytes: serverCfg.MaxResponseBytes,
IncludeTools: append([]string{}, serverCfg.IncludeTools...),
ExcludeTools: append([]string{}, serverCfg.ExcludeTools...),
}
}
return servers
}
func inferMCPProtocol(configuredProtocol, command string) string {
if protocol := strings.TrimSpace(configuredProtocol); protocol != "" {
return protocol
}
// Context7 currently emits JSON-RPC messages as JSONL on stdio,
// so defaulting avoids long startup waits when protocol is omitted.
if strings.Contains(strings.ToLower(command), "context7-mcp") {
return mcp.ProtocolJSONLines
}
return ""
}
+79
View File
@@ -0,0 +1,79 @@
package agent
import (
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/mcp"
)
func TestCalculateMCPDiscoveryTimeout_UsesMaxInitWithGrace(t *testing.T) {
serverConfigs := map[string]struct {
initSeconds int
}{
"fast": {initSeconds: 5},
"slow": {initSeconds: 60},
}
cfg := config.MCPToolsConfig{
Enabled: true,
Servers: map[string]config.MCPServerConfig{
"fast": {
Enabled: true,
Command: "fast",
InitTimeoutSeconds: serverConfigs["fast"].initSeconds,
},
"slow": {
Enabled: true,
Command: "slow",
InitTimeoutSeconds: serverConfigs["slow"].initSeconds,
},
},
}
mcpConfigs := buildMCPServerConfigs(cfg)
timeout := calculateMCPDiscoveryTimeout(mcpConfigs)
want := 65 * time.Second
if timeout != want {
t.Fatalf("calculateMCPDiscoveryTimeout() = %v, want %v", timeout, want)
}
}
func TestBuildMCPServerConfigs_SkipsDisabledServers(t *testing.T) {
cfg := config.MCPToolsConfig{
Enabled: true,
Servers: map[string]config.MCPServerConfig{
"context7": {
Enabled: true,
Command: "context7-mcp",
Protocol: "jsonl",
},
"disabled": {
Enabled: false,
Command: "ignored",
},
},
}
mcpConfigs := buildMCPServerConfigs(cfg)
if len(mcpConfigs) != 1 {
t.Fatalf("buildMCPServerConfigs() count = %d, want 1", len(mcpConfigs))
}
context7, ok := mcpConfigs["context7"]
if !ok {
t.Fatalf("context7 not found in buildMCPServerConfigs output")
}
if context7.Protocol != "jsonl" {
t.Fatalf("context7 protocol = %q, want jsonl", context7.Protocol)
}
}
func TestInferMCPProtocol_Context7DefaultsToJSONL(t *testing.T) {
got := inferMCPProtocol("", "context7-mcp")
if got != mcp.ProtocolJSONLines {
t.Fatalf("inferMCPProtocol() = %q, want %s", got, mcp.ProtocolJSONLines)
}
}
+82 -1
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/caarlos0/env/v11"
@@ -51,7 +52,10 @@ type Config struct {
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
Devices DevicesConfig `json:"devices"`
mu sync.RWMutex
// MCPServers is a compatibility alias for configs using top-level "mcpServers".
// Canonical config remains tools.mcp.servers.
MCPServers map[string]LegacyMCPServerConfig `json:"mcpServers,omitempty"`
mu sync.RWMutex
}
type AgentsConfig struct {
@@ -222,9 +226,38 @@ type CronToolsConfig struct {
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
}
type MCPServerConfig struct {
Enabled bool `json:"enabled"`
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
WorkingDir string `json:"working_dir"`
Protocol string `json:"protocol"`
InitTimeoutSeconds int `json:"init_timeout_seconds"`
CallTimeoutSeconds int `json:"call_timeout_seconds"`
MaxResponseBytes int `json:"max_response_bytes"`
IncludeTools []string `json:"include_tools"`
ExcludeTools []string `json:"exclude_tools"`
}
type MCPToolsConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
Servers map[string]MCPServerConfig `json:"servers"`
}
// LegacyMCPServerConfig supports compatibility with "mcpServers" style config.
type LegacyMCPServerConfig struct {
Type string `json:"type"`
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
Protocol string `json:"protocol"`
}
type ToolsConfig struct {
Web WebToolsConfig `json:"web"`
Cron CronToolsConfig `json:"cron"`
MCP MCPToolsConfig `json:"mcp"`
}
func DefaultConfig() *Config {
@@ -342,6 +375,10 @@ func DefaultConfig() *Config {
Cron: CronToolsConfig{
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
},
MCP: MCPToolsConfig{
Enabled: false,
Servers: map[string]MCPServerConfig{},
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
@@ -373,9 +410,53 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
cfg.applyLegacyMCPServers()
return cfg, nil
}
func (c *Config) applyLegacyMCPServers() {
// If canonical MCP config already exists, keep it as source of truth.
if len(c.Tools.MCP.Servers) > 0 {
return
}
if len(c.MCPServers) == 0 {
return
}
if c.Tools.MCP.Servers == nil {
c.Tools.MCP.Servers = map[string]MCPServerConfig{}
}
for name, legacy := range c.MCPServers {
if strings.TrimSpace(legacy.Command) == "" {
continue
}
enabled := true
if legacy.Type != "" && legacy.Type != "stdio" {
enabled = false
}
envCopy := make(map[string]string, len(legacy.Env))
for key, value := range legacy.Env {
envCopy[key] = value
}
c.Tools.MCP.Servers[name] = MCPServerConfig{
Enabled: enabled,
Command: legacy.Command,
Args: append([]string{}, legacy.Args...),
Env: envCopy,
Protocol: legacy.Protocol,
}
}
if len(c.Tools.MCP.Servers) > 0 {
c.Tools.MCP.Enabled = true
}
}
func SaveConfig(path string, cfg *Config) error {
cfg.mu.RLock()
defer cfg.mu.RUnlock()
+66
View File
@@ -150,6 +150,17 @@ func TestDefaultConfig_WebTools(t *testing.T) {
}
}
func TestDefaultConfig_MCPTools(t *testing.T) {
cfg := DefaultConfig()
if cfg.Tools.MCP.Enabled {
t.Error("MCP tools should be disabled by default")
}
if cfg.Tools.MCP.Servers == nil {
t.Error("MCP servers map should be initialized")
}
}
func TestSaveConfig_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
@@ -204,3 +215,58 @@ func TestConfig_Complete(t *testing.T) {
t.Error("Heartbeat should be enabled by default")
}
}
func TestLoadConfig_LegacyMCPServersCompatibility(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
configJSON := `{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "test-model",
"max_tokens": 1024,
"temperature": 0.7,
"max_tool_iterations": 10
}
},
"mcpServers": {
"context7": {
"type": "stdio",
"protocol": "jsonl",
"command": "npx",
"args": ["-y", "@upstash/context7-mcp", "--api-key", "test-key"]
}
}
}`
if err := os.WriteFile(configPath, []byte(configJSON), 0600); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig failed: %v", err)
}
if !cfg.Tools.MCP.Enabled {
t.Fatal("Tools.MCP should be enabled from legacy mcpServers")
}
server, ok := cfg.Tools.MCP.Servers["context7"]
if !ok {
t.Fatal("context7 server not mapped from legacy mcpServers")
}
if !server.Enabled {
t.Fatal("context7 server should be enabled")
}
if server.Command != "npx" {
t.Fatalf("context7 command = %q, want npx", server.Command)
}
if server.Protocol != "jsonl" {
t.Fatalf("context7 protocol = %q, want jsonl", server.Protocol)
}
if len(server.Args) == 0 {
t.Fatal("context7 args should be mapped")
}
}
+603
View File
@@ -0,0 +1,603 @@
package mcp
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
// Client is the transport-agnostic MCP client contract.
type Client interface {
Start(ctx context.Context) error
ListTools(ctx context.Context) ([]RemoteTool, error)
CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error)
Close() error
}
// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers).
type StdioClient struct {
config ServerConfig
mode string
mu sync.Mutex
writeMu sync.Mutex
started bool
closed bool
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.ReadCloser
waitCh chan struct{}
pending map[string]chan rpcResponse
nextID uint64
}
type rpcRequest struct {
JSONRPC string `json:"jsonrpc"`
ID string `json:"id,omitempty"`
Method string `json:"method"`
Params any `json:"params,omitempty"`
}
type rpcResponseEnvelope struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *rpcError `json:"error,omitempty"`
Method string `json:"method,omitempty"`
}
type rpcError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type rpcResponse struct {
result json.RawMessage
rpcErr *rpcError
err error
}
type initializeParams struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]any `json:"capabilities"`
ClientInfo map[string]interface{} `json:"clientInfo"`
}
func NewStdioClient(config ServerConfig) *StdioClient {
return &StdioClient{
config: config,
mode: normalizeProtocol(config.Protocol),
}
}
func (c *StdioClient) Start(ctx context.Context) error {
c.mu.Lock()
if c.started {
c.mu.Unlock()
return nil
}
if strings.TrimSpace(c.config.Command) == "" {
c.mu.Unlock()
return fmt.Errorf("mcp server %q command is empty", c.config.Name)
}
cmd := exec.Command(c.config.Command, c.config.Args...)
if c.config.WorkingDir != "" {
cmd.Dir = c.config.WorkingDir
}
cmd.Env = buildProcessEnv(c.config.Env)
stdin, err := cmd.StdinPipe()
if err != nil {
c.mu.Unlock()
return fmt.Errorf("create stdin pipe: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
c.mu.Unlock()
return fmt.Errorf("create stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
c.mu.Unlock()
return fmt.Errorf("create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
c.mu.Unlock()
return fmt.Errorf("start process: %w", err)
}
c.started = true
c.closed = false
c.cmd = cmd
c.stdin = stdin
c.stdout = stdout
c.stderr = stderr
c.waitCh = make(chan struct{})
c.pending = make(map[string]chan rpcResponse)
c.mu.Unlock()
go c.readLoop()
go c.waitLoop()
go c.drainStderr()
initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout())
defer cancel()
_, err = c.request(initCtx, "initialize", initializeParams{
ProtocolVersion: "2024-11-05",
Capabilities: map[string]any{
"tools": map[string]any{},
},
ClientInfo: map[string]any{
"name": "picoclaw",
"version": "dev",
},
})
if err != nil {
_ = c.Close()
return fmt.Errorf("initialize failed: %w", err)
}
if err := c.notify("notifications/initialized", map[string]any{}); err != nil {
_ = c.Close()
return fmt.Errorf("initialized notification failed: %w", err)
}
return nil
}
func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) {
if err := c.Start(ctx); err != nil {
return nil, err
}
type listToolsResponse struct {
Tools []struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"inputSchema"`
} `json:"tools"`
NextCursor string `json:"nextCursor,omitempty"`
}
allTools := make([]RemoteTool, 0, 8)
cursor := ""
for page := 0; page < maxToolListPages; page++ {
params := map[string]any{}
if cursor != "" {
params["cursor"] = cursor
}
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
raw, err := c.request(callCtx, "tools/list", params)
cancel()
if err != nil {
return nil, err
}
var response listToolsResponse
if err := json.Unmarshal(raw, &response); err != nil {
return nil, fmt.Errorf("decode tools/list response: %w", err)
}
for _, tool := range response.Tools {
allTools = append(allTools, RemoteTool{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.InputSchema,
})
}
if response.NextCursor == "" {
return allTools, nil
}
cursor = response.NextCursor
}
return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages)
}
func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) {
if err := c.Start(ctx); err != nil {
return CallResult{}, err
}
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
defer cancel()
raw, err := c.request(callCtx, "tools/call", map[string]any{
"name": toolName,
"arguments": arguments,
})
if err != nil {
return CallResult{}, err
}
return formatCallPayload(raw, c.config.ResponseLimit())
}
func (c *StdioClient) Close() error {
c.mu.Lock()
if !c.started || c.closed {
c.mu.Unlock()
return nil
}
c.closed = true
cmd := c.cmd
stdin := c.stdin
waitCh := c.waitCh
c.mu.Unlock()
c.failPending(errors.New("mcp client closed"))
if stdin != nil {
_ = stdin.Close()
}
if cmd != nil && cmd.Process != nil {
_ = cmd.Process.Kill()
}
if waitCh != nil {
select {
case <-waitCh:
case <-time.After(2 * time.Second):
}
}
return nil
}
func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) {
id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10)
responseCh := make(chan rpcResponse, 1)
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, fmt.Errorf("mcp server %q is closed", c.config.Name)
}
c.pending[id] = responseCh
c.mu.Unlock()
req := rpcRequest{
JSONRPC: "2.0",
ID: id,
Method: method,
Params: params,
}
if err := c.writeMessage(req); err != nil {
c.removePending(id)
return nil, err
}
select {
case <-ctx.Done():
c.removePending(id)
return nil, ctx.Err()
case response := <-responseCh:
if response.err != nil {
return nil, response.err
}
if response.rpcErr != nil {
return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message)
}
return response.result, nil
}
}
func (c *StdioClient) notify(method string, params any) error {
req := rpcRequest{
JSONRPC: "2.0",
Method: method,
Params: params,
}
return c.writeMessage(req)
}
func (c *StdioClient) writeMessage(payload any) error {
c.mu.Lock()
if c.closed || c.stdin == nil {
c.mu.Unlock()
return fmt.Errorf("mcp server %q is not writable", c.config.Name)
}
stdin := c.stdin
c.mu.Unlock()
data, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal json-rpc payload: %w", err)
}
if c.mode == ProtocolJSONLines {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if _, err := stdin.Write(append(data, '\n')); err != nil {
return fmt.Errorf("write jsonl body: %w", err)
}
return nil
}
frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
c.writeMu.Lock()
defer c.writeMu.Unlock()
if _, err := io.WriteString(stdin, frameHeader); err != nil {
return fmt.Errorf("write frame header: %w", err)
}
if _, err := stdin.Write(data); err != nil {
return fmt.Errorf("write frame body: %w", err)
}
return nil
}
func (c *StdioClient) readLoop() {
if c.mode == ProtocolJSONLines {
c.readJSONLLoop()
return
}
c.readMCPFrameLoop()
}
func (c *StdioClient) readMCPFrameLoop() {
reader := bufio.NewReader(c.stdout)
for {
payload, err := readFramePayload(reader)
if err != nil {
c.failPending(err)
return
}
var envelope rpcResponseEnvelope
if err := json.Unmarshal(payload, &envelope); err != nil {
continue
}
c.dispatchResponse(envelope)
}
}
func (c *StdioClient) readJSONLLoop() {
scanner := bufio.NewScanner(c.stdout)
scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var envelope rpcResponseEnvelope
if err := json.Unmarshal([]byte(line), &envelope); err != nil {
continue
}
c.dispatchResponse(envelope)
}
if err := scanner.Err(); err != nil {
c.failPending(err)
return
}
c.failPending(io.EOF)
}
func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) {
if len(envelope.ID) == 0 {
return
}
id, ok := parseRPCID(envelope.ID)
if !ok {
return
}
c.mu.Lock()
responseCh := c.pending[id]
if responseCh != nil {
delete(c.pending, id)
}
c.mu.Unlock()
if responseCh == nil {
return
}
response := rpcResponse{
result: envelope.Result,
rpcErr: envelope.Error,
}
select {
case responseCh <- response:
default:
}
}
func (c *StdioClient) waitLoop() {
c.mu.Lock()
cmd := c.cmd
waitCh := c.waitCh
serverName := c.config.Name
c.mu.Unlock()
if cmd == nil {
if waitCh != nil {
close(waitCh)
}
return
}
err := cmd.Wait()
if waitCh != nil {
close(waitCh)
}
if err != nil {
logger.WarnCF("mcp", "MCP process exited with error",
map[string]any{
"server": serverName,
"error": err.Error(),
})
}
}
func (c *StdioClient) drainStderr() {
c.mu.Lock()
stderr := c.stderr
serverName := c.config.Name
c.mu.Unlock()
if stderr == nil {
return
}
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
logger.DebugCF("mcp", "MCP server stderr",
map[string]any{
"server": serverName,
"line": line,
})
}
}
func (c *StdioClient) failPending(err error) {
c.mu.Lock()
pending := c.pending
c.pending = make(map[string]chan rpcResponse)
c.mu.Unlock()
if len(pending) == 0 {
return
}
for _, ch := range pending {
select {
case ch <- rpcResponse{err: err}:
default:
}
}
}
func (c *StdioClient) removePending(id string) {
c.mu.Lock()
delete(c.pending, id)
c.mu.Unlock()
}
func readFramePayload(reader *bufio.Reader) ([]byte, error) {
contentLength := -1
for {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
trimmed := strings.TrimRight(line, "\r\n")
if trimmed == "" {
break
}
parts := strings.SplitN(trimmed, ":", 2)
if len(parts) != 2 {
continue
}
headerName := strings.TrimSpace(strings.ToLower(parts[0]))
if headerName != "content-length" {
continue
}
value := strings.TrimSpace(parts[1])
length, err := strconv.Atoi(value)
if err != nil {
return nil, fmt.Errorf("invalid content-length %q: %w", value, err)
}
contentLength = length
}
if contentLength <= 0 {
return nil, fmt.Errorf("missing content-length")
}
if contentLength > maxFrameBytes {
return nil, fmt.Errorf("frame too large (%d bytes)", contentLength)
}
payload := make([]byte, contentLength)
if _, err := io.ReadFull(reader, payload); err != nil {
return nil, err
}
return payload, nil
}
func parseRPCID(raw json.RawMessage) (string, bool) {
var stringID string
if err := json.Unmarshal(raw, &stringID); err == nil {
return stringID, true
}
var numberID float64
if err := json.Unmarshal(raw, &numberID); err == nil {
return strconv.FormatInt(int64(numberID), 10), true
}
return "", false
}
func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if _, hasDeadline := parent.Deadline(); hasDeadline {
return context.WithCancel(parent)
}
return context.WithTimeout(parent, timeout)
}
func buildProcessEnv(custom map[string]string) []string {
base := os.Environ()
if len(custom) == 0 {
return base
}
keys := make([]string, 0, len(custom))
for key := range custom {
keys = append(keys, key)
}
sort.Strings(keys)
env := make([]string, 0, len(base)+len(keys))
env = append(env, base...)
for _, key := range keys {
env = append(env, key+"="+custom[key])
}
return env
}
func normalizeProtocol(protocol string) string {
switch strings.ToLower(strings.TrimSpace(protocol)) {
case "", ProtocolMCPFrames:
return ProtocolMCPFrames
case ProtocolJSONLines:
return ProtocolJSONLines
default:
return ProtocolMCPFrames
}
}
+23
View File
@@ -0,0 +1,23 @@
package mcp
import "testing"
func TestNormalizeProtocol(t *testing.T) {
tests := []struct {
input string
want string
}{
{input: "", want: ProtocolMCPFrames},
{input: "mcp", want: ProtocolMCPFrames},
{input: "jsonl", want: ProtocolJSONLines},
{input: "JSONL", want: ProtocolJSONLines},
{input: "unknown", want: ProtocolMCPFrames},
}
for _, tt := range tests {
got := normalizeProtocol(tt.input)
if got != tt.want {
t.Fatalf("normalizeProtocol(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
+61
View File
@@ -0,0 +1,61 @@
package mcp
import (
"encoding/json"
"strings"
)
type callResponse struct {
Content []contentBlock `json:"content"`
StructuredContent any `json:"structuredContent,omitempty"`
IsError bool `json:"isError,omitempty"`
}
type contentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
func formatCallPayload(raw json.RawMessage, responseLimit int) (CallResult, error) {
var payload callResponse
if err := json.Unmarshal(raw, &payload); err != nil {
// Fallback for servers that return non-standard payloads.
return CallResult{
Content: truncateString(strings.TrimSpace(string(raw)), responseLimit),
IsError: false,
}, nil
}
parts := make([]string, 0, len(payload.Content)+1)
for _, block := range payload.Content {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
parts = append(parts, block.Text)
}
}
if payload.StructuredContent != nil {
if encoded, err := json.Marshal(payload.StructuredContent); err == nil {
parts = append(parts, string(encoded))
}
}
content := strings.TrimSpace(strings.Join(parts, "\n"))
if content == "" {
content = "{}"
}
return CallResult{
Content: truncateString(content, responseLimit),
IsError: payload.IsError,
}, nil
}
func truncateString(value string, maxBytes int) string {
if maxBytes <= 0 || len(value) <= maxBytes {
return value
}
if maxBytes <= 12 {
return value[:maxBytes]
}
return value[:maxBytes-12] + "\n...[truncated]"
}
+52
View File
@@ -0,0 +1,52 @@
package mcp
import (
"encoding/json"
"strings"
"testing"
)
func TestFormatCallPayload_TextAndStructured(t *testing.T) {
raw := json.RawMessage(`{
"content":[{"type":"text","text":"hello"}],
"structuredContent":{"ok":true}
}`)
result, err := formatCallPayload(raw, 4096)
if err != nil {
t.Fatalf("formatCallPayload() error = %v", err)
}
if result.IsError {
t.Fatalf("expected IsError=false")
}
if !strings.Contains(result.Content, "hello") {
t.Fatalf("expected content to contain text block, got %q", result.Content)
}
if !strings.Contains(result.Content, `"ok":true`) {
t.Fatalf("expected content to contain structured content, got %q", result.Content)
}
}
func TestFormatCallPayload_Truncates(t *testing.T) {
raw := json.RawMessage(`{"content":[{"type":"text","text":"abcdefghijklmnopqrstuvwxyz"}]}`)
result, err := formatCallPayload(raw, 12)
if err != nil {
t.Fatalf("formatCallPayload() error = %v", err)
}
if len(result.Content) != 12 {
t.Fatalf("expected truncated length 12, got %d", len(result.Content))
}
}
func TestFormatCallPayload_RespectsIsError(t *testing.T) {
raw := json.RawMessage(`{"content":[{"type":"text","text":"failed"}],"isError":true}`)
result, err := formatCallPayload(raw, 4096)
if err != nil {
t.Fatalf("formatCallPayload() error = %v", err)
}
if !result.IsError {
t.Fatalf("expected IsError=true")
}
}
+190
View File
@@ -0,0 +1,190 @@
package mcp
import (
"context"
"fmt"
"slices"
"strings"
"sync"
)
type clientFactory func(config ServerConfig) Client
type managedServer struct {
config ServerConfig
client Client
}
// Manager owns MCP servers and maps discovered MCP tools to PicoClaw tools.
type Manager struct {
mu sync.RWMutex
servers map[string]*managedServer
tools map[string]RegisteredTool
discovered bool
newClient clientFactory
}
func NewManager(configs map[string]ServerConfig) *Manager {
servers := make(map[string]*managedServer, len(configs))
for name, cfg := range configs {
copied := cfg
copied.Name = name
servers[name] = &managedServer{config: copied}
}
return &Manager{
servers: servers,
tools: make(map[string]RegisteredTool),
discovered: false,
newClient: func(config ServerConfig) Client {
return NewStdioClient(config)
},
}
}
// DiscoverTools starts configured MCP servers and returns discovered tool metadata.
func (m *Manager) DiscoverTools(ctx context.Context) ([]RegisteredTool, error) {
m.mu.Lock()
if m.discovered {
tools := toolsFromMap(m.tools)
m.mu.Unlock()
return tools, nil
}
discoveryErrors := make([]string, 0)
for serverName, server := range m.servers {
client := m.newClient(server.config)
if err := client.Start(ctx); err != nil {
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
continue
}
remoteTools, err := client.ListTools(ctx)
if err != nil {
_ = client.Close()
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
continue
}
server.client = client
for _, remoteTool := range remoteTools {
if !isToolAllowed(remoteTool.Name, server.config.IncludeTools, server.config.ExcludeTools) {
continue
}
qualifiedName := m.makeUniqueToolName(serverName, remoteTool.Name)
parameters := normalizeSchema(remoteTool.InputSchema)
m.tools[qualifiedName] = RegisteredTool{
QualifiedName: qualifiedName,
ServerName: serverName,
ToolName: remoteTool.Name,
Description: remoteTool.Description,
Parameters: parameters,
}
}
}
m.discovered = true
tools := toolsFromMap(m.tools)
m.mu.Unlock()
if len(tools) == 0 && len(discoveryErrors) > 0 {
return nil, fmt.Errorf("mcp tool discovery failed: %s", strings.Join(discoveryErrors, "; "))
}
return tools, nil
}
func (m *Manager) CallTool(ctx context.Context, qualifiedName string, args map[string]any) (CallResult, error) {
m.mu.RLock()
tool, ok := m.tools[qualifiedName]
if !ok {
m.mu.RUnlock()
return CallResult{}, fmt.Errorf("mcp tool %q not found", qualifiedName)
}
server := m.servers[tool.ServerName]
if server == nil || server.client == nil {
m.mu.RUnlock()
return CallResult{}, fmt.Errorf("mcp server %q is not active", tool.ServerName)
}
client := server.client
toolName := tool.ToolName
m.mu.RUnlock()
if args == nil {
args = map[string]any{}
}
return client.CallTool(ctx, toolName, args)
}
func (m *Manager) Close() error {
m.mu.Lock()
servers := make([]*managedServer, 0, len(m.servers))
for _, server := range m.servers {
servers = append(servers, server)
}
m.mu.Unlock()
var firstErr error
for _, server := range servers {
if server.client == nil {
continue
}
if err := server.client.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func (m *Manager) makeUniqueToolName(serverName, toolName string) string {
base := QualifiedToolName(serverName, toolName)
if _, exists := m.tools[base]; !exists {
return base
}
for index := 2; ; index++ {
candidate := fmt.Sprintf("%s_%d", base, index)
if len(candidate) > qualifiedNameMaxLen {
overflow := len(candidate) - qualifiedNameMaxLen
if overflow < len(base) {
candidate = base[:len(base)-overflow] + fmt.Sprintf("_%d", index)
} else {
candidate = candidate[:qualifiedNameMaxLen]
}
}
if _, exists := m.tools[candidate]; !exists {
return candidate
}
}
}
func normalizeSchema(schema map[string]any) map[string]any {
if len(schema) == 0 {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
return schema
}
func isToolAllowed(name string, include, exclude []string) bool {
if len(include) > 0 && !slices.Contains(include, name) {
return false
}
if slices.Contains(exclude, name) {
return false
}
return true
}
func toolsFromMap(tools map[string]RegisteredTool) []RegisteredTool {
out := make([]RegisteredTool, 0, len(tools))
for _, tool := range tools {
out = append(out, tool)
}
return out
}
+101
View File
@@ -0,0 +1,101 @@
package mcp
import (
"context"
"testing"
)
type fakeClient struct {
tools []RemoteTool
callResult CallResult
callErr error
lastToolName string
lastArgs map[string]any
}
func (f *fakeClient) Start(_ context.Context) error { return nil }
func (f *fakeClient) ListTools(_ context.Context) ([]RemoteTool, error) {
return f.tools, nil
}
func (f *fakeClient) CallTool(_ context.Context, toolName string, arguments map[string]any) (CallResult, error) {
f.lastToolName = toolName
f.lastArgs = arguments
if f.callErr != nil {
return CallResult{}, f.callErr
}
return f.callResult, nil
}
func (f *fakeClient) Close() error { return nil }
func TestManager_DiscoverTools_FilterAndCall(t *testing.T) {
serverCfg := map[string]ServerConfig{
"Local Dev": {
Command: "fake",
IncludeTools: []string{"alpha", "beta"},
ExcludeTools: []string{"beta"},
},
}
manager := NewManager(serverCfg)
client := &fakeClient{
tools: []RemoteTool{
{Name: "alpha", Description: "tool alpha"},
{Name: "beta", Description: "tool beta"},
{Name: "gamma", Description: "tool gamma"},
},
callResult: CallResult{Content: "ok"},
}
manager.newClient = func(_ ServerConfig) Client {
return client
}
tools, err := manager.DiscoverTools(context.Background())
if err != nil {
t.Fatalf("DiscoverTools() error = %v", err)
}
if len(tools) != 1 {
t.Fatalf("DiscoverTools() returned %d tools, want 1", len(tools))
}
tool := tools[0]
if tool.ToolName != "alpha" {
t.Fatalf("discovered tool = %q, want alpha", tool.ToolName)
}
result, err := manager.CallTool(context.Background(), tool.QualifiedName, map[string]any{"x": 1})
if err != nil {
t.Fatalf("CallTool() error = %v", err)
}
if result.Content != "ok" {
t.Fatalf("CallTool() content = %q, want ok", result.Content)
}
if client.lastToolName != "alpha" {
t.Fatalf("called MCP tool = %q, want alpha", client.lastToolName)
}
}
func TestManager_NormalizeEmptySchema(t *testing.T) {
serverCfg := map[string]ServerConfig{
"srv": {Command: "fake"},
}
manager := NewManager(serverCfg)
manager.newClient = func(_ ServerConfig) Client {
return &fakeClient{
tools: []RemoteTool{{Name: "empty_schema", InputSchema: nil}},
}
}
tools, err := manager.DiscoverTools(context.Background())
if err != nil {
t.Fatalf("DiscoverTools() error = %v", err)
}
if len(tools) != 1 {
t.Fatalf("DiscoverTools() returned %d tools, want 1", len(tools))
}
parameters := tools[0].Parameters
if parameters["type"] != "object" {
t.Fatalf("normalized schema type = %v, want object", parameters["type"])
}
}
+53
View File
@@ -0,0 +1,53 @@
package mcp
import "strings"
const qualifiedNameMaxLen = 64
// QualifiedToolName creates a stable, provider-safe function name.
func QualifiedToolName(serverName, toolName string) string {
prefix := "mcp_" + sanitizeName(serverName) + "__"
tool := sanitizeName(toolName)
maxToolLen := qualifiedNameMaxLen - len(prefix)
if maxToolLen <= 0 {
return prefix[:qualifiedNameMaxLen]
}
if len(tool) > maxToolLen {
tool = tool[:maxToolLen]
}
return prefix + tool
}
func sanitizeName(value string) string {
trimmed := strings.TrimSpace(strings.ToLower(value))
if trimmed == "" {
return "unknown"
}
var b strings.Builder
b.Grow(len(trimmed))
lastUnderscore := false
for i := 0; i < len(trimmed); i++ {
ch := trimmed[i]
isAlphaNum := (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')
if isAlphaNum {
b.WriteByte(ch)
lastUnderscore = false
continue
}
if !lastUnderscore {
b.WriteByte('_')
lastUnderscore = true
}
}
s := strings.Trim(b.String(), "_")
if s == "" {
s = "unknown"
}
if s[0] >= '0' && s[0] <= '9' {
return "t_" + s
}
return s
}
+19
View File
@@ -0,0 +1,19 @@
package mcp
import "testing"
func TestQualifiedToolName_SanitizesAndPrefixes(t *testing.T) {
got := QualifiedToolName("My Server", "Read-File!")
want := "mcp_my_server__read_file"
if got != want {
t.Fatalf("QualifiedToolName() = %q, want %q", got, want)
}
}
func TestQualifiedToolName_TrimToMaxLen(t *testing.T) {
longToolName := "tool_name_with_many_segments_and_extra_text_that_exceeds_the_limit_significantly"
got := QualifiedToolName("server", longToolName)
if len(got) > qualifiedNameMaxLen {
t.Fatalf("qualified name length = %d, want <= %d", len(got), qualifiedNameMaxLen)
}
}
+77
View File
@@ -0,0 +1,77 @@
package mcp
import "time"
const (
defaultInitTimeoutSeconds = 60
defaultCallTimeoutSeconds = 30
defaultMaxResponseBytes = 64 * 1024
defaultScannerBufferBytes = 64 * 1024
maxFrameBytes = 2 * 1024 * 1024
maxToolListPages = 50
)
const (
ProtocolMCPFrames = "mcp"
ProtocolJSONLines = "jsonl"
)
// ServerConfig defines one MCP server connection.
type ServerConfig struct {
Name string
Command string
Args []string
Env map[string]string
WorkingDir string
Protocol string
InitTimeoutSeconds int
CallTimeoutSeconds int
MaxResponseBytes int
IncludeTools []string
ExcludeTools []string
}
func (c ServerConfig) InitTimeout() time.Duration {
seconds := c.InitTimeoutSeconds
if seconds <= 0 {
seconds = defaultInitTimeoutSeconds
}
return time.Duration(seconds) * time.Second
}
func (c ServerConfig) CallTimeout() time.Duration {
seconds := c.CallTimeoutSeconds
if seconds <= 0 {
seconds = defaultCallTimeoutSeconds
}
return time.Duration(seconds) * time.Second
}
func (c ServerConfig) ResponseLimit() int {
if c.MaxResponseBytes <= 0 {
return defaultMaxResponseBytes
}
return c.MaxResponseBytes
}
// RemoteTool is an MCP tool discovered from a server.
type RemoteTool struct {
Name string
Description string
InputSchema map[string]any
}
// RegisteredTool is a discovered tool with a PicoClaw-facing qualified name.
type RegisteredTool struct {
QualifiedName string
ServerName string
ToolName string
Description string
Parameters map[string]any
}
// CallResult is a normalized MCP tool call result.
type CallResult struct {
Content string
IsError bool
}
+85
View File
@@ -0,0 +1,85 @@
package tools
import (
"context"
"errors"
"fmt"
"github.com/sipeed/picoclaw/pkg/mcp"
)
type MCPTool struct {
manager *mcp.Manager
name string
description string
parameters map[string]any
}
func NewMCPTool(manager *mcp.Manager, tool mcp.RegisteredTool) *MCPTool {
description := tool.Description
if description == "" {
description = fmt.Sprintf("MCP tool %s from server %s", tool.ToolName, tool.ServerName)
}
return &MCPTool{
manager: manager,
name: tool.QualifiedName,
description: description,
parameters: tool.Parameters,
}
}
func (t *MCPTool) Name() string {
return t.name
}
func (t *MCPTool) Description() string {
return t.description
}
func (t *MCPTool) Parameters() map[string]interface{} {
return t.parameters
}
func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
if t.manager == nil {
return ErrorResult("MCP manager is not configured")
}
result, err := t.manager.CallTool(ctx, t.name, args)
if err != nil {
return ErrorResult(fmt.Sprintf("MCP tool %s failed: %v", t.name, err)).WithError(err)
}
if result.IsError {
err := errors.New(result.Content)
return ErrorResult(result.Content).WithError(err)
}
return SilentResult(result.Content)
}
// RegisterMCPTools discovers tools from MCP servers and registers them into the registry.
func RegisterMCPTools(ctx context.Context, registry *ToolRegistry, manager *mcp.Manager) (int, error) {
if registry == nil || manager == nil {
return 0, nil
}
discoveredTools, err := manager.DiscoverTools(ctx)
if err != nil {
return 0, err
}
return RegisterKnownMCPTools(registry, manager, discoveredTools), nil
}
// RegisterKnownMCPTools registers already-discovered MCP tools.
// This avoids repeated discovery work when multiple registries share one manager.
func RegisterKnownMCPTools(registry *ToolRegistry, manager *mcp.Manager, discoveredTools []mcp.RegisteredTool) int {
if registry == nil || manager == nil || len(discoveredTools) == 0 {
return 0
}
for _, tool := range discoveredTools {
registry.Register(NewMCPTool(manager, tool))
}
return len(discoveredTools)
}
+44
View File
@@ -0,0 +1,44 @@
package tools
import (
"testing"
"github.com/sipeed/picoclaw/pkg/mcp"
)
func TestRegisterKnownMCPTools_RegistersAllTools(t *testing.T) {
registry := NewToolRegistry()
manager := &mcp.Manager{}
discovered := []mcp.RegisteredTool{
{
QualifiedName: "mcp_context7__resolve_library_id",
ServerName: "context7",
ToolName: "resolve-library-id",
Description: "Resolve library ID",
Parameters: map[string]any{
"type": "object",
},
},
{
QualifiedName: "mcp_context7__query_docs",
ServerName: "context7",
ToolName: "query-docs",
Description: "Query docs",
Parameters: map[string]any{
"type": "object",
},
},
}
count := RegisterKnownMCPTools(registry, manager, discovered)
if count != 2 {
t.Fatalf("RegisterKnownMCPTools count = %d, want 2", count)
}
if _, ok := registry.Get("mcp_context7__resolve_library_id"); !ok {
t.Fatalf("expected mcp_context7__resolve_library_id to be registered")
}
if _, ok := registry.Get("mcp_context7__query_docs"); !ok {
t.Fatalf("expected mcp_context7__query_docs to be registered")
}
}