mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
91c168db20
Implement comprehensive MCP support with stdio/HTTP/SSE transports, environment variable configuration (env and envFile), custom headers, tool registration, and automatic resource cleanup. Includes full test coverage and VSCode-compatible configuration. - Added pkg/mcp/manager.go for server lifecycle management - Added pkg/tools/mcp_tool.go for tool wrapping - Integrated into agent loop with cleanup - Support for envFile loading (.env format) - Headers injection for HTTP/SSE authentication - Example configs for filesystem, github, brave-search, postgres
433 lines
10 KiB
Go
433 lines
10 KiB
Go
package mcp
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
// headerTransport is an http.RoundTripper that adds custom headers to requests
|
|
type headerTransport struct {
|
|
base http.RoundTripper
|
|
headers map[string]string
|
|
}
|
|
|
|
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// Clone the request to avoid modifying the original
|
|
req = req.Clone(req.Context())
|
|
|
|
// Add custom headers
|
|
for key, value := range t.headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
// Use the base transport
|
|
base := t.base
|
|
if base == nil {
|
|
base = http.DefaultTransport
|
|
}
|
|
return base.RoundTrip(req)
|
|
}
|
|
|
|
// loadEnvFile loads environment variables from a file in .env format
|
|
// Each line should be in the format: KEY=value
|
|
// Lines starting with # are comments
|
|
// Empty lines are ignored
|
|
func loadEnvFile(path string) (map[string]string, error) {
|
|
file, err := os.Open(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open env file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
envVars := make(map[string]string)
|
|
scanner := bufio.NewScanner(file)
|
|
lineNum := 0
|
|
|
|
for scanner.Scan() {
|
|
lineNum++
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// Skip empty lines and comments
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
// Parse KEY=value
|
|
parts := strings.SplitN(line, "=", 2)
|
|
if len(parts) != 2 {
|
|
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
|
|
}
|
|
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
|
|
// Remove surrounding quotes if present
|
|
if len(value) >= 2 {
|
|
if (value[0] == '"' && value[len(value)-1] == '"') ||
|
|
(value[0] == '\'' && value[len(value)-1] == '\'') {
|
|
value = value[1 : len(value)-1]
|
|
}
|
|
}
|
|
|
|
envVars[key] = value
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("error reading env file: %w", err)
|
|
}
|
|
|
|
return envVars, nil
|
|
}
|
|
|
|
// ServerConnection represents a connection to an MCP server
|
|
type ServerConnection struct {
|
|
Name string
|
|
Client *mcp.Client
|
|
Session *mcp.ClientSession
|
|
Tools []*mcp.Tool
|
|
}
|
|
|
|
// Manager manages multiple MCP server connections
|
|
type Manager struct {
|
|
servers map[string]*ServerConnection
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewManager creates a new MCP manager
|
|
func NewManager() *Manager {
|
|
return &Manager{
|
|
servers: make(map[string]*ServerConnection),
|
|
}
|
|
}
|
|
|
|
// LoadFromConfig loads MCP servers from configuration
|
|
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
|
|
if !cfg.Tools.MCP.Enabled {
|
|
logger.InfoCF("mcp", "MCP integration is disabled", nil)
|
|
return nil
|
|
}
|
|
|
|
if len(cfg.Tools.MCP.Servers) == 0 {
|
|
logger.InfoCF("mcp", "No MCP servers configured", nil)
|
|
return nil
|
|
}
|
|
|
|
logger.InfoCF("mcp", "Initializing MCP servers",
|
|
map[string]interface{}{
|
|
"count": len(cfg.Tools.MCP.Servers),
|
|
})
|
|
|
|
var wg sync.WaitGroup
|
|
errs := make(chan error, len(cfg.Tools.MCP.Servers))
|
|
|
|
for name, serverCfg := range cfg.Tools.MCP.Servers {
|
|
if !serverCfg.Enabled {
|
|
logger.DebugCF("mcp", "Skipping disabled server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
})
|
|
continue
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func(name string, serverCfg config.MCPServerConfig) {
|
|
defer wg.Done()
|
|
|
|
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
|
|
logger.ErrorCF("mcp", "Failed to connect to MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"error": err.Error(),
|
|
})
|
|
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
|
|
}
|
|
}(name, serverCfg)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errs)
|
|
|
|
// Collect errors
|
|
var allErrors []error
|
|
for err := range errs {
|
|
allErrors = append(allErrors, err)
|
|
}
|
|
|
|
if len(allErrors) > 0 {
|
|
logger.WarnCF("mcp", "Some MCP servers failed to connect",
|
|
map[string]interface{}{
|
|
"failed": len(allErrors),
|
|
"total": len(cfg.Tools.MCP.Servers),
|
|
})
|
|
// Don't fail completely if some servers fail to connect
|
|
}
|
|
|
|
connectedCount := len(m.GetServers())
|
|
logger.InfoCF("mcp", "MCP server initialization complete",
|
|
map[string]interface{}{
|
|
"connected": connectedCount,
|
|
"total": len(cfg.Tools.MCP.Servers),
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// ConnectServer connects to a single MCP server
|
|
func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCPServerConfig) error {
|
|
logger.InfoCF("mcp", "Connecting to MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"command": cfg.Command,
|
|
"args": cfg.Args,
|
|
})
|
|
|
|
// Create client
|
|
client := mcp.NewClient(&mcp.Implementation{
|
|
Name: "picoclaw",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
|
|
// Create transport based on configuration
|
|
// Auto-detect transport type if not explicitly specified
|
|
var transport mcp.Transport
|
|
transportType := cfg.Type
|
|
|
|
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
|
|
if transportType == "" {
|
|
if cfg.URL != "" {
|
|
transportType = "sse"
|
|
} else if cfg.Command != "" {
|
|
transportType = "stdio"
|
|
} else {
|
|
return fmt.Errorf("either URL or command must be provided")
|
|
}
|
|
}
|
|
|
|
switch transportType {
|
|
case "sse", "http":
|
|
if cfg.URL == "" {
|
|
return fmt.Errorf("URL is required for SSE/HTTP transport")
|
|
}
|
|
logger.DebugCF("mcp", "Using SSE/HTTP transport",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"url": cfg.URL,
|
|
})
|
|
|
|
sseTransport := &mcp.StreamableClientTransport{
|
|
Endpoint: cfg.URL,
|
|
}
|
|
|
|
// Add custom headers if provided
|
|
if len(cfg.Headers) > 0 {
|
|
// Create a custom HTTP client with header-injecting transport
|
|
sseTransport.HTTPClient = &http.Client{
|
|
Transport: &headerTransport{
|
|
base: http.DefaultTransport,
|
|
headers: cfg.Headers,
|
|
},
|
|
}
|
|
logger.DebugCF("mcp", "Added custom HTTP headers",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"header_count": len(cfg.Headers),
|
|
})
|
|
}
|
|
|
|
transport = sseTransport
|
|
case "stdio":
|
|
if cfg.Command == "" {
|
|
return fmt.Errorf("command is required for stdio transport")
|
|
}
|
|
logger.DebugCF("mcp", "Using stdio transport",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"command": cfg.Command,
|
|
})
|
|
// Create command with context
|
|
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
|
|
|
|
// Set environment variables
|
|
env := cmd.Environ()
|
|
|
|
// Load environment variables from file if specified
|
|
if cfg.EnvFile != "" {
|
|
envVars, err := loadEnvFile(cfg.EnvFile)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
|
}
|
|
for k, v := range envVars {
|
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
logger.DebugCF("mcp", "Loaded environment variables from file",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"envFile": cfg.EnvFile,
|
|
"var_count": len(envVars),
|
|
})
|
|
}
|
|
|
|
// Environment variables from config override those from file
|
|
if len(cfg.Env) > 0 {
|
|
for k, v := range cfg.Env {
|
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
}
|
|
|
|
// Set environment if we added any variables
|
|
if len(env) > len(cmd.Environ()) {
|
|
cmd.Env = env
|
|
}
|
|
|
|
transport = &mcp.CommandTransport{Command: cmd}
|
|
default:
|
|
return fmt.Errorf("unsupported transport type: %s (supported: stdio, sse, http)", transportType)
|
|
}
|
|
|
|
// Connect to server
|
|
session, err := client.Connect(ctx, transport, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect: %w", err)
|
|
}
|
|
|
|
// Get server info
|
|
initResult := session.InitializeResult()
|
|
logger.InfoCF("mcp", "Connected to MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"serverName": initResult.ServerInfo.Name,
|
|
"serverVersion": initResult.ServerInfo.Version,
|
|
"protocol": initResult.ProtocolVersion,
|
|
})
|
|
|
|
// List available tools if supported
|
|
var tools []*mcp.Tool
|
|
if initResult.Capabilities.Tools != nil {
|
|
for tool, err := range session.Tools(ctx, nil) {
|
|
if err != nil {
|
|
logger.WarnCF("mcp", "Error listing tool",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"error": err.Error(),
|
|
})
|
|
continue
|
|
}
|
|
tools = append(tools, tool)
|
|
}
|
|
|
|
logger.InfoCF("mcp", "Listed tools from MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"toolCount": len(tools),
|
|
})
|
|
}
|
|
|
|
// Store connection
|
|
m.mu.Lock()
|
|
m.servers[name] = &ServerConnection{
|
|
Name: name,
|
|
Client: client,
|
|
Session: session,
|
|
Tools: tools,
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetServers returns all connected servers
|
|
func (m *Manager) GetServers() map[string]*ServerConnection {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
result := make(map[string]*ServerConnection, len(m.servers))
|
|
for k, v := range m.servers {
|
|
result[k] = v
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetServer returns a specific server connection
|
|
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
conn, ok := m.servers[name]
|
|
return conn, ok
|
|
}
|
|
|
|
// CallTool calls a tool on a specific server
|
|
func (m *Manager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
|
conn, ok := m.GetServer(serverName)
|
|
if !ok {
|
|
return nil, fmt.Errorf("server %s not found", serverName)
|
|
}
|
|
|
|
params := &mcp.CallToolParams{
|
|
Name: toolName,
|
|
Arguments: arguments,
|
|
}
|
|
|
|
result, err := conn.Session.CallTool(ctx, params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to call tool: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Close closes all server connections
|
|
func (m *Manager) Close() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
logger.InfoCF("mcp", "Closing all MCP server connections",
|
|
map[string]interface{}{
|
|
"count": len(m.servers),
|
|
})
|
|
|
|
var errs []error
|
|
for name, conn := range m.servers {
|
|
if err := conn.Session.Close(); err != nil {
|
|
logger.ErrorCF("mcp", "Failed to close server connection",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"error": err.Error(),
|
|
})
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
m.servers = make(map[string]*ServerConnection)
|
|
|
|
if len(errs) > 0 {
|
|
return fmt.Errorf("failed to close %d server(s)", len(errs))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAllTools returns all tools from all connected servers
|
|
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
result := make(map[string][]*mcp.Tool)
|
|
for name, conn := range m.servers {
|
|
if len(conn.Tools) > 0 {
|
|
result[name] = conn.Tools
|
|
}
|
|
}
|
|
return result
|
|
}
|