Files
picoclaw/pkg/mcp/manager.go
T
yuchou87 f1b798434d fix(mcp): prevent race condition between CallTool and Close
Add a closed flag to the Manager struct to prevent CallTool from
accessing server connections after Close has been called. The flag
is checked within the RLock in CallTool to ensure thread-safety.

Previously, CallTool could obtain a server reference using RLock,
then that reference could be closed by Close() running concurrently,
leading to use-after-close errors. Now:

1. CallTool checks the closed flag before accessing servers
2. Close sets the closed flag before closing connections
3. CallTool directly accesses m.servers within the lock instead
   of using GetServer() to avoid releasing the lock prematurely

This ensures CallTool will not use a server connection that is
being closed or has been closed.

Addresses Copilot code review feedback.
2026-02-19 19:47:05 +08:00

481 lines
12 KiB
Go

package mcp
import (
"bufio"
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// headerTransport is an http.RoundTripper that adds custom headers to requests
type headerTransport struct {
base http.RoundTripper
headers map[string]string
}
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
req = req.Clone(req.Context())
// Add custom headers
for key, value := range t.headers {
req.Header.Set(key, value)
}
// Use the base transport
base := t.base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}
// loadEnvFile loads environment variables from a file in .env format
// Each line should be in the format: KEY=value
// Lines starting with # are comments
// Empty lines are ignored
func loadEnvFile(path string) (map[string]string, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open env file: %w", err)
}
defer file.Close()
envVars := make(map[string]string)
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Parse KEY=value
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// Remove surrounding quotes if present
if len(value) >= 2 {
if (value[0] == '"' && value[len(value)-1] == '"') ||
(value[0] == '\'' && value[len(value)-1] == '\'') {
value = value[1 : len(value)-1]
}
}
envVars[key] = value
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading env file: %w", err)
}
return envVars, nil
}
// ServerConnection represents a connection to an MCP server
type ServerConnection struct {
Name string
Client *mcp.Client
Session *mcp.ClientSession
Tools []*mcp.Tool
}
// Manager manages multiple MCP server connections
type Manager struct {
servers map[string]*ServerConnection
mu sync.RWMutex
closed bool
}
// NewManager creates a new MCP manager
func NewManager() *Manager {
return &Manager{
servers: make(map[string]*ServerConnection),
}
}
// LoadFromConfig loads MCP servers from configuration
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
}
// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
// This is the minimal dependency version that doesn't require the full Config object.
func (m *Manager) LoadFromMCPConfig(ctx context.Context, mcpCfg config.MCPConfig, workspacePath string) error {
if !mcpCfg.Enabled {
logger.InfoCF("mcp", "MCP integration is disabled", nil)
return nil
}
if len(mcpCfg.Servers) == 0 {
logger.InfoCF("mcp", "No MCP servers configured", nil)
return nil
}
logger.InfoCF("mcp", "Initializing MCP servers",
map[string]interface{}{
"count": len(mcpCfg.Servers),
})
var wg sync.WaitGroup
errs := make(chan error, len(mcpCfg.Servers))
enabledCount := 0
for name, serverCfg := range mcpCfg.Servers {
if !serverCfg.Enabled {
logger.DebugCF("mcp", "Skipping disabled server",
map[string]interface{}{
"server": name,
})
continue
}
enabledCount++
wg.Add(1)
go func(name string, serverCfg config.MCPServerConfig, workspace string) {
defer wg.Done()
// Resolve relative envFile paths relative to workspace
if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
}
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
logger.ErrorCF("mcp", "Failed to connect to MCP server",
map[string]interface{}{
"server": name,
"error": err.Error(),
})
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
}
}(name, serverCfg, workspacePath)
}
wg.Wait()
close(errs)
// Collect errors
var allErrors []error
for err := range errs {
allErrors = append(allErrors, err)
}
connectedCount := len(m.GetServers())
// If all enabled servers failed to connect, return aggregated error
if enabledCount > 0 && connectedCount == 0 {
logger.ErrorCF("mcp", "All MCP servers failed to connect",
map[string]interface{}{
"failed": len(allErrors),
"total": enabledCount,
})
return errors.Join(allErrors...)
}
if len(allErrors) > 0 {
logger.WarnCF("mcp", "Some MCP servers failed to connect",
map[string]interface{}{
"failed": len(allErrors),
"connected": connectedCount,
"total": enabledCount,
})
// Don't fail completely if some servers successfully connected
}
logger.InfoCF("mcp", "MCP server initialization complete",
map[string]interface{}{
"connected": connectedCount,
"total": enabledCount,
})
return nil
}
// ConnectServer connects to a single MCP server
func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCPServerConfig) error {
logger.InfoCF("mcp", "Connecting to MCP server",
map[string]interface{}{
"server": name,
"command": cfg.Command,
"args": cfg.Args,
})
// Create client
client := mcp.NewClient(&mcp.Implementation{
Name: "picoclaw",
Version: "1.0.0",
}, nil)
// Create transport based on configuration
// Auto-detect transport type if not explicitly specified
var transport mcp.Transport
transportType := cfg.Type
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
if transportType == "" {
if cfg.URL != "" {
transportType = "sse"
} else if cfg.Command != "" {
transportType = "stdio"
} else {
return fmt.Errorf("either URL or command must be provided")
}
}
switch transportType {
case "sse", "http":
if cfg.URL == "" {
return fmt.Errorf("URL is required for SSE/HTTP transport")
}
logger.DebugCF("mcp", "Using SSE/HTTP transport",
map[string]interface{}{
"server": name,
"url": cfg.URL,
})
sseTransport := &mcp.StreamableClientTransport{
Endpoint: cfg.URL,
}
// Add custom headers if provided
if len(cfg.Headers) > 0 {
// Create a custom HTTP client with header-injecting transport
sseTransport.HTTPClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: cfg.Headers,
},
}
logger.DebugCF("mcp", "Added custom HTTP headers",
map[string]interface{}{
"server": name,
"header_count": len(cfg.Headers),
})
}
transport = sseTransport
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("command is required for stdio transport")
}
logger.DebugCF("mcp", "Using stdio transport",
map[string]interface{}{
"server": name,
"command": cfg.Command,
})
// Create command with context
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
// Build environment variables with proper override semantics
// Use a map to ensure config variables override file variables
envMap := make(map[string]string)
// Start with parent process environment
for _, e := range cmd.Environ() {
if idx := strings.Index(e, "="); idx > 0 {
envMap[e[:idx]] = e[idx+1:]
}
}
// Load environment variables from file if specified
if cfg.EnvFile != "" {
envVars, err := loadEnvFile(cfg.EnvFile)
if err != nil {
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
}
for k, v := range envVars {
envMap[k] = v
}
logger.DebugCF("mcp", "Loaded environment variables from file",
map[string]interface{}{
"server": name,
"envFile": cfg.EnvFile,
"var_count": len(envVars),
})
}
// Environment variables from config override those from file
for k, v := range cfg.Env {
envMap[k] = v
}
// Convert map to slice
env := make([]string, 0, len(envMap))
for k, v := range envMap {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
cmd.Env = env
transport = &mcp.CommandTransport{Command: cmd}
default:
return fmt.Errorf("unsupported transport type: %s (supported: stdio, sse, http)", transportType)
}
// Connect to server
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
// Get server info
initResult := session.InitializeResult()
logger.InfoCF("mcp", "Connected to MCP server",
map[string]interface{}{
"server": name,
"serverName": initResult.ServerInfo.Name,
"serverVersion": initResult.ServerInfo.Version,
"protocol": initResult.ProtocolVersion,
})
// List available tools if supported
var tools []*mcp.Tool
if initResult.Capabilities.Tools != nil {
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
logger.WarnCF("mcp", "Error listing tool",
map[string]interface{}{
"server": name,
"error": err.Error(),
})
continue
}
tools = append(tools, tool)
}
logger.InfoCF("mcp", "Listed tools from MCP server",
map[string]interface{}{
"server": name,
"toolCount": len(tools),
})
}
// Store connection
m.mu.Lock()
m.servers[name] = &ServerConnection{
Name: name,
Client: client,
Session: session,
Tools: tools,
}
m.mu.Unlock()
return nil
}
// GetServers returns all connected servers
func (m *Manager) GetServers() map[string]*ServerConnection {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]*ServerConnection, len(m.servers))
for k, v := range m.servers {
result[k] = v
}
return result
}
// GetServer returns a specific server connection
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
conn, ok := m.servers[name]
return conn, ok
}
// CallTool calls a tool on a specific server
func (m *Manager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
m.mu.RLock()
if m.closed {
m.mu.RUnlock()
return nil, fmt.Errorf("manager is closed")
}
conn, ok := m.servers[serverName]
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("server %s not found", serverName)
}
params := &mcp.CallToolParams{
Name: toolName,
Arguments: arguments,
}
result, err := conn.Session.CallTool(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to call tool: %w", err)
}
return result, nil
}
// Close closes all server connections
func (m *Manager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return nil
}
m.closed = true
logger.InfoCF("mcp", "Closing all MCP server connections",
map[string]interface{}{
"count": len(m.servers),
})
var errs []error
for name, conn := range m.servers {
if err := conn.Session.Close(); err != nil {
logger.ErrorCF("mcp", "Failed to close server connection",
map[string]interface{}{
"server": name,
"error": err.Error(),
})
errs = append(errs, err)
}
}
m.servers = make(map[string]*ServerConnection)
if len(errs) > 0 {
return fmt.Errorf("failed to close %d server(s)", len(errs))
}
return nil
}
// GetAllTools returns all tools from all connected servers
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string][]*mcp.Tool)
for name, conn := range m.servers {
if len(conn.Tools) > 0 {
result[name] = conn.Tools
}
}
return result
}