fix(mcp): return aggregated error when all servers fail to connect

- Add errors.Join to return aggregated error when all enabled MCP servers fail
- Track enabled server count separately from total configured servers
- Return error only when all servers fail, not for partial failures
- Improve logging with accurate server counts (enabled vs connected)
- Maintains fault tolerance: partial failures don't stop initialization
This commit is contained in:
yuchou87
2026-02-16 19:33:31 +08:00
parent 24610693e4
commit 77d26e5ce3
+28 -13
View File
@@ -3,6 +3,7 @@ package mcp
import (
"bufio"
"context"
"errors"
"fmt"
"net/http"
"os"
@@ -24,12 +25,12 @@ type headerTransport struct {
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
req = req.Clone(req.Context())
// Add custom headers
for key, value := range t.headers {
req.Header.Set(key, value)
}
// Use the base transport
base := t.base
if base == nil {
@@ -129,6 +130,7 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error
var wg sync.WaitGroup
errs := make(chan error, len(cfg.Tools.MCP.Servers))
enabledCount := 0
for name, serverCfg := range cfg.Tools.MCP.Servers {
if !serverCfg.Enabled {
@@ -139,6 +141,7 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error
continue
}
enabledCount++
wg.Add(1)
go func(name string, serverCfg config.MCPServerConfig) {
defer wg.Done()
@@ -163,20 +166,32 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error
allErrors = append(allErrors, err)
}
connectedCount := len(m.GetServers())
// If all enabled servers failed to connect, return aggregated error
if enabledCount > 0 && connectedCount == 0 {
logger.ErrorCF("mcp", "All MCP servers failed to connect",
map[string]interface{}{
"failed": len(allErrors),
"total": enabledCount,
})
return errors.Join(allErrors...)
}
if len(allErrors) > 0 {
logger.WarnCF("mcp", "Some MCP servers failed to connect",
map[string]interface{}{
"failed": len(allErrors),
"total": len(cfg.Tools.MCP.Servers),
"failed": len(allErrors),
"connected": connectedCount,
"total": enabledCount,
})
// Don't fail completely if some servers fail to connect
// Don't fail completely if some servers successfully connected
}
connectedCount := len(m.GetServers())
logger.InfoCF("mcp", "MCP server initialization complete",
map[string]interface{}{
"connected": connectedCount,
"total": len(cfg.Tools.MCP.Servers),
"total": enabledCount,
})
return nil
@@ -223,11 +238,11 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
"server": name,
"url": cfg.URL,
})
sseTransport := &mcp.StreamableClientTransport{
Endpoint: cfg.URL,
}
// Add custom headers if provided
if len(cfg.Headers) > 0 {
// Create a custom HTTP client with header-injecting transport
@@ -243,7 +258,7 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
"header_count": len(cfg.Headers),
})
}
transport = sseTransport
case "stdio":
if cfg.Command == "" {
@@ -259,7 +274,7 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
// Set environment variables
env := cmd.Environ()
// Load environment variables from file if specified
if cfg.EnvFile != "" {
envVars, err := loadEnvFile(cfg.EnvFile)
@@ -276,14 +291,14 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
"var_count": len(envVars),
})
}
// Environment variables from config override those from file
if len(cfg.Env) > 0 {
for k, v := range cfg.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
}
// Set environment if we added any variables
if len(env) > len(cmd.Environ()) {
cmd.Env = env