fix(mcp): retry tool calls on lost HTTP sessions and fix client lifecycle

This commit is contained in:
afjcjsbx
2026-04-24 20:20:57 +02:00
parent 8d51d306b3
commit 8f8af0874d
2 changed files with 377 additions and 36 deletions
+162 -36
View File
@@ -117,10 +117,12 @@ func loadEnvFile(path string) (map[string]string, error) {
// ServerConnection represents a connection to an MCP server
type ServerConnection struct {
Name string
Client *mcp.Client
Session *mcp.ClientSession
Tools []*mcp.Tool
Name string
Config config.MCPServerConfig
Client *mcp.Client
Session *mcp.ClientSession
Tools []*mcp.Tool
reconnectMu sync.Mutex
}
// Manager manages multiple MCP server connections
@@ -131,6 +133,8 @@ type Manager struct {
wg sync.WaitGroup // tracks in-flight CallTool calls
}
var connectServerFunc = connectServer
// NewManager creates a new MCP manager
func NewManager() *Manager {
return &Manager{
@@ -260,6 +264,28 @@ func (m *Manager) ConnectServer(
name string,
cfg config.MCPServerConfig,
) error {
conn, err := connectServerFunc(ctx, name, cfg)
if err != nil {
return err
}
m.mu.Lock()
defer m.mu.Unlock()
if m.closed.Load() {
_ = conn.Session.Close()
return fmt.Errorf("manager is closed")
}
m.servers[name] = conn
return nil
}
func connectServer(
ctx context.Context,
name string,
cfg config.MCPServerConfig,
) (*ServerConnection, error) {
logger.InfoCF("mcp", "Connecting to MCP server",
map[string]any{
"server": name,
@@ -285,14 +311,14 @@ func (m *Manager) ConnectServer(
} else if cfg.Command != "" {
transportType = "stdio"
} else {
return fmt.Errorf("either URL or command must be provided")
return nil, fmt.Errorf("either URL or command must be provided")
}
}
switch transportType {
case "sse", "http":
if cfg.URL == "" {
return fmt.Errorf("URL is required for SSE/HTTP transport")
return nil, fmt.Errorf("URL is required for SSE/HTTP transport")
}
// Configure DisableStandaloneSSE based on transport type.
@@ -334,7 +360,7 @@ func (m *Manager) ConnectServer(
transport = sseTransport
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("command is required for stdio transport")
return nil, fmt.Errorf("command is required for stdio transport")
}
logger.DebugCF("mcp", "Using stdio transport",
map[string]any{
@@ -359,7 +385,7 @@ func (m *Manager) ConnectServer(
if cfg.EnvFile != "" {
envVars, err := loadEnvFile(cfg.EnvFile)
if err != nil {
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
return nil, fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
}
for k, v := range envVars {
envMap[k] = v
@@ -385,7 +411,7 @@ func (m *Manager) ConnectServer(
cmd.Env = env
transport = &isolatedCommandTransport{Command: cmd}
default:
return fmt.Errorf(
return nil, fmt.Errorf(
"unsupported transport type: %s (supported: stdio, sse, http)",
transportType,
)
@@ -394,7 +420,7 @@ func (m *Manager) ConnectServer(
// Connect to server
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
return nil, fmt.Errorf("failed to connect: %w", err)
}
// Get server info
@@ -408,38 +434,19 @@ func (m *Manager) ConnectServer(
})
// List available tools if supported
var tools []*mcp.Tool
if initResult.Capabilities.Tools != nil {
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
logger.WarnCF("mcp", "Error listing tool",
map[string]any{
"server": name,
"error": err.Error(),
})
continue
}
tools = append(tools, tool)
}
logger.InfoCF("mcp", "Listed tools from MCP server",
map[string]any{
"server": name,
"toolCount": len(tools),
})
tools, err := listServerTools(ctx, name, session, initResult)
if err != nil {
_ = session.Close()
return nil, err
}
// Store connection
m.mu.Lock()
m.servers[name] = &ServerConnection{
return &ServerConnection{
Name: name,
Config: cfg,
Client: client,
Session: session,
Tools: tools,
}
m.mu.Unlock()
return nil
}, nil
}
// GetServers returns all connected servers
@@ -498,12 +505,131 @@ func (m *Manager) CallTool(
result, err := conn.Session.CallTool(ctx, params)
if err != nil {
if shouldReconnectCallError(err) {
logger.WarnCF("mcp", "MCP server session was lost during tool call, reconnecting",
map[string]any{
"server": serverName,
"tool": toolName,
"error": err.Error(),
})
reconnectedConn, reconnectErr := m.reconnectServer(ctx, serverName, conn)
if reconnectErr != nil {
return nil, fmt.Errorf("failed to recover lost MCP session: %w", reconnectErr)
}
result, err = reconnectedConn.Session.CallTool(ctx, params)
if err == nil {
return result, nil
}
}
return nil, fmt.Errorf("failed to call tool: %w", err)
}
return result, nil
}
func listServerTools(
ctx context.Context,
name string,
session *mcp.ClientSession,
initResult *mcp.InitializeResult,
) ([]*mcp.Tool, error) {
var tools []*mcp.Tool
if initResult.Capabilities.Tools == nil {
return tools, nil
}
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
logger.WarnCF("mcp", "Error listing tool",
map[string]any{
"server": name,
"error": err.Error(),
})
continue
}
tools = append(tools, tool)
}
logger.InfoCF("mcp", "Listed tools from MCP server",
map[string]any{
"server": name,
"toolCount": len(tools),
})
return tools, nil
}
func shouldReconnectCallError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, mcp.ErrSessionMissing) {
return true
}
return strings.Contains(strings.ToLower(err.Error()), mcp.ErrSessionMissing.Error())
}
func (m *Manager) reconnectServer(
ctx context.Context,
serverName string,
staleConn *ServerConnection,
) (*ServerConnection, error) {
if staleConn == nil {
return nil, fmt.Errorf("server %s not found", serverName)
}
staleConn.reconnectMu.Lock()
defer staleConn.reconnectMu.Unlock()
if m.closed.Load() {
return nil, fmt.Errorf("manager is closed")
}
m.mu.RLock()
currentConn, ok := m.servers[serverName]
m.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("server %s not found", serverName)
}
if currentConn != staleConn {
return currentConn, nil
}
freshConn, err := connectServerFunc(ctx, serverName, staleConn.Config)
if err != nil {
return nil, err
}
m.mu.Lock()
if m.closed.Load() {
m.mu.Unlock()
_ = freshConn.Session.Close()
return nil, fmt.Errorf("manager is closed")
}
currentConn, ok = m.servers[serverName]
if !ok {
m.mu.Unlock()
_ = freshConn.Session.Close()
return nil, fmt.Errorf("server %s not found", serverName)
}
if currentConn == staleConn {
m.servers[serverName] = freshConn
staleToClose := staleConn
m.mu.Unlock()
_ = staleToClose.Session.Close()
return freshConn, nil
}
m.mu.Unlock()
_ = freshConn.Session.Close()
return currentConn, nil
}
// Close closes all server connections
func (m *Manager) Close() error {
// Use Swap to atomically set closed=true and get the previous value
+215
View File
@@ -2,11 +2,16 @@ package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
@@ -312,6 +317,81 @@ func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
})
}
func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) {
originalConnectServerFunc := connectServerFunc
t.Cleanup(func() {
connectServerFunc = originalConnectServerFunc
})
staleConn, staleTransport, err := newScriptedServerConnection(
"session-1",
nil,
fmt.Errorf(`sending "tools/call": failed to connect (session ID: session-1): %w`, sdkmcp.ErrSessionMissing),
)
if err != nil {
t.Fatalf("newScriptedServerConnection(stale) error = %v", err)
}
freshConn, freshTransport, err := newScriptedServerConnection(
"session-2",
&sdkmcp.CallToolResult{
Content: []sdkmcp.Content{
&sdkmcp.TextContent{Text: "reconnected"},
},
},
nil,
)
if err != nil {
t.Fatalf("newScriptedServerConnection(fresh) error = %v", err)
}
connectCalls := 0
connectServerFunc = func(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) {
connectCalls++
if connectCalls == 1 {
return freshConn, nil
}
return nil, fmt.Errorf("unexpected reconnect attempt %d", connectCalls)
}
mgr := NewManager()
mgr.servers["flaky"] = staleConn
result, err := mgr.CallTool(context.Background(), "flaky", "echo", map[string]any{
"query": "hello",
})
if err != nil {
t.Fatalf("CallTool() error = %v", err)
}
if result == nil || len(result.Content) != 1 {
t.Fatalf("CallTool() returned unexpected content: %#v", result)
}
text, ok := result.Content[0].(*sdkmcp.TextContent)
if !ok {
t.Fatalf("CallTool() content type = %T, want *sdkmcp.TextContent", result.Content[0])
}
if text.Text != "reconnected" {
t.Fatalf("CallTool() text = %q, want %q", text.Text, "reconnected")
}
conn, ok := mgr.GetServer("flaky")
if !ok {
t.Fatal("expected flaky server to remain connected after reconnect")
}
if conn.Session.ID() != "session-2" {
t.Fatalf("Session.ID() = %q, want %q", conn.Session.ID(), "session-2")
}
if connectCalls != 1 {
t.Fatalf("connectCalls = %d, want 1", connectCalls)
}
if staleTransport.toolCallCalls != 1 {
t.Fatalf("stale toolCallCalls = %d, want 1", staleTransport.toolCallCalls)
}
if freshTransport.toolCallCalls != 1 {
t.Fatalf("fresh toolCallCalls = %d, want 1", freshTransport.toolCallCalls)
}
}
func TestClose_IdempotentOnEmptyManager(t *testing.T) {
mgr := NewManager()
@@ -322,3 +402,138 @@ func TestClose_IdempotentOnEmptyManager(t *testing.T) {
t.Fatalf("second close should be idempotent, got: %v", err)
}
}
func newScriptedServerConnection(
sessionID string,
toolCallResult *sdkmcp.CallToolResult,
toolCallErr error,
) (*ServerConnection, *scriptedTransport, error) {
transport := &scriptedTransport{
sessionID: sessionID,
toolCallResult: toolCallResult,
toolCallErr: toolCallErr,
}
client := sdkmcp.NewClient(&sdkmcp.Implementation{
Name: "picoclaw-test",
Version: "1.0.0",
}, nil)
session, err := client.Connect(context.Background(), transport, nil)
if err != nil {
return nil, nil, err
}
return &ServerConnection{
Name: "flaky",
Config: config.MCPServerConfig{Enabled: true, Type: "http", URL: "https://example.invalid/mcp"},
Client: client,
Session: session,
Tools: []*sdkmcp.Tool{
{
Name: "echo",
Description: "Echo test tool",
InputSchema: map[string]any{"type": "object"},
},
},
}, transport, nil
}
type scriptedTransport struct {
sessionID string
toolCallResult *sdkmcp.CallToolResult
toolCallErr error
mu sync.Mutex
toolCallCalls int
closed bool
incoming chan jsonrpc.Message
}
func (t *scriptedTransport) Connect(context.Context) (sdkmcp.Connection, error) {
if t.incoming == nil {
t.incoming = make(chan jsonrpc.Message, 4)
}
return t, nil
}
func (t *scriptedTransport) Read(ctx context.Context) (jsonrpc.Message, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case msg, ok := <-t.incoming:
if !ok {
return nil, io.EOF
}
return msg, nil
}
}
func (t *scriptedTransport) Write(ctx context.Context, msg jsonrpc.Message) error {
req, ok := msg.(*jsonrpc.Request)
if !ok {
return nil
}
switch req.Method {
case "initialize":
payload, err := json.Marshal(&sdkmcp.InitializeResult{
ProtocolVersion: "2025-11-25",
ServerInfo: &sdkmcp.Implementation{
Name: "scripted-test-server",
Version: "1.0.0",
},
Capabilities: &sdkmcp.ServerCapabilities{
Tools: &sdkmcp.ToolCapabilities{},
},
})
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
return nil
}
case "notifications/initialized":
return nil
case "tools/call":
t.mu.Lock()
t.toolCallCalls++
t.mu.Unlock()
if t.toolCallErr != nil {
return t.toolCallErr
}
payload, err := json.Marshal(t.toolCallResult)
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
return nil
}
}
return fmt.Errorf("unexpected method %q", req.Method)
}
func (t *scriptedTransport) Close() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return nil
}
t.closed = true
close(t.incoming)
return nil
}
func (t *scriptedTransport) SessionID() string {
return t.sessionID
}