From 77d26e5ce32adebf31eb773d1dcb9ebabf57ae65 Mon Sep 17 00:00:00 2001 From: yuchou87 Date: Mon, 16 Feb 2026 19:33:31 +0800 Subject: [PATCH] fix(mcp): return aggregated error when all servers fail to connect - Add errors.Join to return aggregated error when all enabled MCP servers fail - Track enabled server count separately from total configured servers - Return error only when all servers fail, not for partial failures - Improve logging with accurate server counts (enabled vs connected) - Maintains fault tolerance: partial failures don't stop initialization --- pkg/mcp/manager.go | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go index d6ca28f76..be941ec23 100644 --- a/pkg/mcp/manager.go +++ b/pkg/mcp/manager.go @@ -3,6 +3,7 @@ package mcp import ( "bufio" "context" + "errors" "fmt" "net/http" "os" @@ -24,12 +25,12 @@ type headerTransport struct { func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Clone the request to avoid modifying the original req = req.Clone(req.Context()) - + // Add custom headers for key, value := range t.headers { req.Header.Set(key, value) } - + // Use the base transport base := t.base if base == nil { @@ -129,6 +130,7 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error var wg sync.WaitGroup errs := make(chan error, len(cfg.Tools.MCP.Servers)) + enabledCount := 0 for name, serverCfg := range cfg.Tools.MCP.Servers { if !serverCfg.Enabled { @@ -139,6 +141,7 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error continue } + enabledCount++ wg.Add(1) go func(name string, serverCfg config.MCPServerConfig) { defer wg.Done() @@ -163,20 +166,32 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error allErrors = append(allErrors, err) } + connectedCount := len(m.GetServers()) + + // If all enabled servers failed to connect, return aggregated error + if enabledCount > 0 && connectedCount == 0 { + logger.ErrorCF("mcp", "All MCP servers failed to connect", + map[string]interface{}{ + "failed": len(allErrors), + "total": enabledCount, + }) + return errors.Join(allErrors...) + } + if len(allErrors) > 0 { logger.WarnCF("mcp", "Some MCP servers failed to connect", map[string]interface{}{ - "failed": len(allErrors), - "total": len(cfg.Tools.MCP.Servers), + "failed": len(allErrors), + "connected": connectedCount, + "total": enabledCount, }) - // Don't fail completely if some servers fail to connect + // Don't fail completely if some servers successfully connected } - connectedCount := len(m.GetServers()) logger.InfoCF("mcp", "MCP server initialization complete", map[string]interface{}{ "connected": connectedCount, - "total": len(cfg.Tools.MCP.Servers), + "total": enabledCount, }) return nil @@ -223,11 +238,11 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP "server": name, "url": cfg.URL, }) - + sseTransport := &mcp.StreamableClientTransport{ Endpoint: cfg.URL, } - + // Add custom headers if provided if len(cfg.Headers) > 0 { // Create a custom HTTP client with header-injecting transport @@ -243,7 +258,7 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP "header_count": len(cfg.Headers), }) } - + transport = sseTransport case "stdio": if cfg.Command == "" { @@ -259,7 +274,7 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP // Set environment variables env := cmd.Environ() - + // Load environment variables from file if specified if cfg.EnvFile != "" { envVars, err := loadEnvFile(cfg.EnvFile) @@ -276,14 +291,14 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP "var_count": len(envVars), }) } - + // Environment variables from config override those from file if len(cfg.Env) > 0 { for k, v := range cfg.Env { env = append(env, fmt.Sprintf("%s=%s", k, v)) } } - + // Set environment if we added any variables if len(env) > len(cmd.Environ()) { cmd.Env = env