mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
6892d006d6
Replace full *config.Config reference with config.MCPConfig value type in AgentLoop to allow garbage collection of unused configuration data. Changes: - AgentLoop now stores only MCPConfig and workspacePath (minimal deps) - Add mcp.Manager.LoadFromMCPConfig() for minimal dependency version - Keep LoadFromConfig() for backward compatibility - Full Config object can be GC'd after NewAgentLoop() returns This optimization reduces memory usage by not holding references to unused channel, provider, gateway, and device configurations.
460 lines
11 KiB
Go
460 lines
11 KiB
Go
package mcp
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
// headerTransport is an http.RoundTripper that adds custom headers to requests
|
|
type headerTransport struct {
|
|
base http.RoundTripper
|
|
headers map[string]string
|
|
}
|
|
|
|
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// Clone the request to avoid modifying the original
|
|
req = req.Clone(req.Context())
|
|
|
|
// Add custom headers
|
|
for key, value := range t.headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
// Use the base transport
|
|
base := t.base
|
|
if base == nil {
|
|
base = http.DefaultTransport
|
|
}
|
|
return base.RoundTrip(req)
|
|
}
|
|
|
|
// loadEnvFile loads environment variables from a file in .env format
|
|
// Each line should be in the format: KEY=value
|
|
// Lines starting with # are comments
|
|
// Empty lines are ignored
|
|
func loadEnvFile(path string) (map[string]string, error) {
|
|
file, err := os.Open(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open env file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
envVars := make(map[string]string)
|
|
scanner := bufio.NewScanner(file)
|
|
lineNum := 0
|
|
|
|
for scanner.Scan() {
|
|
lineNum++
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// Skip empty lines and comments
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
// Parse KEY=value
|
|
parts := strings.SplitN(line, "=", 2)
|
|
if len(parts) != 2 {
|
|
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
|
|
}
|
|
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
|
|
// Remove surrounding quotes if present
|
|
if len(value) >= 2 {
|
|
if (value[0] == '"' && value[len(value)-1] == '"') ||
|
|
(value[0] == '\'' && value[len(value)-1] == '\'') {
|
|
value = value[1 : len(value)-1]
|
|
}
|
|
}
|
|
|
|
envVars[key] = value
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("error reading env file: %w", err)
|
|
}
|
|
|
|
return envVars, nil
|
|
}
|
|
|
|
// ServerConnection represents a connection to an MCP server
|
|
type ServerConnection struct {
|
|
Name string
|
|
Client *mcp.Client
|
|
Session *mcp.ClientSession
|
|
Tools []*mcp.Tool
|
|
}
|
|
|
|
// Manager manages multiple MCP server connections
|
|
type Manager struct {
|
|
servers map[string]*ServerConnection
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewManager creates a new MCP manager
|
|
func NewManager() *Manager {
|
|
return &Manager{
|
|
servers: make(map[string]*ServerConnection),
|
|
}
|
|
}
|
|
|
|
// LoadFromConfig loads MCP servers from configuration
|
|
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
|
|
return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
|
|
}
|
|
|
|
// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
|
|
// This is the minimal dependency version that doesn't require the full Config object.
|
|
func (m *Manager) LoadFromMCPConfig(ctx context.Context, mcpCfg config.MCPConfig, workspacePath string) error {
|
|
if !mcpCfg.Enabled {
|
|
logger.InfoCF("mcp", "MCP integration is disabled", nil)
|
|
return nil
|
|
}
|
|
|
|
if len(mcpCfg.Servers) == 0 {
|
|
logger.InfoCF("mcp", "No MCP servers configured", nil)
|
|
return nil
|
|
}
|
|
|
|
logger.InfoCF("mcp", "Initializing MCP servers",
|
|
map[string]interface{}{
|
|
"count": len(mcpCfg.Servers),
|
|
})
|
|
|
|
var wg sync.WaitGroup
|
|
errs := make(chan error, len(mcpCfg.Servers))
|
|
enabledCount := 0
|
|
|
|
for name, serverCfg := range mcpCfg.Servers {
|
|
if !serverCfg.Enabled {
|
|
logger.DebugCF("mcp", "Skipping disabled server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
})
|
|
continue
|
|
}
|
|
|
|
enabledCount++
|
|
wg.Add(1)
|
|
go func(name string, serverCfg config.MCPServerConfig, workspace string) {
|
|
defer wg.Done()
|
|
|
|
// Resolve relative envFile paths relative to workspace
|
|
if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
|
|
serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
|
|
}
|
|
|
|
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
|
|
logger.ErrorCF("mcp", "Failed to connect to MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"error": err.Error(),
|
|
})
|
|
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
|
|
}
|
|
}(name, serverCfg, workspacePath)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errs)
|
|
|
|
// Collect errors
|
|
var allErrors []error
|
|
for err := range errs {
|
|
allErrors = append(allErrors, err)
|
|
}
|
|
|
|
connectedCount := len(m.GetServers())
|
|
|
|
// If all enabled servers failed to connect, return aggregated error
|
|
if enabledCount > 0 && connectedCount == 0 {
|
|
logger.ErrorCF("mcp", "All MCP servers failed to connect",
|
|
map[string]interface{}{
|
|
"failed": len(allErrors),
|
|
"total": enabledCount,
|
|
})
|
|
return errors.Join(allErrors...)
|
|
}
|
|
|
|
if len(allErrors) > 0 {
|
|
logger.WarnCF("mcp", "Some MCP servers failed to connect",
|
|
map[string]interface{}{
|
|
"failed": len(allErrors),
|
|
"connected": connectedCount,
|
|
"total": enabledCount,
|
|
})
|
|
// Don't fail completely if some servers successfully connected
|
|
}
|
|
|
|
logger.InfoCF("mcp", "MCP server initialization complete",
|
|
map[string]interface{}{
|
|
"connected": connectedCount,
|
|
"total": enabledCount,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
// ConnectServer connects to a single MCP server
|
|
func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCPServerConfig) error {
|
|
logger.InfoCF("mcp", "Connecting to MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"command": cfg.Command,
|
|
"args": cfg.Args,
|
|
})
|
|
|
|
// Create client
|
|
client := mcp.NewClient(&mcp.Implementation{
|
|
Name: "picoclaw",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
|
|
// Create transport based on configuration
|
|
// Auto-detect transport type if not explicitly specified
|
|
var transport mcp.Transport
|
|
transportType := cfg.Type
|
|
|
|
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
|
|
if transportType == "" {
|
|
if cfg.URL != "" {
|
|
transportType = "sse"
|
|
} else if cfg.Command != "" {
|
|
transportType = "stdio"
|
|
} else {
|
|
return fmt.Errorf("either URL or command must be provided")
|
|
}
|
|
}
|
|
|
|
switch transportType {
|
|
case "sse", "http":
|
|
if cfg.URL == "" {
|
|
return fmt.Errorf("URL is required for SSE/HTTP transport")
|
|
}
|
|
logger.DebugCF("mcp", "Using SSE/HTTP transport",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"url": cfg.URL,
|
|
})
|
|
|
|
sseTransport := &mcp.StreamableClientTransport{
|
|
Endpoint: cfg.URL,
|
|
}
|
|
|
|
// Add custom headers if provided
|
|
if len(cfg.Headers) > 0 {
|
|
// Create a custom HTTP client with header-injecting transport
|
|
sseTransport.HTTPClient = &http.Client{
|
|
Transport: &headerTransport{
|
|
base: http.DefaultTransport,
|
|
headers: cfg.Headers,
|
|
},
|
|
}
|
|
logger.DebugCF("mcp", "Added custom HTTP headers",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"header_count": len(cfg.Headers),
|
|
})
|
|
}
|
|
|
|
transport = sseTransport
|
|
case "stdio":
|
|
if cfg.Command == "" {
|
|
return fmt.Errorf("command is required for stdio transport")
|
|
}
|
|
logger.DebugCF("mcp", "Using stdio transport",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"command": cfg.Command,
|
|
})
|
|
// Create command with context
|
|
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
|
|
|
|
// Set environment variables
|
|
env := cmd.Environ()
|
|
|
|
// Load environment variables from file if specified
|
|
if cfg.EnvFile != "" {
|
|
envVars, err := loadEnvFile(cfg.EnvFile)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
|
}
|
|
for k, v := range envVars {
|
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
logger.DebugCF("mcp", "Loaded environment variables from file",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"envFile": cfg.EnvFile,
|
|
"var_count": len(envVars),
|
|
})
|
|
}
|
|
|
|
// Environment variables from config override those from file
|
|
if len(cfg.Env) > 0 {
|
|
for k, v := range cfg.Env {
|
|
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
}
|
|
|
|
// Set environment if we added any variables
|
|
if len(env) > len(cmd.Environ()) {
|
|
cmd.Env = env
|
|
}
|
|
|
|
transport = &mcp.CommandTransport{Command: cmd}
|
|
default:
|
|
return fmt.Errorf("unsupported transport type: %s (supported: stdio, sse, http)", transportType)
|
|
}
|
|
|
|
// Connect to server
|
|
session, err := client.Connect(ctx, transport, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect: %w", err)
|
|
}
|
|
|
|
// Get server info
|
|
initResult := session.InitializeResult()
|
|
logger.InfoCF("mcp", "Connected to MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"serverName": initResult.ServerInfo.Name,
|
|
"serverVersion": initResult.ServerInfo.Version,
|
|
"protocol": initResult.ProtocolVersion,
|
|
})
|
|
|
|
// List available tools if supported
|
|
var tools []*mcp.Tool
|
|
if initResult.Capabilities.Tools != nil {
|
|
for tool, err := range session.Tools(ctx, nil) {
|
|
if err != nil {
|
|
logger.WarnCF("mcp", "Error listing tool",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"error": err.Error(),
|
|
})
|
|
continue
|
|
}
|
|
tools = append(tools, tool)
|
|
}
|
|
|
|
logger.InfoCF("mcp", "Listed tools from MCP server",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"toolCount": len(tools),
|
|
})
|
|
}
|
|
|
|
// Store connection
|
|
m.mu.Lock()
|
|
m.servers[name] = &ServerConnection{
|
|
Name: name,
|
|
Client: client,
|
|
Session: session,
|
|
Tools: tools,
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetServers returns all connected servers
|
|
func (m *Manager) GetServers() map[string]*ServerConnection {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
result := make(map[string]*ServerConnection, len(m.servers))
|
|
for k, v := range m.servers {
|
|
result[k] = v
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetServer returns a specific server connection
|
|
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
conn, ok := m.servers[name]
|
|
return conn, ok
|
|
}
|
|
|
|
// CallTool calls a tool on a specific server
|
|
func (m *Manager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
|
conn, ok := m.GetServer(serverName)
|
|
if !ok {
|
|
return nil, fmt.Errorf("server %s not found", serverName)
|
|
}
|
|
|
|
params := &mcp.CallToolParams{
|
|
Name: toolName,
|
|
Arguments: arguments,
|
|
}
|
|
|
|
result, err := conn.Session.CallTool(ctx, params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to call tool: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Close closes all server connections
|
|
func (m *Manager) Close() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
logger.InfoCF("mcp", "Closing all MCP server connections",
|
|
map[string]interface{}{
|
|
"count": len(m.servers),
|
|
})
|
|
|
|
var errs []error
|
|
for name, conn := range m.servers {
|
|
if err := conn.Session.Close(); err != nil {
|
|
logger.ErrorCF("mcp", "Failed to close server connection",
|
|
map[string]interface{}{
|
|
"server": name,
|
|
"error": err.Error(),
|
|
})
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
m.servers = make(map[string]*ServerConnection)
|
|
|
|
if len(errs) > 0 {
|
|
return fmt.Errorf("failed to close %d server(s)", len(errs))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAllTools returns all tools from all connected servers
|
|
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
result := make(map[string][]*mcp.Tool)
|
|
for name, conn := range m.servers {
|
|
if len(conn.Tools) > 0 {
|
|
result[name] = conn.Tools
|
|
}
|
|
}
|
|
return result
|
|
}
|