mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50c58a3462 | |||
| 403e048821 |
@@ -130,6 +130,28 @@
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
},
|
||||
"mcp": {
|
||||
"enabled": false,
|
||||
"servers": {
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/tmp"
|
||||
],
|
||||
"protocol": "mcp",
|
||||
"env": {},
|
||||
"working_dir": "",
|
||||
"init_timeout_seconds": 60,
|
||||
"call_timeout_seconds": 30,
|
||||
"max_response_bytes": 65536,
|
||||
"include_tools": [],
|
||||
"exclude_tools": []
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
@@ -144,4 +166,4 @@
|
||||
"host": "0.0.0.0",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+59
-4
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
@@ -44,8 +45,12 @@ type AgentLoop struct {
|
||||
running atomic.Bool
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
channelManager *channels.Manager
|
||||
mcpManager *mcp.Manager
|
||||
mcpCloseOnce sync.Once
|
||||
}
|
||||
|
||||
const defaultWebFetchMaxChars = 50000
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
@@ -60,7 +65,14 @@ type processOptions struct {
|
||||
|
||||
// createToolRegistry creates a tool registry with common tools.
|
||||
// This is shared between main agent and subagents.
|
||||
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
|
||||
func createToolRegistry(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
mcpManager *mcp.Manager,
|
||||
discoveredMCPTools []mcp.RegisteredTool,
|
||||
) *tools.ToolRegistry {
|
||||
registry := tools.NewToolRegistry()
|
||||
|
||||
// File system tools
|
||||
@@ -85,7 +97,9 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg
|
||||
}); searchTool != nil {
|
||||
registry.Register(searchTool)
|
||||
}
|
||||
registry.Register(tools.NewWebFetchTool(50000))
|
||||
registry.Register(tools.NewWebFetchTool(defaultWebFetchMaxChars))
|
||||
|
||||
tools.RegisterKnownMCPTools(registry, mcpManager, discoveredMCPTools)
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
registry.Register(tools.NewI2CTool())
|
||||
@@ -113,12 +127,35 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
|
||||
var (
|
||||
mcpManager *mcp.Manager
|
||||
discoveredMCPTools []mcp.RegisteredTool
|
||||
)
|
||||
if cfg.Tools.MCP.Enabled {
|
||||
bootstrap, err := bootstrapMCP(cfg.Tools.MCP)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "MCP tool bootstrap failed",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else if bootstrap != nil {
|
||||
mcpManager = bootstrap.Manager
|
||||
discoveredMCPTools = bootstrap.Tools
|
||||
if len(discoveredMCPTools) > 0 {
|
||||
logger.InfoCF("agent", "MCP tools registered",
|
||||
map[string]interface{}{
|
||||
"count": len(discoveredMCPTools),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create tool registry for main agent
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||
|
||||
// Create subagent manager with its own tool registry
|
||||
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
||||
subagentManager.SetTools(subagentTools)
|
||||
|
||||
@@ -151,11 +188,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
contextBuilder: contextBuilder,
|
||||
tools: toolsRegistry,
|
||||
summarizing: sync.Map{},
|
||||
mcpManager: mcpManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
defer al.closeMCP()
|
||||
|
||||
for al.running.Load() {
|
||||
select {
|
||||
@@ -198,6 +237,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
|
||||
func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
al.closeMCP()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) closeMCP() {
|
||||
if al.mcpManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
al.mcpCloseOnce.Do(func() {
|
||||
if err := al.mcpManager.Close(); err != nil {
|
||||
logger.WarnCF("agent", "Failed to close MCP manager",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
mcpBootstrapMinTimeout = 10 * time.Second
|
||||
mcpBootstrapMaxTimeout = 5 * time.Minute
|
||||
mcpBootstrapGraceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type mcpBootstrapResult struct {
|
||||
Manager *mcp.Manager
|
||||
Tools []mcp.RegisteredTool
|
||||
}
|
||||
|
||||
func bootstrapMCP(cfg config.MCPToolsConfig) (*mcpBootstrapResult, error) {
|
||||
serverConfigs := buildMCPServerConfigs(cfg)
|
||||
if len(serverConfigs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
manager := mcp.NewManager(serverConfigs)
|
||||
|
||||
discoveryTimeout := calculateMCPDiscoveryTimeout(serverConfigs)
|
||||
discoveryCtx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
|
||||
defer cancel()
|
||||
|
||||
discoveredTools, err := manager.DiscoverTools(discoveryCtx)
|
||||
if err != nil {
|
||||
_ = manager.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mcpBootstrapResult{
|
||||
Manager: manager,
|
||||
Tools: discoveredTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func calculateMCPDiscoveryTimeout(serverConfigs map[string]mcp.ServerConfig) time.Duration {
|
||||
maxInitTimeout := mcpBootstrapMinTimeout
|
||||
|
||||
for _, serverConfig := range serverConfigs {
|
||||
initTimeout := serverConfig.InitTimeout()
|
||||
if initTimeout > maxInitTimeout {
|
||||
maxInitTimeout = initTimeout
|
||||
}
|
||||
}
|
||||
|
||||
timeout := maxInitTimeout + mcpBootstrapGraceTimeout
|
||||
if timeout < mcpBootstrapMinTimeout {
|
||||
return mcpBootstrapMinTimeout
|
||||
}
|
||||
if timeout > mcpBootstrapMaxTimeout {
|
||||
return mcpBootstrapMaxTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func buildMCPServerConfigs(cfg config.MCPToolsConfig) map[string]mcp.ServerConfig {
|
||||
servers := make(map[string]mcp.ServerConfig, len(cfg.Servers))
|
||||
|
||||
for serverName, serverCfg := range cfg.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
envCopy := make(map[string]string, len(serverCfg.Env))
|
||||
for key, value := range serverCfg.Env {
|
||||
envCopy[key] = value
|
||||
}
|
||||
|
||||
servers[serverName] = mcp.ServerConfig{
|
||||
Name: serverName,
|
||||
Command: serverCfg.Command,
|
||||
Args: append([]string{}, serverCfg.Args...),
|
||||
Env: envCopy,
|
||||
WorkingDir: serverCfg.WorkingDir,
|
||||
Protocol: inferMCPProtocol(serverCfg.Protocol, serverCfg.Command),
|
||||
InitTimeoutSeconds: serverCfg.InitTimeoutSeconds,
|
||||
CallTimeoutSeconds: serverCfg.CallTimeoutSeconds,
|
||||
MaxResponseBytes: serverCfg.MaxResponseBytes,
|
||||
IncludeTools: append([]string{}, serverCfg.IncludeTools...),
|
||||
ExcludeTools: append([]string{}, serverCfg.ExcludeTools...),
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
func inferMCPProtocol(configuredProtocol, command string) string {
|
||||
if protocol := strings.TrimSpace(configuredProtocol); protocol != "" {
|
||||
return protocol
|
||||
}
|
||||
|
||||
// Context7 currently emits JSON-RPC messages as JSONL on stdio,
|
||||
// so defaulting avoids long startup waits when protocol is omitted.
|
||||
if strings.Contains(strings.ToLower(command), "context7-mcp") {
|
||||
return mcp.ProtocolJSONLines
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
+82
-1
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
@@ -51,7 +52,10 @@ type Config struct {
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
Devices DevicesConfig `json:"devices"`
|
||||
mu sync.RWMutex
|
||||
// MCPServers is a compatibility alias for configs using top-level "mcpServers".
|
||||
// Canonical config remains tools.mcp.servers.
|
||||
MCPServers map[string]LegacyMCPServerConfig `json:"mcpServers,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type AgentsConfig struct {
|
||||
@@ -222,9 +226,38 @@ type CronToolsConfig struct {
|
||||
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
||||
}
|
||||
|
||||
type MCPServerConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
WorkingDir string `json:"working_dir"`
|
||||
Protocol string `json:"protocol"`
|
||||
InitTimeoutSeconds int `json:"init_timeout_seconds"`
|
||||
CallTimeoutSeconds int `json:"call_timeout_seconds"`
|
||||
MaxResponseBytes int `json:"max_response_bytes"`
|
||||
IncludeTools []string `json:"include_tools"`
|
||||
ExcludeTools []string `json:"exclude_tools"`
|
||||
}
|
||||
|
||||
type MCPToolsConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
Servers map[string]MCPServerConfig `json:"servers"`
|
||||
}
|
||||
|
||||
// LegacyMCPServerConfig supports compatibility with "mcpServers" style config.
|
||||
type LegacyMCPServerConfig struct {
|
||||
Type string `json:"type"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
Protocol string `json:"protocol"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
MCP MCPToolsConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
@@ -342,6 +375,10 @@ func DefaultConfig() *Config {
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
MCP: MCPToolsConfig{
|
||||
Enabled: false,
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
@@ -373,9 +410,53 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.applyLegacyMCPServers()
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) applyLegacyMCPServers() {
|
||||
// If canonical MCP config already exists, keep it as source of truth.
|
||||
if len(c.Tools.MCP.Servers) > 0 {
|
||||
return
|
||||
}
|
||||
if len(c.MCPServers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Tools.MCP.Servers == nil {
|
||||
c.Tools.MCP.Servers = map[string]MCPServerConfig{}
|
||||
}
|
||||
|
||||
for name, legacy := range c.MCPServers {
|
||||
if strings.TrimSpace(legacy.Command) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := true
|
||||
if legacy.Type != "" && legacy.Type != "stdio" {
|
||||
enabled = false
|
||||
}
|
||||
|
||||
envCopy := make(map[string]string, len(legacy.Env))
|
||||
for key, value := range legacy.Env {
|
||||
envCopy[key] = value
|
||||
}
|
||||
|
||||
c.Tools.MCP.Servers[name] = MCPServerConfig{
|
||||
Enabled: enabled,
|
||||
Command: legacy.Command,
|
||||
Args: append([]string{}, legacy.Args...),
|
||||
Env: envCopy,
|
||||
Protocol: legacy.Protocol,
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.Tools.MCP.Servers) > 0 {
|
||||
c.Tools.MCP.Enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
cfg.mu.RLock()
|
||||
defer cfg.mu.RUnlock()
|
||||
|
||||
@@ -0,0 +1,603 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// Client is the transport-agnostic MCP client contract.
|
||||
type Client interface {
|
||||
Start(ctx context.Context) error
|
||||
ListTools(ctx context.Context) ([]RemoteTool, error)
|
||||
CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers).
|
||||
type StdioClient struct {
|
||||
config ServerConfig
|
||||
mode string
|
||||
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
|
||||
started bool
|
||||
closed bool
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
stderr io.ReadCloser
|
||||
waitCh chan struct{}
|
||||
pending map[string]chan rpcResponse
|
||||
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
type rpcRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type rpcResponseEnvelope struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *rpcError `json:"error,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
}
|
||||
|
||||
type rpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type rpcResponse struct {
|
||||
result json.RawMessage
|
||||
rpcErr *rpcError
|
||||
err error
|
||||
}
|
||||
|
||||
type initializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
ClientInfo map[string]interface{} `json:"clientInfo"`
|
||||
}
|
||||
|
||||
func NewStdioClient(config ServerConfig) *StdioClient {
|
||||
return &StdioClient{
|
||||
config: config,
|
||||
mode: normalizeProtocol(config.Protocol),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) Start(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.started {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(c.config.Command) == "" {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("mcp server %q command is empty", c.config.Name)
|
||||
}
|
||||
|
||||
cmd := exec.Command(c.config.Command, c.config.Args...)
|
||||
if c.config.WorkingDir != "" {
|
||||
cmd.Dir = c.config.WorkingDir
|
||||
}
|
||||
cmd.Env = buildProcessEnv(c.config.Env)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
c.started = true
|
||||
c.closed = false
|
||||
c.cmd = cmd
|
||||
c.stdin = stdin
|
||||
c.stdout = stdout
|
||||
c.stderr = stderr
|
||||
c.waitCh = make(chan struct{})
|
||||
c.pending = make(map[string]chan rpcResponse)
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.readLoop()
|
||||
go c.waitLoop()
|
||||
go c.drainStderr()
|
||||
|
||||
initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout())
|
||||
defer cancel()
|
||||
|
||||
_, err = c.request(initCtx, "initialize", initializeParams{
|
||||
ProtocolVersion: "2024-11-05",
|
||||
Capabilities: map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
ClientInfo: map[string]any{
|
||||
"name": "picoclaw",
|
||||
"version": "dev",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialize failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.notify("notifications/initialized", map[string]any{}); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialized notification failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) {
|
||||
if err := c.Start(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type listToolsResponse struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"inputSchema"`
|
||||
} `json:"tools"`
|
||||
NextCursor string `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
allTools := make([]RemoteTool, 0, 8)
|
||||
cursor := ""
|
||||
|
||||
for page := 0; page < maxToolListPages; page++ {
|
||||
params := map[string]any{}
|
||||
if cursor != "" {
|
||||
params["cursor"] = cursor
|
||||
}
|
||||
|
||||
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||
raw, err := c.request(callCtx, "tools/list", params)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response listToolsResponse
|
||||
if err := json.Unmarshal(raw, &response); err != nil {
|
||||
return nil, fmt.Errorf("decode tools/list response: %w", err)
|
||||
}
|
||||
|
||||
for _, tool := range response.Tools {
|
||||
allTools = append(allTools, RemoteTool{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
InputSchema: tool.InputSchema,
|
||||
})
|
||||
}
|
||||
|
||||
if response.NextCursor == "" {
|
||||
return allTools, nil
|
||||
}
|
||||
cursor = response.NextCursor
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages)
|
||||
}
|
||||
|
||||
func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) {
|
||||
if err := c.Start(ctx); err != nil {
|
||||
return CallResult{}, err
|
||||
}
|
||||
|
||||
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||
defer cancel()
|
||||
|
||||
raw, err := c.request(callCtx, "tools/call", map[string]any{
|
||||
"name": toolName,
|
||||
"arguments": arguments,
|
||||
})
|
||||
if err != nil {
|
||||
return CallResult{}, err
|
||||
}
|
||||
|
||||
return formatCallPayload(raw, c.config.ResponseLimit())
|
||||
}
|
||||
|
||||
func (c *StdioClient) Close() error {
|
||||
c.mu.Lock()
|
||||
if !c.started || c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
cmd := c.cmd
|
||||
stdin := c.stdin
|
||||
waitCh := c.waitCh
|
||||
c.mu.Unlock()
|
||||
|
||||
c.failPending(errors.New("mcp client closed"))
|
||||
|
||||
if stdin != nil {
|
||||
_ = stdin.Close()
|
||||
}
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
|
||||
if waitCh != nil {
|
||||
select {
|
||||
case <-waitCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) {
|
||||
id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10)
|
||||
responseCh := make(chan rpcResponse, 1)
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("mcp server %q is closed", c.config.Name)
|
||||
}
|
||||
c.pending[id] = responseCh
|
||||
c.mu.Unlock()
|
||||
|
||||
req := rpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
if err := c.writeMessage(req); err != nil {
|
||||
c.removePending(id)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.removePending(id)
|
||||
return nil, ctx.Err()
|
||||
case response := <-responseCh:
|
||||
if response.err != nil {
|
||||
return nil, response.err
|
||||
}
|
||||
if response.rpcErr != nil {
|
||||
return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message)
|
||||
}
|
||||
return response.result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) notify(method string, params any) error {
|
||||
req := rpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
return c.writeMessage(req)
|
||||
}
|
||||
|
||||
func (c *StdioClient) writeMessage(payload any) error {
|
||||
c.mu.Lock()
|
||||
if c.closed || c.stdin == nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("mcp server %q is not writable", c.config.Name)
|
||||
}
|
||||
stdin := c.stdin
|
||||
c.mu.Unlock()
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal json-rpc payload: %w", err)
|
||||
}
|
||||
|
||||
if c.mode == ProtocolJSONLines {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if _, err := stdin.Write(append(data, '\n')); err != nil {
|
||||
return fmt.Errorf("write jsonl body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
|
||||
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if _, err := io.WriteString(stdin, frameHeader); err != nil {
|
||||
return fmt.Errorf("write frame header: %w", err)
|
||||
}
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("write frame body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) readLoop() {
|
||||
if c.mode == ProtocolJSONLines {
|
||||
c.readJSONLLoop()
|
||||
return
|
||||
}
|
||||
|
||||
c.readMCPFrameLoop()
|
||||
}
|
||||
|
||||
func (c *StdioClient) readMCPFrameLoop() {
|
||||
reader := bufio.NewReader(c.stdout)
|
||||
|
||||
for {
|
||||
payload, err := readFramePayload(reader)
|
||||
if err != nil {
|
||||
c.failPending(err)
|
||||
return
|
||||
}
|
||||
|
||||
var envelope rpcResponseEnvelope
|
||||
if err := json.Unmarshal(payload, &envelope); err != nil {
|
||||
continue
|
||||
}
|
||||
c.dispatchResponse(envelope)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) readJSONLLoop() {
|
||||
scanner := bufio.NewScanner(c.stdout)
|
||||
scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var envelope rpcResponseEnvelope
|
||||
if err := json.Unmarshal([]byte(line), &envelope); err != nil {
|
||||
continue
|
||||
}
|
||||
c.dispatchResponse(envelope)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
c.failPending(err)
|
||||
return
|
||||
}
|
||||
c.failPending(io.EOF)
|
||||
}
|
||||
|
||||
func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) {
|
||||
if len(envelope.ID) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
id, ok := parseRPCID(envelope.ID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
responseCh := c.pending[id]
|
||||
if responseCh != nil {
|
||||
delete(c.pending, id)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if responseCh == nil {
|
||||
return
|
||||
}
|
||||
|
||||
response := rpcResponse{
|
||||
result: envelope.Result,
|
||||
rpcErr: envelope.Error,
|
||||
}
|
||||
select {
|
||||
case responseCh <- response:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) waitLoop() {
|
||||
c.mu.Lock()
|
||||
cmd := c.cmd
|
||||
waitCh := c.waitCh
|
||||
serverName := c.config.Name
|
||||
c.mu.Unlock()
|
||||
|
||||
if cmd == nil {
|
||||
if waitCh != nil {
|
||||
close(waitCh)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := cmd.Wait()
|
||||
if waitCh != nil {
|
||||
close(waitCh)
|
||||
}
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "MCP process exited with error",
|
||||
map[string]any{
|
||||
"server": serverName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) drainStderr() {
|
||||
c.mu.Lock()
|
||||
stderr := c.stderr
|
||||
serverName := c.config.Name
|
||||
c.mu.Unlock()
|
||||
|
||||
if stderr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
logger.DebugCF("mcp", "MCP server stderr",
|
||||
map[string]any{
|
||||
"server": serverName,
|
||||
"line": line,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) failPending(err error) {
|
||||
c.mu.Lock()
|
||||
pending := c.pending
|
||||
c.pending = make(map[string]chan rpcResponse)
|
||||
c.mu.Unlock()
|
||||
|
||||
if len(pending) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range pending {
|
||||
select {
|
||||
case ch <- rpcResponse{err: err}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) removePending(id string) {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func readFramePayload(reader *bufio.Reader) ([]byte, error) {
|
||||
contentLength := -1
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if trimmed == "" {
|
||||
break
|
||||
}
|
||||
|
||||
parts := strings.SplitN(trimmed, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
headerName := strings.TrimSpace(strings.ToLower(parts[0]))
|
||||
if headerName != "content-length" {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(parts[1])
|
||||
length, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid content-length %q: %w", value, err)
|
||||
}
|
||||
contentLength = length
|
||||
}
|
||||
|
||||
if contentLength <= 0 {
|
||||
return nil, fmt.Errorf("missing content-length")
|
||||
}
|
||||
if contentLength > maxFrameBytes {
|
||||
return nil, fmt.Errorf("frame too large (%d bytes)", contentLength)
|
||||
}
|
||||
|
||||
payload := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func parseRPCID(raw json.RawMessage) (string, bool) {
|
||||
var stringID string
|
||||
if err := json.Unmarshal(raw, &stringID); err == nil {
|
||||
return stringID, true
|
||||
}
|
||||
|
||||
var numberID float64
|
||||
if err := json.Unmarshal(raw, &numberID); err == nil {
|
||||
return strconv.FormatInt(int64(numberID), 10), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if _, hasDeadline := parent.Deadline(); hasDeadline {
|
||||
return context.WithCancel(parent)
|
||||
}
|
||||
return context.WithTimeout(parent, timeout)
|
||||
}
|
||||
|
||||
func buildProcessEnv(custom map[string]string) []string {
|
||||
base := os.Environ()
|
||||
if len(custom) == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(custom))
|
||||
for key := range custom {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
env := make([]string, 0, len(base)+len(keys))
|
||||
env = append(env, base...)
|
||||
for _, key := range keys {
|
||||
env = append(env, key+"="+custom[key])
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func normalizeProtocol(protocol string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(protocol)) {
|
||||
case "", ProtocolMCPFrames:
|
||||
return ProtocolMCPFrames
|
||||
case ProtocolJSONLines:
|
||||
return ProtocolJSONLines
|
||||
default:
|
||||
return ProtocolMCPFrames
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type callResponse struct {
|
||||
Content []contentBlock `json:"content"`
|
||||
StructuredContent any `json:"structuredContent,omitempty"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
type contentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
func formatCallPayload(raw json.RawMessage, responseLimit int) (CallResult, error) {
|
||||
var payload callResponse
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
// Fallback for servers that return non-standard payloads.
|
||||
return CallResult{
|
||||
Content: truncateString(strings.TrimSpace(string(raw)), responseLimit),
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(payload.Content)+1)
|
||||
for _, block := range payload.Content {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
parts = append(parts, block.Text)
|
||||
}
|
||||
}
|
||||
|
||||
if payload.StructuredContent != nil {
|
||||
if encoded, err := json.Marshal(payload.StructuredContent); err == nil {
|
||||
parts = append(parts, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(strings.Join(parts, "\n"))
|
||||
if content == "" {
|
||||
content = "{}"
|
||||
}
|
||||
|
||||
return CallResult{
|
||||
Content: truncateString(content, responseLimit),
|
||||
IsError: payload.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func truncateString(value string, maxBytes int) string {
|
||||
if maxBytes <= 0 || len(value) <= maxBytes {
|
||||
return value
|
||||
}
|
||||
if maxBytes <= 12 {
|
||||
return value[:maxBytes]
|
||||
}
|
||||
return value[:maxBytes-12] + "\n...[truncated]"
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type clientFactory func(config ServerConfig) Client
|
||||
|
||||
type managedServer struct {
|
||||
config ServerConfig
|
||||
client Client
|
||||
}
|
||||
|
||||
// Manager owns MCP servers and maps discovered MCP tools to PicoClaw tools.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
servers map[string]*managedServer
|
||||
tools map[string]RegisteredTool
|
||||
|
||||
discovered bool
|
||||
newClient clientFactory
|
||||
}
|
||||
|
||||
func NewManager(configs map[string]ServerConfig) *Manager {
|
||||
servers := make(map[string]*managedServer, len(configs))
|
||||
for name, cfg := range configs {
|
||||
copied := cfg
|
||||
copied.Name = name
|
||||
servers[name] = &managedServer{config: copied}
|
||||
}
|
||||
return &Manager{
|
||||
servers: servers,
|
||||
tools: make(map[string]RegisteredTool),
|
||||
discovered: false,
|
||||
newClient: func(config ServerConfig) Client {
|
||||
return NewStdioClient(config)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DiscoverTools starts configured MCP servers and returns discovered tool metadata.
|
||||
func (m *Manager) DiscoverTools(ctx context.Context) ([]RegisteredTool, error) {
|
||||
m.mu.Lock()
|
||||
if m.discovered {
|
||||
tools := toolsFromMap(m.tools)
|
||||
m.mu.Unlock()
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
discoveryErrors := make([]string, 0)
|
||||
|
||||
for serverName, server := range m.servers {
|
||||
client := m.newClient(server.config)
|
||||
if err := client.Start(ctx); err != nil {
|
||||
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
remoteTools, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
server.client = client
|
||||
for _, remoteTool := range remoteTools {
|
||||
if !isToolAllowed(remoteTool.Name, server.config.IncludeTools, server.config.ExcludeTools) {
|
||||
continue
|
||||
}
|
||||
|
||||
qualifiedName := m.makeUniqueToolName(serverName, remoteTool.Name)
|
||||
parameters := normalizeSchema(remoteTool.InputSchema)
|
||||
m.tools[qualifiedName] = RegisteredTool{
|
||||
QualifiedName: qualifiedName,
|
||||
ServerName: serverName,
|
||||
ToolName: remoteTool.Name,
|
||||
Description: remoteTool.Description,
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.discovered = true
|
||||
tools := toolsFromMap(m.tools)
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(tools) == 0 && len(discoveryErrors) > 0 {
|
||||
return nil, fmt.Errorf("mcp tool discovery failed: %s", strings.Join(discoveryErrors, "; "))
|
||||
}
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (m *Manager) CallTool(ctx context.Context, qualifiedName string, args map[string]any) (CallResult, error) {
|
||||
m.mu.RLock()
|
||||
tool, ok := m.tools[qualifiedName]
|
||||
if !ok {
|
||||
m.mu.RUnlock()
|
||||
return CallResult{}, fmt.Errorf("mcp tool %q not found", qualifiedName)
|
||||
}
|
||||
|
||||
server := m.servers[tool.ServerName]
|
||||
if server == nil || server.client == nil {
|
||||
m.mu.RUnlock()
|
||||
return CallResult{}, fmt.Errorf("mcp server %q is not active", tool.ServerName)
|
||||
}
|
||||
client := server.client
|
||||
toolName := tool.ToolName
|
||||
m.mu.RUnlock()
|
||||
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
return client.CallTool(ctx, toolName, args)
|
||||
}
|
||||
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
servers := make([]*managedServer, 0, len(m.servers))
|
||||
for _, server := range m.servers {
|
||||
servers = append(servers, server)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
var firstErr error
|
||||
for _, server := range servers {
|
||||
if server.client == nil {
|
||||
continue
|
||||
}
|
||||
if err := server.client.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (m *Manager) makeUniqueToolName(serverName, toolName string) string {
|
||||
base := QualifiedToolName(serverName, toolName)
|
||||
if _, exists := m.tools[base]; !exists {
|
||||
return base
|
||||
}
|
||||
|
||||
for index := 2; ; index++ {
|
||||
candidate := fmt.Sprintf("%s_%d", base, index)
|
||||
if len(candidate) > qualifiedNameMaxLen {
|
||||
overflow := len(candidate) - qualifiedNameMaxLen
|
||||
if overflow < len(base) {
|
||||
candidate = base[:len(base)-overflow] + fmt.Sprintf("_%d", index)
|
||||
} else {
|
||||
candidate = candidate[:qualifiedNameMaxLen]
|
||||
}
|
||||
}
|
||||
if _, exists := m.tools[candidate]; !exists {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSchema(schema map[string]any) map[string]any {
|
||||
if len(schema) == 0 {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func isToolAllowed(name string, include, exclude []string) bool {
|
||||
if len(include) > 0 && !slices.Contains(include, name) {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(exclude, name) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func toolsFromMap(tools map[string]RegisteredTool) []RegisteredTool {
|
||||
out := make([]RegisteredTool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
out = append(out, tool)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package mcp
|
||||
|
||||
import "strings"
|
||||
|
||||
const qualifiedNameMaxLen = 64
|
||||
|
||||
// QualifiedToolName creates a stable, provider-safe function name.
|
||||
func QualifiedToolName(serverName, toolName string) string {
|
||||
prefix := "mcp_" + sanitizeName(serverName) + "__"
|
||||
tool := sanitizeName(toolName)
|
||||
maxToolLen := qualifiedNameMaxLen - len(prefix)
|
||||
if maxToolLen <= 0 {
|
||||
return prefix[:qualifiedNameMaxLen]
|
||||
}
|
||||
if len(tool) > maxToolLen {
|
||||
tool = tool[:maxToolLen]
|
||||
}
|
||||
return prefix + tool
|
||||
}
|
||||
|
||||
func sanitizeName(value string) string {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
if trimmed == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(trimmed))
|
||||
|
||||
lastUnderscore := false
|
||||
for i := 0; i < len(trimmed); i++ {
|
||||
ch := trimmed[i]
|
||||
isAlphaNum := (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')
|
||||
if isAlphaNum {
|
||||
b.WriteByte(ch)
|
||||
lastUnderscore = false
|
||||
continue
|
||||
}
|
||||
if !lastUnderscore {
|
||||
b.WriteByte('_')
|
||||
lastUnderscore = true
|
||||
}
|
||||
}
|
||||
|
||||
s := strings.Trim(b.String(), "_")
|
||||
if s == "" {
|
||||
s = "unknown"
|
||||
}
|
||||
if s[0] >= '0' && s[0] <= '9' {
|
||||
return "t_" + s
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package mcp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultInitTimeoutSeconds = 60
|
||||
defaultCallTimeoutSeconds = 30
|
||||
defaultMaxResponseBytes = 64 * 1024
|
||||
defaultScannerBufferBytes = 64 * 1024
|
||||
maxFrameBytes = 2 * 1024 * 1024
|
||||
maxToolListPages = 50
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolMCPFrames = "mcp"
|
||||
ProtocolJSONLines = "jsonl"
|
||||
)
|
||||
|
||||
// ServerConfig defines one MCP server connection.
|
||||
type ServerConfig struct {
|
||||
Name string
|
||||
Command string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
WorkingDir string
|
||||
Protocol string
|
||||
InitTimeoutSeconds int
|
||||
CallTimeoutSeconds int
|
||||
MaxResponseBytes int
|
||||
IncludeTools []string
|
||||
ExcludeTools []string
|
||||
}
|
||||
|
||||
func (c ServerConfig) InitTimeout() time.Duration {
|
||||
seconds := c.InitTimeoutSeconds
|
||||
if seconds <= 0 {
|
||||
seconds = defaultInitTimeoutSeconds
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c ServerConfig) CallTimeout() time.Duration {
|
||||
seconds := c.CallTimeoutSeconds
|
||||
if seconds <= 0 {
|
||||
seconds = defaultCallTimeoutSeconds
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c ServerConfig) ResponseLimit() int {
|
||||
if c.MaxResponseBytes <= 0 {
|
||||
return defaultMaxResponseBytes
|
||||
}
|
||||
return c.MaxResponseBytes
|
||||
}
|
||||
|
||||
// RemoteTool is an MCP tool discovered from a server.
|
||||
type RemoteTool struct {
|
||||
Name string
|
||||
Description string
|
||||
InputSchema map[string]any
|
||||
}
|
||||
|
||||
// RegisteredTool is a discovered tool with a PicoClaw-facing qualified name.
|
||||
type RegisteredTool struct {
|
||||
QualifiedName string
|
||||
ServerName string
|
||||
ToolName string
|
||||
Description string
|
||||
Parameters map[string]any
|
||||
}
|
||||
|
||||
// CallResult is a normalized MCP tool call result.
|
||||
type CallResult struct {
|
||||
Content string
|
||||
IsError bool
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
type MCPTool struct {
|
||||
manager *mcp.Manager
|
||||
name string
|
||||
description string
|
||||
parameters map[string]any
|
||||
}
|
||||
|
||||
func NewMCPTool(manager *mcp.Manager, tool mcp.RegisteredTool) *MCPTool {
|
||||
description := tool.Description
|
||||
if description == "" {
|
||||
description = fmt.Sprintf("MCP tool %s from server %s", tool.ToolName, tool.ServerName)
|
||||
}
|
||||
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
name: tool.QualifiedName,
|
||||
description: description,
|
||||
parameters: tool.Parameters,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *MCPTool) Description() string {
|
||||
return t.description
|
||||
}
|
||||
|
||||
func (t *MCPTool) Parameters() map[string]interface{} {
|
||||
return t.parameters
|
||||
}
|
||||
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("MCP manager is not configured")
|
||||
}
|
||||
|
||||
result, err := t.manager.CallTool(ctx, t.name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool %s failed: %v", t.name, err)).WithError(err)
|
||||
}
|
||||
if result.IsError {
|
||||
err := errors.New(result.Content)
|
||||
return ErrorResult(result.Content).WithError(err)
|
||||
}
|
||||
return SilentResult(result.Content)
|
||||
}
|
||||
|
||||
// RegisterMCPTools discovers tools from MCP servers and registers them into the registry.
|
||||
func RegisterMCPTools(ctx context.Context, registry *ToolRegistry, manager *mcp.Manager) (int, error) {
|
||||
if registry == nil || manager == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
discoveredTools, err := manager.DiscoverTools(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return RegisterKnownMCPTools(registry, manager, discoveredTools), nil
|
||||
}
|
||||
|
||||
// RegisterKnownMCPTools registers already-discovered MCP tools.
|
||||
// This avoids repeated discovery work when multiple registries share one manager.
|
||||
func RegisterKnownMCPTools(registry *ToolRegistry, manager *mcp.Manager, discoveredTools []mcp.RegisteredTool) int {
|
||||
if registry == nil || manager == nil || len(discoveredTools) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, tool := range discoveredTools {
|
||||
registry.Register(NewMCPTool(manager, tool))
|
||||
}
|
||||
return len(discoveredTools)
|
||||
}
|
||||
Reference in New Issue
Block a user