mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(mcp): return aggregated error when all servers fail to connect
- Add errors.Join to return aggregated error when all enabled MCP servers fail - Track enabled server count separately from total configured servers - Return error only when all servers fail, not for partial failures - Improve logging with accurate server counts (enabled vs connected) - Maintains fault tolerance: partial failures don't stop initialization
This commit is contained in:
+28
-13
@@ -3,6 +3,7 @@ package mcp
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -24,12 +25,12 @@ type headerTransport struct {
|
||||
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
|
||||
// Add custom headers
|
||||
for key, value := range t.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
|
||||
// Use the base transport
|
||||
base := t.base
|
||||
if base == nil {
|
||||
@@ -129,6 +130,7 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, len(cfg.Tools.MCP.Servers))
|
||||
enabledCount := 0
|
||||
|
||||
for name, serverCfg := range cfg.Tools.MCP.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
@@ -139,6 +141,7 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error
|
||||
continue
|
||||
}
|
||||
|
||||
enabledCount++
|
||||
wg.Add(1)
|
||||
go func(name string, serverCfg config.MCPServerConfig) {
|
||||
defer wg.Done()
|
||||
@@ -163,20 +166,32 @@ func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error
|
||||
allErrors = append(allErrors, err)
|
||||
}
|
||||
|
||||
connectedCount := len(m.GetServers())
|
||||
|
||||
// If all enabled servers failed to connect, return aggregated error
|
||||
if enabledCount > 0 && connectedCount == 0 {
|
||||
logger.ErrorCF("mcp", "All MCP servers failed to connect",
|
||||
map[string]interface{}{
|
||||
"failed": len(allErrors),
|
||||
"total": enabledCount,
|
||||
})
|
||||
return errors.Join(allErrors...)
|
||||
}
|
||||
|
||||
if len(allErrors) > 0 {
|
||||
logger.WarnCF("mcp", "Some MCP servers failed to connect",
|
||||
map[string]interface{}{
|
||||
"failed": len(allErrors),
|
||||
"total": len(cfg.Tools.MCP.Servers),
|
||||
"failed": len(allErrors),
|
||||
"connected": connectedCount,
|
||||
"total": enabledCount,
|
||||
})
|
||||
// Don't fail completely if some servers fail to connect
|
||||
// Don't fail completely if some servers successfully connected
|
||||
}
|
||||
|
||||
connectedCount := len(m.GetServers())
|
||||
logger.InfoCF("mcp", "MCP server initialization complete",
|
||||
map[string]interface{}{
|
||||
"connected": connectedCount,
|
||||
"total": len(cfg.Tools.MCP.Servers),
|
||||
"total": enabledCount,
|
||||
})
|
||||
|
||||
return nil
|
||||
@@ -223,11 +238,11 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
|
||||
"server": name,
|
||||
"url": cfg.URL,
|
||||
})
|
||||
|
||||
|
||||
sseTransport := &mcp.StreamableClientTransport{
|
||||
Endpoint: cfg.URL,
|
||||
}
|
||||
|
||||
|
||||
// Add custom headers if provided
|
||||
if len(cfg.Headers) > 0 {
|
||||
// Create a custom HTTP client with header-injecting transport
|
||||
@@ -243,7 +258,7 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
|
||||
"header_count": len(cfg.Headers),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
transport = sseTransport
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
@@ -259,7 +274,7 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
|
||||
|
||||
// Set environment variables
|
||||
env := cmd.Environ()
|
||||
|
||||
|
||||
// Load environment variables from file if specified
|
||||
if cfg.EnvFile != "" {
|
||||
envVars, err := loadEnvFile(cfg.EnvFile)
|
||||
@@ -276,14 +291,14 @@ func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCP
|
||||
"var_count": len(envVars),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Environment variables from config override those from file
|
||||
if len(cfg.Env) > 0 {
|
||||
for k, v := range cfg.Env {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Set environment if we added any variables
|
||||
if len(env) > len(cmd.Environ()) {
|
||||
cmd.Env = env
|
||||
|
||||
Reference in New Issue
Block a user