mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50c58a3462 | |||
| 403e048821 |
@@ -130,6 +130,28 @@
|
|||||||
},
|
},
|
||||||
"cron": {
|
"cron": {
|
||||||
"exec_timeout_minutes": 5
|
"exec_timeout_minutes": 5
|
||||||
|
},
|
||||||
|
"mcp": {
|
||||||
|
"enabled": false,
|
||||||
|
"servers": {
|
||||||
|
"filesystem": {
|
||||||
|
"enabled": false,
|
||||||
|
"command": "npx",
|
||||||
|
"args": [
|
||||||
|
"-y",
|
||||||
|
"@modelcontextprotocol/server-filesystem",
|
||||||
|
"/tmp"
|
||||||
|
],
|
||||||
|
"protocol": "mcp",
|
||||||
|
"env": {},
|
||||||
|
"working_dir": "",
|
||||||
|
"init_timeout_seconds": 60,
|
||||||
|
"call_timeout_seconds": 30,
|
||||||
|
"max_response_bytes": 65536,
|
||||||
|
"include_tools": [],
|
||||||
|
"exclude_tools": []
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"heartbeat": {
|
"heartbeat": {
|
||||||
@@ -144,4 +166,4 @@
|
|||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": 18790
|
"port": 18790
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+59
-4
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
"github.com/sipeed/picoclaw/pkg/constants"
|
"github.com/sipeed/picoclaw/pkg/constants"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||||
"github.com/sipeed/picoclaw/pkg/providers"
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
"github.com/sipeed/picoclaw/pkg/session"
|
"github.com/sipeed/picoclaw/pkg/session"
|
||||||
"github.com/sipeed/picoclaw/pkg/state"
|
"github.com/sipeed/picoclaw/pkg/state"
|
||||||
@@ -44,8 +45,12 @@ type AgentLoop struct {
|
|||||||
running atomic.Bool
|
running atomic.Bool
|
||||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||||
channelManager *channels.Manager
|
channelManager *channels.Manager
|
||||||
|
mcpManager *mcp.Manager
|
||||||
|
mcpCloseOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const defaultWebFetchMaxChars = 50000
|
||||||
|
|
||||||
// processOptions configures how a message is processed
|
// processOptions configures how a message is processed
|
||||||
type processOptions struct {
|
type processOptions struct {
|
||||||
SessionKey string // Session identifier for history/context
|
SessionKey string // Session identifier for history/context
|
||||||
@@ -60,7 +65,14 @@ type processOptions struct {
|
|||||||
|
|
||||||
// createToolRegistry creates a tool registry with common tools.
|
// createToolRegistry creates a tool registry with common tools.
|
||||||
// This is shared between main agent and subagents.
|
// This is shared between main agent and subagents.
|
||||||
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
|
func createToolRegistry(
|
||||||
|
workspace string,
|
||||||
|
restrict bool,
|
||||||
|
cfg *config.Config,
|
||||||
|
msgBus *bus.MessageBus,
|
||||||
|
mcpManager *mcp.Manager,
|
||||||
|
discoveredMCPTools []mcp.RegisteredTool,
|
||||||
|
) *tools.ToolRegistry {
|
||||||
registry := tools.NewToolRegistry()
|
registry := tools.NewToolRegistry()
|
||||||
|
|
||||||
// File system tools
|
// File system tools
|
||||||
@@ -85,7 +97,9 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg
|
|||||||
}); searchTool != nil {
|
}); searchTool != nil {
|
||||||
registry.Register(searchTool)
|
registry.Register(searchTool)
|
||||||
}
|
}
|
||||||
registry.Register(tools.NewWebFetchTool(50000))
|
registry.Register(tools.NewWebFetchTool(defaultWebFetchMaxChars))
|
||||||
|
|
||||||
|
tools.RegisterKnownMCPTools(registry, mcpManager, discoveredMCPTools)
|
||||||
|
|
||||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||||
registry.Register(tools.NewI2CTool())
|
registry.Register(tools.NewI2CTool())
|
||||||
@@ -113,12 +127,35 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
|||||||
|
|
||||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||||
|
|
||||||
|
var (
|
||||||
|
mcpManager *mcp.Manager
|
||||||
|
discoveredMCPTools []mcp.RegisteredTool
|
||||||
|
)
|
||||||
|
if cfg.Tools.MCP.Enabled {
|
||||||
|
bootstrap, err := bootstrapMCP(cfg.Tools.MCP)
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnCF("agent", "MCP tool bootstrap failed",
|
||||||
|
map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
} else if bootstrap != nil {
|
||||||
|
mcpManager = bootstrap.Manager
|
||||||
|
discoveredMCPTools = bootstrap.Tools
|
||||||
|
if len(discoveredMCPTools) > 0 {
|
||||||
|
logger.InfoCF("agent", "MCP tools registered",
|
||||||
|
map[string]interface{}{
|
||||||
|
"count": len(discoveredMCPTools),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create tool registry for main agent
|
// Create tool registry for main agent
|
||||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
|
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||||
|
|
||||||
// Create subagent manager with its own tool registry
|
// Create subagent manager with its own tool registry
|
||||||
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
||||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
|
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||||
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
||||||
subagentManager.SetTools(subagentTools)
|
subagentManager.SetTools(subagentTools)
|
||||||
|
|
||||||
@@ -151,11 +188,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
|||||||
contextBuilder: contextBuilder,
|
contextBuilder: contextBuilder,
|
||||||
tools: toolsRegistry,
|
tools: toolsRegistry,
|
||||||
summarizing: sync.Map{},
|
summarizing: sync.Map{},
|
||||||
|
mcpManager: mcpManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||||
al.running.Store(true)
|
al.running.Store(true)
|
||||||
|
defer al.closeMCP()
|
||||||
|
|
||||||
for al.running.Load() {
|
for al.running.Load() {
|
||||||
select {
|
select {
|
||||||
@@ -198,6 +237,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
|||||||
|
|
||||||
func (al *AgentLoop) Stop() {
|
func (al *AgentLoop) Stop() {
|
||||||
al.running.Store(false)
|
al.running.Store(false)
|
||||||
|
al.closeMCP()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (al *AgentLoop) closeMCP() {
|
||||||
|
if al.mcpManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
al.mcpCloseOnce.Do(func() {
|
||||||
|
if err := al.mcpManager.Close(); err != nil {
|
||||||
|
logger.WarnCF("agent", "Failed to close MCP manager",
|
||||||
|
map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||||
|
|||||||
@@ -0,0 +1,110 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
mcpBootstrapMinTimeout = 10 * time.Second
|
||||||
|
mcpBootstrapMaxTimeout = 5 * time.Minute
|
||||||
|
mcpBootstrapGraceTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type mcpBootstrapResult struct {
|
||||||
|
Manager *mcp.Manager
|
||||||
|
Tools []mcp.RegisteredTool
|
||||||
|
}
|
||||||
|
|
||||||
|
func bootstrapMCP(cfg config.MCPToolsConfig) (*mcpBootstrapResult, error) {
|
||||||
|
serverConfigs := buildMCPServerConfigs(cfg)
|
||||||
|
if len(serverConfigs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := mcp.NewManager(serverConfigs)
|
||||||
|
|
||||||
|
discoveryTimeout := calculateMCPDiscoveryTimeout(serverConfigs)
|
||||||
|
discoveryCtx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
discoveredTools, err := manager.DiscoverTools(discoveryCtx)
|
||||||
|
if err != nil {
|
||||||
|
_ = manager.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &mcpBootstrapResult{
|
||||||
|
Manager: manager,
|
||||||
|
Tools: discoveredTools,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateMCPDiscoveryTimeout(serverConfigs map[string]mcp.ServerConfig) time.Duration {
|
||||||
|
maxInitTimeout := mcpBootstrapMinTimeout
|
||||||
|
|
||||||
|
for _, serverConfig := range serverConfigs {
|
||||||
|
initTimeout := serverConfig.InitTimeout()
|
||||||
|
if initTimeout > maxInitTimeout {
|
||||||
|
maxInitTimeout = initTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := maxInitTimeout + mcpBootstrapGraceTimeout
|
||||||
|
if timeout < mcpBootstrapMinTimeout {
|
||||||
|
return mcpBootstrapMinTimeout
|
||||||
|
}
|
||||||
|
if timeout > mcpBootstrapMaxTimeout {
|
||||||
|
return mcpBootstrapMaxTimeout
|
||||||
|
}
|
||||||
|
return timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMCPServerConfigs(cfg config.MCPToolsConfig) map[string]mcp.ServerConfig {
|
||||||
|
servers := make(map[string]mcp.ServerConfig, len(cfg.Servers))
|
||||||
|
|
||||||
|
for serverName, serverCfg := range cfg.Servers {
|
||||||
|
if !serverCfg.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
envCopy := make(map[string]string, len(serverCfg.Env))
|
||||||
|
for key, value := range serverCfg.Env {
|
||||||
|
envCopy[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
servers[serverName] = mcp.ServerConfig{
|
||||||
|
Name: serverName,
|
||||||
|
Command: serverCfg.Command,
|
||||||
|
Args: append([]string{}, serverCfg.Args...),
|
||||||
|
Env: envCopy,
|
||||||
|
WorkingDir: serverCfg.WorkingDir,
|
||||||
|
Protocol: inferMCPProtocol(serverCfg.Protocol, serverCfg.Command),
|
||||||
|
InitTimeoutSeconds: serverCfg.InitTimeoutSeconds,
|
||||||
|
CallTimeoutSeconds: serverCfg.CallTimeoutSeconds,
|
||||||
|
MaxResponseBytes: serverCfg.MaxResponseBytes,
|
||||||
|
IncludeTools: append([]string{}, serverCfg.IncludeTools...),
|
||||||
|
ExcludeTools: append([]string{}, serverCfg.ExcludeTools...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return servers
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferMCPProtocol(configuredProtocol, command string) string {
|
||||||
|
if protocol := strings.TrimSpace(configuredProtocol); protocol != "" {
|
||||||
|
return protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context7 currently emits JSON-RPC messages as JSONL on stdio,
|
||||||
|
// so defaulting avoids long startup waits when protocol is omitted.
|
||||||
|
if strings.Contains(strings.ToLower(command), "context7-mcp") {
|
||||||
|
return mcp.ProtocolJSONLines
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
+82
-1
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/caarlos0/env/v11"
|
"github.com/caarlos0/env/v11"
|
||||||
@@ -51,7 +52,10 @@ type Config struct {
|
|||||||
Tools ToolsConfig `json:"tools"`
|
Tools ToolsConfig `json:"tools"`
|
||||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||||
Devices DevicesConfig `json:"devices"`
|
Devices DevicesConfig `json:"devices"`
|
||||||
mu sync.RWMutex
|
// MCPServers is a compatibility alias for configs using top-level "mcpServers".
|
||||||
|
// Canonical config remains tools.mcp.servers.
|
||||||
|
MCPServers map[string]LegacyMCPServerConfig `json:"mcpServers,omitempty"`
|
||||||
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type AgentsConfig struct {
|
type AgentsConfig struct {
|
||||||
@@ -222,9 +226,38 @@ type CronToolsConfig struct {
|
|||||||
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MCPServerConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Command string `json:"command"`
|
||||||
|
Args []string `json:"args"`
|
||||||
|
Env map[string]string `json:"env"`
|
||||||
|
WorkingDir string `json:"working_dir"`
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
InitTimeoutSeconds int `json:"init_timeout_seconds"`
|
||||||
|
CallTimeoutSeconds int `json:"call_timeout_seconds"`
|
||||||
|
MaxResponseBytes int `json:"max_response_bytes"`
|
||||||
|
IncludeTools []string `json:"include_tools"`
|
||||||
|
ExcludeTools []string `json:"exclude_tools"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MCPToolsConfig struct {
|
||||||
|
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||||
|
Servers map[string]MCPServerConfig `json:"servers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LegacyMCPServerConfig supports compatibility with "mcpServers" style config.
|
||||||
|
type LegacyMCPServerConfig struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Command string `json:"command"`
|
||||||
|
Args []string `json:"args"`
|
||||||
|
Env map[string]string `json:"env"`
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
}
|
||||||
|
|
||||||
type ToolsConfig struct {
|
type ToolsConfig struct {
|
||||||
Web WebToolsConfig `json:"web"`
|
Web WebToolsConfig `json:"web"`
|
||||||
Cron CronToolsConfig `json:"cron"`
|
Cron CronToolsConfig `json:"cron"`
|
||||||
|
MCP MCPToolsConfig `json:"mcp"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultConfig() *Config {
|
func DefaultConfig() *Config {
|
||||||
@@ -342,6 +375,10 @@ func DefaultConfig() *Config {
|
|||||||
Cron: CronToolsConfig{
|
Cron: CronToolsConfig{
|
||||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||||
},
|
},
|
||||||
|
MCP: MCPToolsConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Servers: map[string]MCPServerConfig{},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Heartbeat: HeartbeatConfig{
|
Heartbeat: HeartbeatConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
@@ -373,9 +410,53 @@ func LoadConfig(path string) (*Config, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cfg.applyLegacyMCPServers()
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) applyLegacyMCPServers() {
|
||||||
|
// If canonical MCP config already exists, keep it as source of truth.
|
||||||
|
if len(c.Tools.MCP.Servers) > 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(c.MCPServers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Tools.MCP.Servers == nil {
|
||||||
|
c.Tools.MCP.Servers = map[string]MCPServerConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, legacy := range c.MCPServers {
|
||||||
|
if strings.TrimSpace(legacy.Command) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled := true
|
||||||
|
if legacy.Type != "" && legacy.Type != "stdio" {
|
||||||
|
enabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
envCopy := make(map[string]string, len(legacy.Env))
|
||||||
|
for key, value := range legacy.Env {
|
||||||
|
envCopy[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Tools.MCP.Servers[name] = MCPServerConfig{
|
||||||
|
Enabled: enabled,
|
||||||
|
Command: legacy.Command,
|
||||||
|
Args: append([]string{}, legacy.Args...),
|
||||||
|
Env: envCopy,
|
||||||
|
Protocol: legacy.Protocol,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.Tools.MCP.Servers) > 0 {
|
||||||
|
c.Tools.MCP.Enabled = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func SaveConfig(path string, cfg *Config) error {
|
func SaveConfig(path string, cfg *Config) error {
|
||||||
cfg.mu.RLock()
|
cfg.mu.RLock()
|
||||||
defer cfg.mu.RUnlock()
|
defer cfg.mu.RUnlock()
|
||||||
|
|||||||
@@ -0,0 +1,603 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is the transport-agnostic MCP client contract.
|
||||||
|
type Client interface {
|
||||||
|
Start(ctx context.Context) error
|
||||||
|
ListTools(ctx context.Context) ([]RemoteTool, error)
|
||||||
|
CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers).
|
||||||
|
type StdioClient struct {
|
||||||
|
config ServerConfig
|
||||||
|
mode string
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
writeMu sync.Mutex
|
||||||
|
|
||||||
|
started bool
|
||||||
|
closed bool
|
||||||
|
|
||||||
|
cmd *exec.Cmd
|
||||||
|
stdin io.WriteCloser
|
||||||
|
stdout io.ReadCloser
|
||||||
|
stderr io.ReadCloser
|
||||||
|
waitCh chan struct{}
|
||||||
|
pending map[string]chan rpcResponse
|
||||||
|
|
||||||
|
nextID uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcRequest struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params any `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcResponseEnvelope struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
ID json.RawMessage `json:"id,omitempty"`
|
||||||
|
Result json.RawMessage `json:"result,omitempty"`
|
||||||
|
Error *rpcError `json:"error,omitempty"`
|
||||||
|
Method string `json:"method,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcResponse struct {
|
||||||
|
result json.RawMessage
|
||||||
|
rpcErr *rpcError
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type initializeParams struct {
|
||||||
|
ProtocolVersion string `json:"protocolVersion"`
|
||||||
|
Capabilities map[string]any `json:"capabilities"`
|
||||||
|
ClientInfo map[string]interface{} `json:"clientInfo"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStdioClient(config ServerConfig) *StdioClient {
|
||||||
|
return &StdioClient{
|
||||||
|
config: config,
|
||||||
|
mode: normalizeProtocol(config.Protocol),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) Start(ctx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.started {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.config.Command) == "" {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return fmt.Errorf("mcp server %q command is empty", c.config.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(c.config.Command, c.config.Args...)
|
||||||
|
if c.config.WorkingDir != "" {
|
||||||
|
cmd.Dir = c.config.WorkingDir
|
||||||
|
}
|
||||||
|
cmd.Env = buildProcessEnv(c.config.Env)
|
||||||
|
|
||||||
|
stdin, err := cmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return fmt.Errorf("create stdin pipe: %w", err)
|
||||||
|
}
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return fmt.Errorf("create stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
stderr, err := cmd.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return fmt.Errorf("create stderr pipe: %w", err)
|
||||||
|
}
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return fmt.Errorf("start process: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.started = true
|
||||||
|
c.closed = false
|
||||||
|
c.cmd = cmd
|
||||||
|
c.stdin = stdin
|
||||||
|
c.stdout = stdout
|
||||||
|
c.stderr = stderr
|
||||||
|
c.waitCh = make(chan struct{})
|
||||||
|
c.pending = make(map[string]chan rpcResponse)
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
go c.readLoop()
|
||||||
|
go c.waitLoop()
|
||||||
|
go c.drainStderr()
|
||||||
|
|
||||||
|
initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err = c.request(initCtx, "initialize", initializeParams{
|
||||||
|
ProtocolVersion: "2024-11-05",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"tools": map[string]any{},
|
||||||
|
},
|
||||||
|
ClientInfo: map[string]any{
|
||||||
|
"name": "picoclaw",
|
||||||
|
"version": "dev",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return fmt.Errorf("initialize failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.notify("notifications/initialized", map[string]any{}); err != nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return fmt.Errorf("initialized notification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) {
|
||||||
|
if err := c.Start(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type listToolsResponse struct {
|
||||||
|
Tools []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
InputSchema map[string]any `json:"inputSchema"`
|
||||||
|
} `json:"tools"`
|
||||||
|
NextCursor string `json:"nextCursor,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
allTools := make([]RemoteTool, 0, 8)
|
||||||
|
cursor := ""
|
||||||
|
|
||||||
|
for page := 0; page < maxToolListPages; page++ {
|
||||||
|
params := map[string]any{}
|
||||||
|
if cursor != "" {
|
||||||
|
params["cursor"] = cursor
|
||||||
|
}
|
||||||
|
|
||||||
|
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||||
|
raw, err := c.request(callCtx, "tools/list", params)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var response listToolsResponse
|
||||||
|
if err := json.Unmarshal(raw, &response); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode tools/list response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range response.Tools {
|
||||||
|
allTools = append(allTools, RemoteTool{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
InputSchema: tool.InputSchema,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.NextCursor == "" {
|
||||||
|
return allTools, nil
|
||||||
|
}
|
||||||
|
cursor = response.NextCursor
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) {
|
||||||
|
if err := c.Start(ctx); err != nil {
|
||||||
|
return CallResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
raw, err := c.request(callCtx, "tools/call", map[string]any{
|
||||||
|
"name": toolName,
|
||||||
|
"arguments": arguments,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return CallResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return formatCallPayload(raw, c.config.ResponseLimit())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if !c.started || c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
cmd := c.cmd
|
||||||
|
stdin := c.stdin
|
||||||
|
waitCh := c.waitCh
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
c.failPending(errors.New("mcp client closed"))
|
||||||
|
|
||||||
|
if stdin != nil {
|
||||||
|
_ = stdin.Close()
|
||||||
|
}
|
||||||
|
if cmd != nil && cmd.Process != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
|
||||||
|
if waitCh != nil {
|
||||||
|
select {
|
||||||
|
case <-waitCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) {
|
||||||
|
id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10)
|
||||||
|
responseCh := make(chan rpcResponse, 1)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, fmt.Errorf("mcp server %q is closed", c.config.Name)
|
||||||
|
}
|
||||||
|
c.pending[id] = responseCh
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
req := rpcRequest{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: id,
|
||||||
|
Method: method,
|
||||||
|
Params: params,
|
||||||
|
}
|
||||||
|
if err := c.writeMessage(req); err != nil {
|
||||||
|
c.removePending(id)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.removePending(id)
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case response := <-responseCh:
|
||||||
|
if response.err != nil {
|
||||||
|
return nil, response.err
|
||||||
|
}
|
||||||
|
if response.rpcErr != nil {
|
||||||
|
return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message)
|
||||||
|
}
|
||||||
|
return response.result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) notify(method string, params any) error {
|
||||||
|
req := rpcRequest{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Method: method,
|
||||||
|
Params: params,
|
||||||
|
}
|
||||||
|
return c.writeMessage(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) writeMessage(payload any) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed || c.stdin == nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return fmt.Errorf("mcp server %q is not writable", c.config.Name)
|
||||||
|
}
|
||||||
|
stdin := c.stdin
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal json-rpc payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.mode == ProtocolJSONLines {
|
||||||
|
c.writeMu.Lock()
|
||||||
|
defer c.writeMu.Unlock()
|
||||||
|
|
||||||
|
if _, err := stdin.Write(append(data, '\n')); err != nil {
|
||||||
|
return fmt.Errorf("write jsonl body: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
|
||||||
|
|
||||||
|
c.writeMu.Lock()
|
||||||
|
defer c.writeMu.Unlock()
|
||||||
|
|
||||||
|
if _, err := io.WriteString(stdin, frameHeader); err != nil {
|
||||||
|
return fmt.Errorf("write frame header: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := stdin.Write(data); err != nil {
|
||||||
|
return fmt.Errorf("write frame body: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) readLoop() {
|
||||||
|
if c.mode == ProtocolJSONLines {
|
||||||
|
c.readJSONLLoop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readMCPFrameLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) readMCPFrameLoop() {
|
||||||
|
reader := bufio.NewReader(c.stdout)
|
||||||
|
|
||||||
|
for {
|
||||||
|
payload, err := readFramePayload(reader)
|
||||||
|
if err != nil {
|
||||||
|
c.failPending(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var envelope rpcResponseEnvelope
|
||||||
|
if err := json.Unmarshal(payload, &envelope); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.dispatchResponse(envelope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) readJSONLLoop() {
|
||||||
|
scanner := bufio.NewScanner(c.stdout)
|
||||||
|
scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var envelope rpcResponseEnvelope
|
||||||
|
if err := json.Unmarshal([]byte(line), &envelope); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.dispatchResponse(envelope)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
c.failPending(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.failPending(io.EOF)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) {
|
||||||
|
if len(envelope.ID) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, ok := parseRPCID(envelope.ID)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
responseCh := c.pending[id]
|
||||||
|
if responseCh != nil {
|
||||||
|
delete(c.pending, id)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if responseCh == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := rpcResponse{
|
||||||
|
result: envelope.Result,
|
||||||
|
rpcErr: envelope.Error,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case responseCh <- response:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) waitLoop() {
|
||||||
|
c.mu.Lock()
|
||||||
|
cmd := c.cmd
|
||||||
|
waitCh := c.waitCh
|
||||||
|
serverName := c.config.Name
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if cmd == nil {
|
||||||
|
if waitCh != nil {
|
||||||
|
close(waitCh)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cmd.Wait()
|
||||||
|
if waitCh != nil {
|
||||||
|
close(waitCh)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnCF("mcp", "MCP process exited with error",
|
||||||
|
map[string]any{
|
||||||
|
"server": serverName,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) drainStderr() {
|
||||||
|
c.mu.Lock()
|
||||||
|
stderr := c.stderr
|
||||||
|
serverName := c.config.Name
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if stderr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(stderr)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.DebugCF("mcp", "MCP server stderr",
|
||||||
|
map[string]any{
|
||||||
|
"server": serverName,
|
||||||
|
"line": line,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) failPending(err error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
pending := c.pending
|
||||||
|
c.pending = make(map[string]chan rpcResponse)
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if len(pending) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ch := range pending {
|
||||||
|
select {
|
||||||
|
case ch <- rpcResponse{err: err}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StdioClient) removePending(id string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.pending, id)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func readFramePayload(reader *bufio.Reader) ([]byte, error) {
|
||||||
|
contentLength := -1
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimRight(line, "\r\n")
|
||||||
|
if trimmed == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(trimmed, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
headerName := strings.TrimSpace(strings.ToLower(parts[0]))
|
||||||
|
if headerName != "content-length" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
value := strings.TrimSpace(parts[1])
|
||||||
|
length, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid content-length %q: %w", value, err)
|
||||||
|
}
|
||||||
|
contentLength = length
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentLength <= 0 {
|
||||||
|
return nil, fmt.Errorf("missing content-length")
|
||||||
|
}
|
||||||
|
if contentLength > maxFrameBytes {
|
||||||
|
return nil, fmt.Errorf("frame too large (%d bytes)", contentLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, contentLength)
|
||||||
|
if _, err := io.ReadFull(reader, payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRPCID(raw json.RawMessage) (string, bool) {
|
||||||
|
var stringID string
|
||||||
|
if err := json.Unmarshal(raw, &stringID); err == nil {
|
||||||
|
return stringID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
var numberID float64
|
||||||
|
if err := json.Unmarshal(raw, &numberID); err == nil {
|
||||||
|
return strconv.FormatInt(int64(numberID), 10), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||||
|
if _, hasDeadline := parent.Deadline(); hasDeadline {
|
||||||
|
return context.WithCancel(parent)
|
||||||
|
}
|
||||||
|
return context.WithTimeout(parent, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildProcessEnv(custom map[string]string) []string {
|
||||||
|
base := os.Environ()
|
||||||
|
if len(custom) == 0 {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(custom))
|
||||||
|
for key := range custom {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
env := make([]string, 0, len(base)+len(keys))
|
||||||
|
env = append(env, base...)
|
||||||
|
for _, key := range keys {
|
||||||
|
env = append(env, key+"="+custom[key])
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeProtocol(protocol string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(protocol)) {
|
||||||
|
case "", ProtocolMCPFrames:
|
||||||
|
return ProtocolMCPFrames
|
||||||
|
case ProtocolJSONLines:
|
||||||
|
return ProtocolJSONLines
|
||||||
|
default:
|
||||||
|
return ProtocolMCPFrames
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type callResponse struct {
|
||||||
|
Content []contentBlock `json:"content"`
|
||||||
|
StructuredContent any `json:"structuredContent,omitempty"`
|
||||||
|
IsError bool `json:"isError,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type contentBlock struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatCallPayload(raw json.RawMessage, responseLimit int) (CallResult, error) {
|
||||||
|
var payload callResponse
|
||||||
|
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||||
|
// Fallback for servers that return non-standard payloads.
|
||||||
|
return CallResult{
|
||||||
|
Content: truncateString(strings.TrimSpace(string(raw)), responseLimit),
|
||||||
|
IsError: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := make([]string, 0, len(payload.Content)+1)
|
||||||
|
for _, block := range payload.Content {
|
||||||
|
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||||
|
parts = append(parts, block.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.StructuredContent != nil {
|
||||||
|
if encoded, err := json.Marshal(payload.StructuredContent); err == nil {
|
||||||
|
parts = append(parts, string(encoded))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content := strings.TrimSpace(strings.Join(parts, "\n"))
|
||||||
|
if content == "" {
|
||||||
|
content = "{}"
|
||||||
|
}
|
||||||
|
|
||||||
|
return CallResult{
|
||||||
|
Content: truncateString(content, responseLimit),
|
||||||
|
IsError: payload.IsError,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateString(value string, maxBytes int) string {
|
||||||
|
if maxBytes <= 0 || len(value) <= maxBytes {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if maxBytes <= 12 {
|
||||||
|
return value[:maxBytes]
|
||||||
|
}
|
||||||
|
return value[:maxBytes-12] + "\n...[truncated]"
|
||||||
|
}
|
||||||
@@ -0,0 +1,190 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type clientFactory func(config ServerConfig) Client
|
||||||
|
|
||||||
|
type managedServer struct {
|
||||||
|
config ServerConfig
|
||||||
|
client Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager owns MCP servers and maps discovered MCP tools to PicoClaw tools.
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
servers map[string]*managedServer
|
||||||
|
tools map[string]RegisteredTool
|
||||||
|
|
||||||
|
discovered bool
|
||||||
|
newClient clientFactory
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(configs map[string]ServerConfig) *Manager {
|
||||||
|
servers := make(map[string]*managedServer, len(configs))
|
||||||
|
for name, cfg := range configs {
|
||||||
|
copied := cfg
|
||||||
|
copied.Name = name
|
||||||
|
servers[name] = &managedServer{config: copied}
|
||||||
|
}
|
||||||
|
return &Manager{
|
||||||
|
servers: servers,
|
||||||
|
tools: make(map[string]RegisteredTool),
|
||||||
|
discovered: false,
|
||||||
|
newClient: func(config ServerConfig) Client {
|
||||||
|
return NewStdioClient(config)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiscoverTools starts configured MCP servers and returns discovered tool metadata.
|
||||||
|
func (m *Manager) DiscoverTools(ctx context.Context) ([]RegisteredTool, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
if m.discovered {
|
||||||
|
tools := toolsFromMap(m.tools)
|
||||||
|
m.mu.Unlock()
|
||||||
|
return tools, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
discoveryErrors := make([]string, 0)
|
||||||
|
|
||||||
|
for serverName, server := range m.servers {
|
||||||
|
client := m.newClient(server.config)
|
||||||
|
if err := client.Start(ctx); err != nil {
|
||||||
|
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteTools, err := client.ListTools(ctx)
|
||||||
|
if err != nil {
|
||||||
|
_ = client.Close()
|
||||||
|
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
server.client = client
|
||||||
|
for _, remoteTool := range remoteTools {
|
||||||
|
if !isToolAllowed(remoteTool.Name, server.config.IncludeTools, server.config.ExcludeTools) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
qualifiedName := m.makeUniqueToolName(serverName, remoteTool.Name)
|
||||||
|
parameters := normalizeSchema(remoteTool.InputSchema)
|
||||||
|
m.tools[qualifiedName] = RegisteredTool{
|
||||||
|
QualifiedName: qualifiedName,
|
||||||
|
ServerName: serverName,
|
||||||
|
ToolName: remoteTool.Name,
|
||||||
|
Description: remoteTool.Description,
|
||||||
|
Parameters: parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.discovered = true
|
||||||
|
tools := toolsFromMap(m.tools)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if len(tools) == 0 && len(discoveryErrors) > 0 {
|
||||||
|
return nil, fmt.Errorf("mcp tool discovery failed: %s", strings.Join(discoveryErrors, "; "))
|
||||||
|
}
|
||||||
|
return tools, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) CallTool(ctx context.Context, qualifiedName string, args map[string]any) (CallResult, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
tool, ok := m.tools[qualifiedName]
|
||||||
|
if !ok {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return CallResult{}, fmt.Errorf("mcp tool %q not found", qualifiedName)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := m.servers[tool.ServerName]
|
||||||
|
if server == nil || server.client == nil {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return CallResult{}, fmt.Errorf("mcp server %q is not active", tool.ServerName)
|
||||||
|
}
|
||||||
|
client := server.client
|
||||||
|
toolName := tool.ToolName
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if args == nil {
|
||||||
|
args = map[string]any{}
|
||||||
|
}
|
||||||
|
return client.CallTool(ctx, toolName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Close() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
servers := make([]*managedServer, 0, len(m.servers))
|
||||||
|
for _, server := range m.servers {
|
||||||
|
servers = append(servers, server)
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
var firstErr error
|
||||||
|
for _, server := range servers {
|
||||||
|
if server.client == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := server.client.Close(); err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) makeUniqueToolName(serverName, toolName string) string {
|
||||||
|
base := QualifiedToolName(serverName, toolName)
|
||||||
|
if _, exists := m.tools[base]; !exists {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
for index := 2; ; index++ {
|
||||||
|
candidate := fmt.Sprintf("%s_%d", base, index)
|
||||||
|
if len(candidate) > qualifiedNameMaxLen {
|
||||||
|
overflow := len(candidate) - qualifiedNameMaxLen
|
||||||
|
if overflow < len(base) {
|
||||||
|
candidate = base[:len(base)-overflow] + fmt.Sprintf("_%d", index)
|
||||||
|
} else {
|
||||||
|
candidate = candidate[:qualifiedNameMaxLen]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, exists := m.tools[candidate]; !exists {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeSchema(schema map[string]any) map[string]any {
|
||||||
|
if len(schema) == 0 {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func isToolAllowed(name string, include, exclude []string) bool {
|
||||||
|
if len(include) > 0 && !slices.Contains(include, name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if slices.Contains(exclude, name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolsFromMap(tools map[string]RegisteredTool) []RegisteredTool {
|
||||||
|
out := make([]RegisteredTool, 0, len(tools))
|
||||||
|
for _, tool := range tools {
|
||||||
|
out = append(out, tool)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
const qualifiedNameMaxLen = 64
|
||||||
|
|
||||||
|
// QualifiedToolName creates a stable, provider-safe function name.
|
||||||
|
func QualifiedToolName(serverName, toolName string) string {
|
||||||
|
prefix := "mcp_" + sanitizeName(serverName) + "__"
|
||||||
|
tool := sanitizeName(toolName)
|
||||||
|
maxToolLen := qualifiedNameMaxLen - len(prefix)
|
||||||
|
if maxToolLen <= 0 {
|
||||||
|
return prefix[:qualifiedNameMaxLen]
|
||||||
|
}
|
||||||
|
if len(tool) > maxToolLen {
|
||||||
|
tool = tool[:maxToolLen]
|
||||||
|
}
|
||||||
|
return prefix + tool
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeName(value string) string {
|
||||||
|
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||||
|
if trimmed == "" {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(trimmed))
|
||||||
|
|
||||||
|
lastUnderscore := false
|
||||||
|
for i := 0; i < len(trimmed); i++ {
|
||||||
|
ch := trimmed[i]
|
||||||
|
isAlphaNum := (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')
|
||||||
|
if isAlphaNum {
|
||||||
|
b.WriteByte(ch)
|
||||||
|
lastUnderscore = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !lastUnderscore {
|
||||||
|
b.WriteByte('_')
|
||||||
|
lastUnderscore = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := strings.Trim(b.String(), "_")
|
||||||
|
if s == "" {
|
||||||
|
s = "unknown"
|
||||||
|
}
|
||||||
|
if s[0] >= '0' && s[0] <= '9' {
|
||||||
|
return "t_" + s
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultInitTimeoutSeconds = 60
|
||||||
|
defaultCallTimeoutSeconds = 30
|
||||||
|
defaultMaxResponseBytes = 64 * 1024
|
||||||
|
defaultScannerBufferBytes = 64 * 1024
|
||||||
|
maxFrameBytes = 2 * 1024 * 1024
|
||||||
|
maxToolListPages = 50
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtocolMCPFrames = "mcp"
|
||||||
|
ProtocolJSONLines = "jsonl"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerConfig defines one MCP server connection.
|
||||||
|
type ServerConfig struct {
|
||||||
|
Name string
|
||||||
|
Command string
|
||||||
|
Args []string
|
||||||
|
Env map[string]string
|
||||||
|
WorkingDir string
|
||||||
|
Protocol string
|
||||||
|
InitTimeoutSeconds int
|
||||||
|
CallTimeoutSeconds int
|
||||||
|
MaxResponseBytes int
|
||||||
|
IncludeTools []string
|
||||||
|
ExcludeTools []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ServerConfig) InitTimeout() time.Duration {
|
||||||
|
seconds := c.InitTimeoutSeconds
|
||||||
|
if seconds <= 0 {
|
||||||
|
seconds = defaultInitTimeoutSeconds
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ServerConfig) CallTimeout() time.Duration {
|
||||||
|
seconds := c.CallTimeoutSeconds
|
||||||
|
if seconds <= 0 {
|
||||||
|
seconds = defaultCallTimeoutSeconds
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ServerConfig) ResponseLimit() int {
|
||||||
|
if c.MaxResponseBytes <= 0 {
|
||||||
|
return defaultMaxResponseBytes
|
||||||
|
}
|
||||||
|
return c.MaxResponseBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteTool is an MCP tool discovered from a server.
|
||||||
|
type RemoteTool struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
InputSchema map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisteredTool is a discovered tool with a PicoClaw-facing qualified name.
|
||||||
|
type RegisteredTool struct {
|
||||||
|
QualifiedName string
|
||||||
|
ServerName string
|
||||||
|
ToolName string
|
||||||
|
Description string
|
||||||
|
Parameters map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallResult is a normalized MCP tool call result.
|
||||||
|
type CallResult struct {
|
||||||
|
Content string
|
||||||
|
IsError bool
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MCPTool struct {
|
||||||
|
manager *mcp.Manager
|
||||||
|
name string
|
||||||
|
description string
|
||||||
|
parameters map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMCPTool(manager *mcp.Manager, tool mcp.RegisteredTool) *MCPTool {
|
||||||
|
description := tool.Description
|
||||||
|
if description == "" {
|
||||||
|
description = fmt.Sprintf("MCP tool %s from server %s", tool.ToolName, tool.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MCPTool{
|
||||||
|
manager: manager,
|
||||||
|
name: tool.QualifiedName,
|
||||||
|
description: description,
|
||||||
|
parameters: tool.Parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *MCPTool) Name() string {
|
||||||
|
return t.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *MCPTool) Description() string {
|
||||||
|
return t.description
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *MCPTool) Parameters() map[string]interface{} {
|
||||||
|
return t.parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||||
|
if t.manager == nil {
|
||||||
|
return ErrorResult("MCP manager is not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := t.manager.CallTool(ctx, t.name, args)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorResult(fmt.Sprintf("MCP tool %s failed: %v", t.name, err)).WithError(err)
|
||||||
|
}
|
||||||
|
if result.IsError {
|
||||||
|
err := errors.New(result.Content)
|
||||||
|
return ErrorResult(result.Content).WithError(err)
|
||||||
|
}
|
||||||
|
return SilentResult(result.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterMCPTools discovers tools from MCP servers and registers them into the registry.
|
||||||
|
func RegisterMCPTools(ctx context.Context, registry *ToolRegistry, manager *mcp.Manager) (int, error) {
|
||||||
|
if registry == nil || manager == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
discoveredTools, err := manager.DiscoverTools(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return RegisterKnownMCPTools(registry, manager, discoveredTools), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterKnownMCPTools registers already-discovered MCP tools.
|
||||||
|
// This avoids repeated discovery work when multiple registries share one manager.
|
||||||
|
func RegisterKnownMCPTools(registry *ToolRegistry, manager *mcp.Manager, discoveredTools []mcp.RegisteredTool) int {
|
||||||
|
if registry == nil || manager == nil || len(discoveredTools) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range discoveredTools {
|
||||||
|
registry.Register(NewMCPTool(manager, tool))
|
||||||
|
}
|
||||||
|
return len(discoveredTools)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user