Files
picoclaw/pkg/mcp/manager.go
T

720 lines
18 KiB
Go

package mcp
import (
"bufio"
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
)
// headerTransport is an http.RoundTripper that adds custom headers to requests
type headerTransport struct {
base http.RoundTripper
headers map[string]string
}
func expandHomeCommandPath(command string) string {
if command == "" || command[0] != '~' {
return command
}
home, err := os.UserHomeDir()
if err != nil {
return command
}
if command == "~" {
return home
}
if strings.HasPrefix(command, "~/") || strings.HasPrefix(command, "~\\") {
return filepath.Join(home, command[2:])
}
return command
}
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
req = req.Clone(req.Context())
// Add custom headers
for key, value := range t.headers {
req.Header.Set(key, value)
}
// Use the base transport
base := t.base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}
// loadEnvFile loads environment variables from a file in .env format
// Each line should be in the format: KEY=value
// Lines starting with # are comments
// Empty lines are ignored
func loadEnvFile(path string) (map[string]string, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open env file: %w", err)
}
defer file.Close()
envVars := make(map[string]string)
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Parse KEY=value
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key == "" {
return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum)
}
// Remove surrounding quotes if present
if len(value) >= 2 {
if (value[0] == '"' && value[len(value)-1] == '"') ||
(value[0] == '\'' && value[len(value)-1] == '\'') {
value = value[1 : len(value)-1]
}
}
envVars[key] = value
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading env file: %w", err)
}
return envVars, nil
}
// ServerConnection represents a connection to an MCP server
type ServerConnection struct {
Name string
Config config.MCPServerConfig
Client *mcp.Client
Session *mcp.ClientSession
Tools []*mcp.Tool
reconnectMu sync.Mutex
}
// Manager manages multiple MCP server connections
type Manager struct {
servers map[string]*ServerConnection
runtimeEvents runtimeevents.Bus
mu sync.RWMutex
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
wg sync.WaitGroup // tracks in-flight CallTool calls
}
var connectServerFunc = connectServer
// ManagerOption configures an MCP manager.
type ManagerOption func(*Manager)
// WithRuntimeEvents injects the runtime event bus used for MCP observations.
func WithRuntimeEvents(eventBus runtimeevents.Bus) ManagerOption {
return func(m *Manager) {
m.runtimeEvents = eventBus
}
}
// ServerEventPayload describes MCP server connection events.
type ServerEventPayload struct {
Server string `json:"server"`
Type string `json:"type,omitempty"`
URL string `json:"url,omitempty"`
Command string `json:"command,omitempty"`
Tool string `json:"tool,omitempty"`
ToolCount int `json:"tool_count,omitempty"`
Error string `json:"error,omitempty"`
}
// NewManager creates a new MCP manager
func NewManager(opts ...ManagerOption) *Manager {
m := &Manager{
servers: make(map[string]*ServerConnection),
}
for _, opt := range opts {
if opt != nil {
opt(m)
}
}
return m
}
// LoadFromConfig loads MCP servers from configuration
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
}
// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
// This is the minimal dependency version that doesn't require the full Config object.
func (m *Manager) LoadFromMCPConfig(
ctx context.Context,
mcpCfg config.MCPConfig,
workspacePath string,
) error {
if !mcpCfg.Enabled {
logger.InfoCF("mcp", "MCP integration is disabled", nil)
return nil
}
if len(mcpCfg.Servers) == 0 {
logger.InfoCF("mcp", "No MCP servers configured", nil)
return nil
}
logger.InfoCF("mcp", "Initializing MCP servers",
map[string]any{
"count": len(mcpCfg.Servers),
})
var wg sync.WaitGroup
errs := make(chan error, len(mcpCfg.Servers))
enabledCount := 0
for name, serverCfg := range mcpCfg.Servers {
if !serverCfg.Enabled {
logger.DebugCF("mcp", "Skipping disabled server",
map[string]any{
"server": name,
})
continue
}
enabledCount++
wg.Add(1)
go func(name string, serverCfg config.MCPServerConfig, workspace string) {
defer wg.Done()
// Resolve relative envFile paths relative to workspace
if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
if workspace == "" {
err := fmt.Errorf(
"workspace path is empty while resolving relative envFile %q for server %s",
serverCfg.EnvFile,
name,
)
logger.ErrorCF("mcp", "Invalid MCP server configuration",
map[string]any{
"server": name,
"env_file": serverCfg.EnvFile,
"error": err.Error(),
})
errs <- err
return
}
serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
}
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
logger.ErrorCF("mcp", "Failed to connect to MCP server",
map[string]any{
"server": name,
"error": err.Error(),
})
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
}
}(name, serverCfg, workspacePath)
}
wg.Wait()
close(errs)
// Collect errors
var allErrors []error
for err := range errs {
allErrors = append(allErrors, err)
}
connectedCount := len(m.GetServers())
// If all enabled servers failed to connect, return aggregated error
if enabledCount > 0 && connectedCount == 0 {
logger.ErrorCF("mcp", "All MCP servers failed to connect",
map[string]any{
"failed": len(allErrors),
"total": enabledCount,
})
return errors.Join(allErrors...)
}
if len(allErrors) > 0 {
logger.WarnCF("mcp", "Some MCP servers failed to connect",
map[string]any{
"failed": len(allErrors),
"connected": connectedCount,
"total": enabledCount,
})
// Don't fail completely if some servers successfully connected
}
logger.InfoCF("mcp", "MCP server initialization complete",
map[string]any{
"connected": connectedCount,
"total": enabledCount,
})
return nil
}
// ConnectServer connects to a single MCP server
func (m *Manager) ConnectServer(
ctx context.Context,
name string,
cfg config.MCPServerConfig,
) error {
m.publishServerEvent(runtimeevents.KindMCPServerConnecting, name, cfg, 0, nil)
conn, err := connectServerFunc(ctx, name, cfg)
if err != nil {
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, err)
return err
}
m.mu.Lock()
defer m.mu.Unlock()
if m.closed.Load() {
_ = conn.Session.Close()
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, fmt.Errorf("manager is closed"))
return fmt.Errorf("manager is closed")
}
m.servers[name] = conn
for _, tool := range conn.Tools {
toolName := ""
if tool != nil {
toolName = tool.Name
}
m.publishToolDiscovered(name, cfg, toolName)
}
m.publishServerEvent(runtimeevents.KindMCPServerConnected, name, cfg, len(conn.Tools), nil)
return nil
}
func connectServer(
ctx context.Context,
name string,
cfg config.MCPServerConfig,
) (*ServerConnection, error) {
logger.InfoCF("mcp", "Connecting to MCP server",
map[string]any{
"server": name,
"command": cfg.Command,
"args_count": len(cfg.Args),
})
// Create client
client := mcp.NewClient(&mcp.Implementation{
Name: "picoclaw",
Version: "1.0.0",
}, nil)
// Create transport based on configuration
// Auto-detect transport type if not explicitly specified
var transport mcp.Transport
transportType := config.EffectiveMCPTransportType(cfg)
if transportType == "" {
return nil, fmt.Errorf("either URL or command must be provided")
}
switch transportType {
case "sse", "http":
if cfg.URL == "" {
return nil, fmt.Errorf("URL is required for SSE/HTTP transport")
}
// Configure DisableStandaloneSSE based on transport type.
// - "http": Streamable HTTP request-response mode. Disable the standalone
// SSE stream to avoid compatibility issues with servers that don't
// support the optional GET listener.
// - "sse": Bidirectional mode. Enable the standalone SSE stream to receive
// server-initiated notifications (e.g., ToolListChangedNotification).
// - Empty or auto-detected: Defaults to "sse" behavior (standalone SSE enabled).
disableStandaloneSSE := transportType == "http"
logger.DebugCF("mcp", "Using SSE/HTTP transport",
map[string]any{
"server": name,
"url": cfg.URL,
"disableStandaloneSSE": disableStandaloneSSE,
})
sseTransport := &mcp.StreamableClientTransport{
Endpoint: cfg.URL,
DisableStandaloneSSE: disableStandaloneSSE,
}
// Add custom headers if provided
if len(cfg.Headers) > 0 {
// Create a custom HTTP client with header-injecting transport
sseTransport.HTTPClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: cfg.Headers,
},
}
logger.DebugCF("mcp", "Added custom HTTP headers",
map[string]any{
"server": name,
"header_count": len(cfg.Headers),
})
}
transport = sseTransport
case "stdio":
if cfg.Command == "" {
return nil, fmt.Errorf("command is required for stdio transport")
}
logger.DebugCF("mcp", "Using stdio transport",
map[string]any{
"server": name,
"command": cfg.Command,
})
// Create command with context
cmd := exec.CommandContext(ctx, expandHomeCommandPath(cfg.Command), cfg.Args...)
// Build environment variables with proper override semantics
// Use a map to ensure config variables override file variables
envMap := make(map[string]string)
// Start with parent process environment
for _, e := range cmd.Environ() {
if idx := strings.Index(e, "="); idx > 0 {
envMap[e[:idx]] = e[idx+1:]
}
}
// Load environment variables from file if specified
if cfg.EnvFile != "" {
envVars, err := loadEnvFile(cfg.EnvFile)
if err != nil {
return nil, fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
}
for k, v := range envVars {
envMap[k] = v
}
logger.DebugCF("mcp", "Loaded environment variables from file",
map[string]any{
"server": name,
"envFile": cfg.EnvFile,
"var_count": len(envVars),
})
}
// Environment variables from config override those from file
for k, v := range cfg.Env {
envMap[k] = v
}
// Convert map to slice
env := make([]string, 0, len(envMap))
for k, v := range envMap {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
cmd.Env = env
transport = &isolatedCommandTransport{Command: cmd}
default:
return nil, fmt.Errorf(
"unsupported transport type: %s (supported: stdio, sse, http, streamable-http)",
transportType,
)
}
// Connect to server
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return nil, fmt.Errorf("failed to connect: %w", err)
}
// Get server info
initResult := session.InitializeResult()
logger.InfoCF("mcp", "Connected to MCP server",
map[string]any{
"server": name,
"serverName": initResult.ServerInfo.Name,
"serverVersion": initResult.ServerInfo.Version,
"protocol": initResult.ProtocolVersion,
})
// List available tools if supported
tools, err := listServerTools(ctx, name, session, initResult)
if err != nil {
_ = session.Close()
return nil, err
}
return &ServerConnection{
Name: name,
Config: cfg,
Client: client,
Session: session,
Tools: tools,
}, nil
}
// GetServers returns all connected servers
func (m *Manager) GetServers() map[string]*ServerConnection {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]*ServerConnection, len(m.servers))
for k, v := range m.servers {
result[k] = v
}
return result
}
// GetServer returns a specific server connection
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
conn, ok := m.servers[name]
return conn, ok
}
// CallTool calls a tool on a specific server
func (m *Manager) CallTool(
ctx context.Context,
serverName, toolName string,
arguments map[string]any,
) (*mcp.CallToolResult, error) {
// Check if closed before acquiring lock (fast path)
if m.closed.Load() {
return nil, fmt.Errorf("manager is closed")
}
m.mu.RLock()
// Double-check after acquiring lock to prevent TOCTOU race
if m.closed.Load() {
m.mu.RUnlock()
return nil, fmt.Errorf("manager is closed")
}
conn, ok := m.servers[serverName]
if ok {
m.wg.Add(1) // Add to WaitGroup while holding the lock
}
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("server %s not found", serverName)
}
defer m.wg.Done()
params := &mcp.CallToolParams{
Name: toolName,
Arguments: arguments,
}
result, err := conn.Session.CallTool(ctx, params)
if err != nil {
if shouldReconnectCallError(err) {
logger.WarnCF("mcp", "MCP server session was lost during tool call, reconnecting",
map[string]any{
"server": serverName,
"tool": toolName,
"error": err.Error(),
})
reconnectedConn, reconnectErr := m.reconnectServer(ctx, serverName, conn)
if reconnectErr != nil {
return nil, fmt.Errorf("failed to recover lost MCP session: %w", reconnectErr)
}
result, err = reconnectedConn.Session.CallTool(ctx, params)
if err == nil {
return result, nil
}
}
return nil, fmt.Errorf("failed to call tool: %w", err)
}
return result, nil
}
func listServerTools(
ctx context.Context,
name string,
session *mcp.ClientSession,
initResult *mcp.InitializeResult,
) ([]*mcp.Tool, error) {
var tools []*mcp.Tool
if initResult.Capabilities.Tools == nil {
return tools, nil
}
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
logger.WarnCF("mcp", "Error listing tool",
map[string]any{
"server": name,
"error": err.Error(),
})
continue
}
tools = append(tools, tool)
}
logger.InfoCF("mcp", "Listed tools from MCP server",
map[string]any{
"server": name,
"toolCount": len(tools),
})
return tools, nil
}
func shouldReconnectCallError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, mcp.ErrSessionMissing) {
return true
}
return strings.Contains(strings.ToLower(err.Error()), mcp.ErrSessionMissing.Error())
}
func (m *Manager) reconnectServer(
ctx context.Context,
serverName string,
staleConn *ServerConnection,
) (*ServerConnection, error) {
if staleConn == nil {
return nil, fmt.Errorf("server %s not found", serverName)
}
staleConn.reconnectMu.Lock()
defer staleConn.reconnectMu.Unlock()
if m.closed.Load() {
return nil, fmt.Errorf("manager is closed")
}
m.mu.RLock()
currentConn, ok := m.servers[serverName]
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("server %s not found", serverName)
}
if currentConn != staleConn {
return currentConn, nil
}
freshConn, err := connectServerFunc(ctx, serverName, staleConn.Config)
if err != nil {
return nil, err
}
m.mu.Lock()
if m.closed.Load() {
m.mu.Unlock()
_ = freshConn.Session.Close()
return nil, fmt.Errorf("manager is closed")
}
currentConn, ok = m.servers[serverName]
if !ok {
m.mu.Unlock()
_ = freshConn.Session.Close()
return nil, fmt.Errorf("server %s not found", serverName)
}
if currentConn == staleConn {
m.servers[serverName] = freshConn
staleToClose := staleConn
m.mu.Unlock()
_ = staleToClose.Session.Close()
return freshConn, nil
}
m.mu.Unlock()
_ = freshConn.Session.Close()
return currentConn, nil
}
// Close closes all server connections
func (m *Manager) Close() error {
// Use Swap to atomically set closed=true and get the previous value
// This prevents TOCTOU race with CallTool's closed check
if m.closed.Swap(true) {
return nil // already closed
}
// Wait for all in-flight CallTool calls to finish before closing sessions
// After closed=true is set, no new CallTool can start (they check closed first)
m.wg.Wait()
m.mu.Lock()
defer m.mu.Unlock()
logger.InfoCF("mcp", "Closing all MCP server connections",
map[string]any{
"count": len(m.servers),
})
var errs []error
for name, conn := range m.servers {
if err := conn.Session.Close(); err != nil {
logger.ErrorCF("mcp", "Failed to close server connection",
map[string]any{
"server": name,
"error": err.Error(),
})
errs = append(errs, fmt.Errorf("server %s: %w", name, err))
}
}
m.servers = make(map[string]*ServerConnection)
if len(errs) > 0 {
return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...))
}
return nil
}
// GetAllTools returns all tools from all connected servers
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string][]*mcp.Tool)
for name, conn := range m.servers {
if len(conn.Tools) > 0 {
result[name] = conn.Tools
}
}
return result
}