package mcp import ( "bufio" "context" "errors" "fmt" "net/http" "os" "os/exec" "path/filepath" "strings" "sync" "sync/atomic" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" ) // headerTransport is an http.RoundTripper that adds custom headers to requests type headerTransport struct { base http.RoundTripper headers map[string]string } func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Clone the request to avoid modifying the original req = req.Clone(req.Context()) // Add custom headers for key, value := range t.headers { req.Header.Set(key, value) } // Use the base transport base := t.base if base == nil { base = http.DefaultTransport } return base.RoundTrip(req) } // loadEnvFile loads environment variables from a file in .env format // Each line should be in the format: KEY=value // Lines starting with # are comments // Empty lines are ignored func loadEnvFile(path string) (map[string]string, error) { file, err := os.Open(path) if err != nil { return nil, fmt.Errorf("failed to open env file: %w", err) } defer file.Close() envVars := make(map[string]string) scanner := bufio.NewScanner(file) lineNum := 0 for scanner.Scan() { lineNum++ line := strings.TrimSpace(scanner.Text()) // Skip empty lines and comments if line == "" || strings.HasPrefix(line, "#") { continue } // Parse KEY=value parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line) } key := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) if key == "" { return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum) } // Remove surrounding quotes if present if len(value) >= 2 { if (value[0] == '"' && value[len(value)-1] == '"') || (value[0] == '\'' && value[len(value)-1] == '\'') { value = value[1 : len(value)-1] } } envVars[key] = value } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("error reading env file: %w", err) } return envVars, nil } // ServerConnection represents a connection to an MCP server type ServerConnection struct { Name string Client *mcp.Client Session *mcp.ClientSession Tools []*mcp.Tool } // Manager manages multiple MCP server connections type Manager struct { servers map[string]*ServerConnection mu sync.RWMutex closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race wg sync.WaitGroup // tracks in-flight CallTool calls } // NewManager creates a new MCP manager func NewManager() *Manager { return &Manager{ servers: make(map[string]*ServerConnection), } } // LoadFromConfig loads MCP servers from configuration func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error { return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath()) } // LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path. // This is the minimal dependency version that doesn't require the full Config object. func (m *Manager) LoadFromMCPConfig( ctx context.Context, mcpCfg config.MCPConfig, workspacePath string, ) error { if !mcpCfg.Enabled { logger.InfoCF("mcp", "MCP integration is disabled", nil) return nil } if len(mcpCfg.Servers) == 0 { logger.InfoCF("mcp", "No MCP servers configured", nil) return nil } logger.InfoCF("mcp", "Initializing MCP servers", map[string]any{ "count": len(mcpCfg.Servers), }) var wg sync.WaitGroup errs := make(chan error, len(mcpCfg.Servers)) enabledCount := 0 for name, serverCfg := range mcpCfg.Servers { if !serverCfg.Enabled { logger.DebugCF("mcp", "Skipping disabled server", map[string]any{ "server": name, }) continue } enabledCount++ wg.Add(1) go func(name string, serverCfg config.MCPServerConfig, workspace string) { defer wg.Done() // Resolve relative envFile paths relative to workspace if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) { if workspace == "" { err := fmt.Errorf( "workspace path is empty while resolving relative envFile %q for server %s", serverCfg.EnvFile, name, ) logger.ErrorCF("mcp", "Invalid MCP server configuration", map[string]any{ "server": name, "env_file": serverCfg.EnvFile, "error": err.Error(), }) errs <- err return } serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile) } if err := m.ConnectServer(ctx, name, serverCfg); err != nil { logger.ErrorCF("mcp", "Failed to connect to MCP server", map[string]any{ "server": name, "error": err.Error(), }) errs <- fmt.Errorf("failed to connect to server %s: %w", name, err) } }(name, serverCfg, workspacePath) } wg.Wait() close(errs) // Collect errors var allErrors []error for err := range errs { allErrors = append(allErrors, err) } connectedCount := len(m.GetServers()) // If all enabled servers failed to connect, return aggregated error if enabledCount > 0 && connectedCount == 0 { logger.ErrorCF("mcp", "All MCP servers failed to connect", map[string]any{ "failed": len(allErrors), "total": enabledCount, }) return errors.Join(allErrors...) } if len(allErrors) > 0 { logger.WarnCF("mcp", "Some MCP servers failed to connect", map[string]any{ "failed": len(allErrors), "connected": connectedCount, "total": enabledCount, }) // Don't fail completely if some servers successfully connected } logger.InfoCF("mcp", "MCP server initialization complete", map[string]any{ "connected": connectedCount, "total": enabledCount, }) return nil } // ConnectServer connects to a single MCP server func (m *Manager) ConnectServer( ctx context.Context, name string, cfg config.MCPServerConfig, ) error { logger.InfoCF("mcp", "Connecting to MCP server", map[string]any{ "server": name, "command": cfg.Command, "args_count": len(cfg.Args), }) // Create client client := mcp.NewClient(&mcp.Implementation{ Name: "picoclaw", Version: "1.0.0", }, nil) // Create transport based on configuration // Auto-detect transport type if not explicitly specified var transport mcp.Transport transportType := cfg.Type // Auto-detect: if URL is provided, use SSE; if command is provided, use stdio if transportType == "" { if cfg.URL != "" { transportType = "sse" } else if cfg.Command != "" { transportType = "stdio" } else { return fmt.Errorf("either URL or command must be provided") } } switch transportType { case "sse", "http": if cfg.URL == "" { return fmt.Errorf("URL is required for SSE/HTTP transport") } // Configure DisableStandaloneSSE based on transport type. // - "http": Request-response only mode. Disable the standalone SSE stream // to avoid compatibility issues with servers that don't support GET /mcp. // - "sse": Bidirectional mode. Enable the standalone SSE stream to receive // server-initiated notifications (e.g., ToolListChangedNotification). // - Empty or auto-detected: Defaults to "sse" behavior (standalone SSE enabled). disableStandaloneSSE := (cfg.Type == "http") logger.DebugCF("mcp", "Using SSE/HTTP transport", map[string]any{ "server": name, "url": cfg.URL, "disableStandaloneSSE": disableStandaloneSSE, }) sseTransport := &mcp.StreamableClientTransport{ Endpoint: cfg.URL, DisableStandaloneSSE: disableStandaloneSSE, } // Add custom headers if provided if len(cfg.Headers) > 0 { // Create a custom HTTP client with header-injecting transport sseTransport.HTTPClient = &http.Client{ Transport: &headerTransport{ base: http.DefaultTransport, headers: cfg.Headers, }, } logger.DebugCF("mcp", "Added custom HTTP headers", map[string]any{ "server": name, "header_count": len(cfg.Headers), }) } transport = sseTransport case "stdio": if cfg.Command == "" { return fmt.Errorf("command is required for stdio transport") } logger.DebugCF("mcp", "Using stdio transport", map[string]any{ "server": name, "command": cfg.Command, }) // Create command with context cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...) // Build environment variables with proper override semantics // Use a map to ensure config variables override file variables envMap := make(map[string]string) // Start with parent process environment for _, e := range cmd.Environ() { if idx := strings.Index(e, "="); idx > 0 { envMap[e[:idx]] = e[idx+1:] } } // Load environment variables from file if specified if cfg.EnvFile != "" { envVars, err := loadEnvFile(cfg.EnvFile) if err != nil { return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err) } for k, v := range envVars { envMap[k] = v } logger.DebugCF("mcp", "Loaded environment variables from file", map[string]any{ "server": name, "envFile": cfg.EnvFile, "var_count": len(envVars), }) } // Environment variables from config override those from file for k, v := range cfg.Env { envMap[k] = v } // Convert map to slice env := make([]string, 0, len(envMap)) for k, v := range envMap { env = append(env, fmt.Sprintf("%s=%s", k, v)) } cmd.Env = env transport = &mcp.CommandTransport{Command: cmd} default: return fmt.Errorf( "unsupported transport type: %s (supported: stdio, sse, http)", transportType, ) } // Connect to server session, err := client.Connect(ctx, transport, nil) if err != nil { return fmt.Errorf("failed to connect: %w", err) } // Get server info initResult := session.InitializeResult() logger.InfoCF("mcp", "Connected to MCP server", map[string]any{ "server": name, "serverName": initResult.ServerInfo.Name, "serverVersion": initResult.ServerInfo.Version, "protocol": initResult.ProtocolVersion, }) // List available tools if supported var tools []*mcp.Tool if initResult.Capabilities.Tools != nil { for tool, err := range session.Tools(ctx, nil) { if err != nil { logger.WarnCF("mcp", "Error listing tool", map[string]any{ "server": name, "error": err.Error(), }) continue } tools = append(tools, tool) } logger.InfoCF("mcp", "Listed tools from MCP server", map[string]any{ "server": name, "toolCount": len(tools), }) } // Store connection m.mu.Lock() m.servers[name] = &ServerConnection{ Name: name, Client: client, Session: session, Tools: tools, } m.mu.Unlock() return nil } // GetServers returns all connected servers func (m *Manager) GetServers() map[string]*ServerConnection { m.mu.RLock() defer m.mu.RUnlock() result := make(map[string]*ServerConnection, len(m.servers)) for k, v := range m.servers { result[k] = v } return result } // GetServer returns a specific server connection func (m *Manager) GetServer(name string) (*ServerConnection, bool) { m.mu.RLock() defer m.mu.RUnlock() conn, ok := m.servers[name] return conn, ok } // CallTool calls a tool on a specific server func (m *Manager) CallTool( ctx context.Context, serverName, toolName string, arguments map[string]any, ) (*mcp.CallToolResult, error) { // Check if closed before acquiring lock (fast path) if m.closed.Load() { return nil, fmt.Errorf("manager is closed") } m.mu.RLock() // Double-check after acquiring lock to prevent TOCTOU race if m.closed.Load() { m.mu.RUnlock() return nil, fmt.Errorf("manager is closed") } conn, ok := m.servers[serverName] if ok { m.wg.Add(1) // Add to WaitGroup while holding the lock } m.mu.RUnlock() if !ok { return nil, fmt.Errorf("server %s not found", serverName) } defer m.wg.Done() params := &mcp.CallToolParams{ Name: toolName, Arguments: arguments, } result, err := conn.Session.CallTool(ctx, params) if err != nil { return nil, fmt.Errorf("failed to call tool: %w", err) } return result, nil } // Close closes all server connections func (m *Manager) Close() error { // Use Swap to atomically set closed=true and get the previous value // This prevents TOCTOU race with CallTool's closed check if m.closed.Swap(true) { return nil // already closed } // Wait for all in-flight CallTool calls to finish before closing sessions // After closed=true is set, no new CallTool can start (they check closed first) m.wg.Wait() m.mu.Lock() defer m.mu.Unlock() logger.InfoCF("mcp", "Closing all MCP server connections", map[string]any{ "count": len(m.servers), }) var errs []error for name, conn := range m.servers { if err := conn.Session.Close(); err != nil { logger.ErrorCF("mcp", "Failed to close server connection", map[string]any{ "server": name, "error": err.Error(), }) errs = append(errs, fmt.Errorf("server %s: %w", name, err)) } } m.servers = make(map[string]*ServerConnection) if len(errs) > 0 { return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...)) } return nil } // GetAllTools returns all tools from all connected servers func (m *Manager) GetAllTools() map[string][]*mcp.Tool { m.mu.RLock() defer m.mu.RUnlock() result := make(map[string][]*mcp.Tool) for name, conn := range m.servers { if len(conn.Tools) > 0 { result[name] = conn.Tools } } return result }