mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(mcp): add Model Context Protocol integration
Implement comprehensive MCP support with stdio/HTTP/SSE transports, environment variable configuration (env and envFile), custom headers, tool registration, and automatic resource cleanup. Includes full test coverage and VSCode-compatible configuration. - Added pkg/mcp/manager.go for server lifecycle management - Added pkg/tools/mcp_tool.go for tool wrapping - Integrated into agent loop with cleanup - Support for envFile loading (.env format) - Headers injection for HTTP/SSE authentication - Example configs for filesystem, github, brave-search, postgres
This commit is contained in:
@@ -0,0 +1,432 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// headerTransport is an http.RoundTripper that adds custom headers to requests
|
||||
type headerTransport struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
// Add custom headers
|
||||
for key, value := range t.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Use the base transport
|
||||
base := t.base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// loadEnvFile loads environment variables from a file in .env format
|
||||
// Each line should be in the format: KEY=value
|
||||
// Lines starting with # are comments
|
||||
// Empty lines are ignored
|
||||
func loadEnvFile(path string) (map[string]string, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
envVars := make(map[string]string)
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse KEY=value
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
// Remove surrounding quotes if present
|
||||
if len(value) >= 2 {
|
||||
if (value[0] == '"' && value[len(value)-1] == '"') ||
|
||||
(value[0] == '\'' && value[len(value)-1] == '\'') {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
envVars[key] = value
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading env file: %w", err)
|
||||
}
|
||||
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
// ServerConnection represents a connection to an MCP server
|
||||
type ServerConnection struct {
|
||||
Name string
|
||||
Client *mcp.Client
|
||||
Session *mcp.ClientSession
|
||||
Tools []*mcp.Tool
|
||||
}
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
type Manager struct {
|
||||
servers map[string]*ServerConnection
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
servers: make(map[string]*ServerConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromConfig loads MCP servers from configuration
|
||||
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
|
||||
if !cfg.Tools.MCP.Enabled {
|
||||
logger.InfoCF("mcp", "MCP integration is disabled", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(cfg.Tools.MCP.Servers) == 0 {
|
||||
logger.InfoCF("mcp", "No MCP servers configured", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Initializing MCP servers",
|
||||
map[string]interface{}{
|
||||
"count": len(cfg.Tools.MCP.Servers),
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, len(cfg.Tools.MCP.Servers))
|
||||
|
||||
for name, serverCfg := range cfg.Tools.MCP.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
logger.DebugCF("mcp", "Skipping disabled server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(name string, serverCfg config.MCPServerConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to connect to MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
|
||||
}
|
||||
}(name, serverCfg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
// Collect errors
|
||||
var allErrors []error
|
||||
for err := range errs {
|
||||
allErrors = append(allErrors, err)
|
||||
}
|
||||
|
||||
if len(allErrors) > 0 {
|
||||
logger.WarnCF("mcp", "Some MCP servers failed to connect",
|
||||
map[string]interface{}{
|
||||
"failed": len(allErrors),
|
||||
"total": len(cfg.Tools.MCP.Servers),
|
||||
})
|
||||
// Don't fail completely if some servers fail to connect
|
||||
}
|
||||
|
||||
connectedCount := len(m.GetServers())
|
||||
logger.InfoCF("mcp", "MCP server initialization complete",
|
||||
map[string]interface{}{
|
||||
"connected": connectedCount,
|
||||
"total": len(cfg.Tools.MCP.Servers),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectServer connects to a single MCP server
|
||||
func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCPServerConfig) error {
|
||||
logger.InfoCF("mcp", "Connecting to MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
"args": cfg.Args,
|
||||
})
|
||||
|
||||
// Create client
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: "picoclaw",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
// Create transport based on configuration
|
||||
// Auto-detect transport type if not explicitly specified
|
||||
var transport mcp.Transport
|
||||
transportType := cfg.Type
|
||||
|
||||
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
|
||||
if transportType == "" {
|
||||
if cfg.URL != "" {
|
||||
transportType = "sse"
|
||||
} else if cfg.Command != "" {
|
||||
transportType = "stdio"
|
||||
} else {
|
||||
return fmt.Errorf("either URL or command must be provided")
|
||||
}
|
||||
}
|
||||
|
||||
switch transportType {
|
||||
case "sse", "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("URL is required for SSE/HTTP transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using SSE/HTTP transport",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"url": cfg.URL,
|
||||
})
|
||||
|
||||
sseTransport := &mcp.StreamableClientTransport{
|
||||
Endpoint: cfg.URL,
|
||||
}
|
||||
|
||||
// Add custom headers if provided
|
||||
if len(cfg.Headers) > 0 {
|
||||
// Create a custom HTTP client with header-injecting transport
|
||||
sseTransport.HTTPClient = &http.Client{
|
||||
Transport: &headerTransport{
|
||||
base: http.DefaultTransport,
|
||||
headers: cfg.Headers,
|
||||
},
|
||||
}
|
||||
logger.DebugCF("mcp", "Added custom HTTP headers",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"header_count": len(cfg.Headers),
|
||||
})
|
||||
}
|
||||
|
||||
transport = sseTransport
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("command is required for stdio transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using stdio transport",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
})
|
||||
// Create command with context
|
||||
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
|
||||
|
||||
// Set environment variables
|
||||
env := cmd.Environ()
|
||||
|
||||
// Load environment variables from file if specified
|
||||
if cfg.EnvFile != "" {
|
||||
envVars, err := loadEnvFile(cfg.EnvFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
||||
}
|
||||
for k, v := range envVars {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
logger.DebugCF("mcp", "Loaded environment variables from file",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"envFile": cfg.EnvFile,
|
||||
"var_count": len(envVars),
|
||||
})
|
||||
}
|
||||
|
||||
// Environment variables from config override those from file
|
||||
if len(cfg.Env) > 0 {
|
||||
for k, v := range cfg.Env {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
// Set environment if we added any variables
|
||||
if len(env) > len(cmd.Environ()) {
|
||||
cmd.Env = env
|
||||
}
|
||||
|
||||
transport = &mcp.CommandTransport{Command: cmd}
|
||||
default:
|
||||
return fmt.Errorf("unsupported transport type: %s (supported: stdio, sse, http)", transportType)
|
||||
}
|
||||
|
||||
// Connect to server
|
||||
session, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// Get server info
|
||||
initResult := session.InitializeResult()
|
||||
logger.InfoCF("mcp", "Connected to MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"serverName": initResult.ServerInfo.Name,
|
||||
"serverVersion": initResult.ServerInfo.Version,
|
||||
"protocol": initResult.ProtocolVersion,
|
||||
})
|
||||
|
||||
// List available tools if supported
|
||||
var tools []*mcp.Tool
|
||||
if initResult.Capabilities.Tools != nil {
|
||||
for tool, err := range session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "Error listing tool",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Listed tools from MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"toolCount": len(tools),
|
||||
})
|
||||
}
|
||||
|
||||
// Store connection
|
||||
m.mu.Lock()
|
||||
m.servers[name] = &ServerConnection{
|
||||
Name: name,
|
||||
Client: client,
|
||||
Session: session,
|
||||
Tools: tools,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServers returns all connected servers
|
||||
func (m *Manager) GetServers() map[string]*ServerConnection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ServerConnection, len(m.servers))
|
||||
for k, v := range m.servers {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetServer returns a specific server connection
|
||||
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
conn, ok := m.servers[name]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// CallTool calls a tool on a specific server
|
||||
func (m *Manager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
||||
conn, ok := m.GetServer(serverName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s not found", serverName)
|
||||
}
|
||||
|
||||
params := &mcp.CallToolParams{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
result, err := conn.Session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call tool: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Close closes all server connections
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
logger.InfoCF("mcp", "Closing all MCP server connections",
|
||||
map[string]interface{}{
|
||||
"count": len(m.servers),
|
||||
})
|
||||
|
||||
var errs []error
|
||||
for name, conn := range m.servers {
|
||||
if err := conn.Session.Close(); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to close server connection",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.servers = make(map[string]*ServerConnection)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to close %d server(s)", len(errs))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllTools returns all tools from all connected servers
|
||||
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string][]*mcp.Tool)
|
||||
for name, conn := range m.servers {
|
||||
if len(conn.Tools) > 0 {
|
||||
result[name] = conn.Tools
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic env file",
|
||||
content: `API_KEY=secret123
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with comments and empty lines",
|
||||
content: `# This is a comment
|
||||
API_KEY=secret123
|
||||
|
||||
# Another comment
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with quoted values",
|
||||
content: `API_KEY="secret with spaces"
|
||||
NAME='single quoted'
|
||||
PLAIN=no-quotes`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret with spaces",
|
||||
"NAME": "single quoted",
|
||||
"PLAIN": "no-quotes",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with spaces around equals",
|
||||
content: `API_KEY = secret123
|
||||
DATABASE_URL= postgres://localhost/db
|
||||
PORT =8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format - no equals",
|
||||
content: `INVALID_LINE`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
content: ``,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "only comments",
|
||||
content: `# Comment 1
|
||||
# Comment 2`,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(tt.content), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
result, err := loadEnvFile(envFile)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
|
||||
}
|
||||
|
||||
for key, expectedValue := range tt.expected {
|
||||
if actualValue, ok := result[key]; !ok {
|
||||
t.Errorf("Expected key %s not found", key)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvFileNotFound(t *testing.T) {
|
||||
_, err := loadEnvFile("/nonexistent/file.env")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvFilePriority(t *testing.T) {
|
||||
// Create a temporary .env file
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
envContent := `API_KEY=from_file
|
||||
DATABASE_URL=from_file
|
||||
SHARED_VAR=from_file`
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(envContent), 0644); err != nil {
|
||||
t.Fatalf("Failed to create .env file: %v", err)
|
||||
}
|
||||
|
||||
// Load envFile
|
||||
envVars, err := loadEnvFile(envFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load env file: %v", err)
|
||||
}
|
||||
|
||||
// Verify envFile variables
|
||||
if envVars["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
|
||||
}
|
||||
|
||||
// Simulate config.Env overriding envFile
|
||||
configEnv := map[string]string{
|
||||
"SHARED_VAR": "from_config",
|
||||
"NEW_VAR": "from_config",
|
||||
}
|
||||
|
||||
// Merge: envFile first, then config overrides
|
||||
merged := make(map[string]string)
|
||||
for k, v := range envVars {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range configEnv {
|
||||
merged[k] = v
|
||||
}
|
||||
|
||||
// Verify priority: config.Env should override envFile
|
||||
if merged["SHARED_VAR"] != "from_config" {
|
||||
t.Errorf("Expected SHARED_VAR=from_config (config should override file), got %s", merged["SHARED_VAR"])
|
||||
}
|
||||
if merged["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
|
||||
}
|
||||
if merged["NEW_VAR"] != "from_config" {
|
||||
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user