mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(mcp): retry tool calls on lost HTTP sessions and fix client lifecycle
This commit is contained in:
+162
-36
@@ -117,10 +117,12 @@ func loadEnvFile(path string) (map[string]string, error) {
|
||||
|
||||
// ServerConnection represents a connection to an MCP server
|
||||
type ServerConnection struct {
|
||||
Name string
|
||||
Client *mcp.Client
|
||||
Session *mcp.ClientSession
|
||||
Tools []*mcp.Tool
|
||||
Name string
|
||||
Config config.MCPServerConfig
|
||||
Client *mcp.Client
|
||||
Session *mcp.ClientSession
|
||||
Tools []*mcp.Tool
|
||||
reconnectMu sync.Mutex
|
||||
}
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
@@ -131,6 +133,8 @@ type Manager struct {
|
||||
wg sync.WaitGroup // tracks in-flight CallTool calls
|
||||
}
|
||||
|
||||
var connectServerFunc = connectServer
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
@@ -260,6 +264,28 @@ func (m *Manager) ConnectServer(
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) error {
|
||||
conn, err := connectServerFunc(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed.Load() {
|
||||
_ = conn.Session.Close()
|
||||
return fmt.Errorf("manager is closed")
|
||||
}
|
||||
|
||||
m.servers[name] = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func connectServer(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) (*ServerConnection, error) {
|
||||
logger.InfoCF("mcp", "Connecting to MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
@@ -285,14 +311,14 @@ func (m *Manager) ConnectServer(
|
||||
} else if cfg.Command != "" {
|
||||
transportType = "stdio"
|
||||
} else {
|
||||
return fmt.Errorf("either URL or command must be provided")
|
||||
return nil, fmt.Errorf("either URL or command must be provided")
|
||||
}
|
||||
}
|
||||
|
||||
switch transportType {
|
||||
case "sse", "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("URL is required for SSE/HTTP transport")
|
||||
return nil, fmt.Errorf("URL is required for SSE/HTTP transport")
|
||||
}
|
||||
|
||||
// Configure DisableStandaloneSSE based on transport type.
|
||||
@@ -334,7 +360,7 @@ func (m *Manager) ConnectServer(
|
||||
transport = sseTransport
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("command is required for stdio transport")
|
||||
return nil, fmt.Errorf("command is required for stdio transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using stdio transport",
|
||||
map[string]any{
|
||||
@@ -359,7 +385,7 @@ func (m *Manager) ConnectServer(
|
||||
if cfg.EnvFile != "" {
|
||||
envVars, err := loadEnvFile(cfg.EnvFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
||||
return nil, fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
||||
}
|
||||
for k, v := range envVars {
|
||||
envMap[k] = v
|
||||
@@ -385,7 +411,7 @@ func (m *Manager) ConnectServer(
|
||||
cmd.Env = env
|
||||
transport = &isolatedCommandTransport{Command: cmd}
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
return nil, fmt.Errorf(
|
||||
"unsupported transport type: %s (supported: stdio, sse, http)",
|
||||
transportType,
|
||||
)
|
||||
@@ -394,7 +420,7 @@ func (m *Manager) ConnectServer(
|
||||
// Connect to server
|
||||
session, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// Get server info
|
||||
@@ -408,38 +434,19 @@ func (m *Manager) ConnectServer(
|
||||
})
|
||||
|
||||
// List available tools if supported
|
||||
var tools []*mcp.Tool
|
||||
if initResult.Capabilities.Tools != nil {
|
||||
for tool, err := range session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "Error listing tool",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Listed tools from MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"toolCount": len(tools),
|
||||
})
|
||||
tools, err := listServerTools(ctx, name, session, initResult)
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store connection
|
||||
m.mu.Lock()
|
||||
m.servers[name] = &ServerConnection{
|
||||
return &ServerConnection{
|
||||
Name: name,
|
||||
Config: cfg,
|
||||
Client: client,
|
||||
Session: session,
|
||||
Tools: tools,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetServers returns all connected servers
|
||||
@@ -498,12 +505,131 @@ func (m *Manager) CallTool(
|
||||
|
||||
result, err := conn.Session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
if shouldReconnectCallError(err) {
|
||||
logger.WarnCF("mcp", "MCP server session was lost during tool call, reconnecting",
|
||||
map[string]any{
|
||||
"server": serverName,
|
||||
"tool": toolName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
|
||||
reconnectedConn, reconnectErr := m.reconnectServer(ctx, serverName, conn)
|
||||
if reconnectErr != nil {
|
||||
return nil, fmt.Errorf("failed to recover lost MCP session: %w", reconnectErr)
|
||||
}
|
||||
|
||||
result, err = reconnectedConn.Session.CallTool(ctx, params)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to call tool: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func listServerTools(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
session *mcp.ClientSession,
|
||||
initResult *mcp.InitializeResult,
|
||||
) ([]*mcp.Tool, error) {
|
||||
var tools []*mcp.Tool
|
||||
if initResult.Capabilities.Tools == nil {
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
for tool, err := range session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "Error listing tool",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Listed tools from MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"toolCount": len(tools),
|
||||
})
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func shouldReconnectCallError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, mcp.ErrSessionMissing) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), mcp.ErrSessionMissing.Error())
|
||||
}
|
||||
|
||||
func (m *Manager) reconnectServer(
|
||||
ctx context.Context,
|
||||
serverName string,
|
||||
staleConn *ServerConnection,
|
||||
) (*ServerConnection, error) {
|
||||
if staleConn == nil {
|
||||
return nil, fmt.Errorf("server %s not found", serverName)
|
||||
}
|
||||
|
||||
staleConn.reconnectMu.Lock()
|
||||
defer staleConn.reconnectMu.Unlock()
|
||||
|
||||
if m.closed.Load() {
|
||||
return nil, fmt.Errorf("manager is closed")
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
currentConn, ok := m.servers[serverName]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s not found", serverName)
|
||||
}
|
||||
if currentConn != staleConn {
|
||||
return currentConn, nil
|
||||
}
|
||||
|
||||
freshConn, err := connectServerFunc(ctx, serverName, staleConn.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
if m.closed.Load() {
|
||||
m.mu.Unlock()
|
||||
_ = freshConn.Session.Close()
|
||||
return nil, fmt.Errorf("manager is closed")
|
||||
}
|
||||
|
||||
currentConn, ok = m.servers[serverName]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
_ = freshConn.Session.Close()
|
||||
return nil, fmt.Errorf("server %s not found", serverName)
|
||||
}
|
||||
|
||||
if currentConn == staleConn {
|
||||
m.servers[serverName] = freshConn
|
||||
staleToClose := staleConn
|
||||
m.mu.Unlock()
|
||||
_ = staleToClose.Session.Close()
|
||||
return freshConn, nil
|
||||
}
|
||||
|
||||
m.mu.Unlock()
|
||||
_ = freshConn.Session.Close()
|
||||
return currentConn, nil
|
||||
}
|
||||
|
||||
// Close closes all server connections
|
||||
func (m *Manager) Close() error {
|
||||
// Use Swap to atomically set closed=true and get the previous value
|
||||
|
||||
@@ -2,11 +2,16 @@ package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -312,6 +317,81 @@ func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) {
|
||||
originalConnectServerFunc := connectServerFunc
|
||||
t.Cleanup(func() {
|
||||
connectServerFunc = originalConnectServerFunc
|
||||
})
|
||||
|
||||
staleConn, staleTransport, err := newScriptedServerConnection(
|
||||
"session-1",
|
||||
nil,
|
||||
fmt.Errorf(`sending "tools/call": failed to connect (session ID: session-1): %w`, sdkmcp.ErrSessionMissing),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("newScriptedServerConnection(stale) error = %v", err)
|
||||
}
|
||||
freshConn, freshTransport, err := newScriptedServerConnection(
|
||||
"session-2",
|
||||
&sdkmcp.CallToolResult{
|
||||
Content: []sdkmcp.Content{
|
||||
&sdkmcp.TextContent{Text: "reconnected"},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("newScriptedServerConnection(fresh) error = %v", err)
|
||||
}
|
||||
|
||||
connectCalls := 0
|
||||
connectServerFunc = func(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) {
|
||||
connectCalls++
|
||||
if connectCalls == 1 {
|
||||
return freshConn, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected reconnect attempt %d", connectCalls)
|
||||
}
|
||||
|
||||
mgr := NewManager()
|
||||
mgr.servers["flaky"] = staleConn
|
||||
|
||||
result, err := mgr.CallTool(context.Background(), "flaky", "echo", map[string]any{
|
||||
"query": "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool() error = %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) != 1 {
|
||||
t.Fatalf("CallTool() returned unexpected content: %#v", result)
|
||||
}
|
||||
|
||||
text, ok := result.Content[0].(*sdkmcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("CallTool() content type = %T, want *sdkmcp.TextContent", result.Content[0])
|
||||
}
|
||||
if text.Text != "reconnected" {
|
||||
t.Fatalf("CallTool() text = %q, want %q", text.Text, "reconnected")
|
||||
}
|
||||
|
||||
conn, ok := mgr.GetServer("flaky")
|
||||
if !ok {
|
||||
t.Fatal("expected flaky server to remain connected after reconnect")
|
||||
}
|
||||
if conn.Session.ID() != "session-2" {
|
||||
t.Fatalf("Session.ID() = %q, want %q", conn.Session.ID(), "session-2")
|
||||
}
|
||||
if connectCalls != 1 {
|
||||
t.Fatalf("connectCalls = %d, want 1", connectCalls)
|
||||
}
|
||||
if staleTransport.toolCallCalls != 1 {
|
||||
t.Fatalf("stale toolCallCalls = %d, want 1", staleTransport.toolCallCalls)
|
||||
}
|
||||
if freshTransport.toolCallCalls != 1 {
|
||||
t.Fatalf("fresh toolCallCalls = %d, want 1", freshTransport.toolCallCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose_IdempotentOnEmptyManager(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
@@ -322,3 +402,138 @@ func TestClose_IdempotentOnEmptyManager(t *testing.T) {
|
||||
t.Fatalf("second close should be idempotent, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newScriptedServerConnection(
|
||||
sessionID string,
|
||||
toolCallResult *sdkmcp.CallToolResult,
|
||||
toolCallErr error,
|
||||
) (*ServerConnection, *scriptedTransport, error) {
|
||||
transport := &scriptedTransport{
|
||||
sessionID: sessionID,
|
||||
toolCallResult: toolCallResult,
|
||||
toolCallErr: toolCallErr,
|
||||
}
|
||||
|
||||
client := sdkmcp.NewClient(&sdkmcp.Implementation{
|
||||
Name: "picoclaw-test",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
session, err := client.Connect(context.Background(), transport, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &ServerConnection{
|
||||
Name: "flaky",
|
||||
Config: config.MCPServerConfig{Enabled: true, Type: "http", URL: "https://example.invalid/mcp"},
|
||||
Client: client,
|
||||
Session: session,
|
||||
Tools: []*sdkmcp.Tool{
|
||||
{
|
||||
Name: "echo",
|
||||
Description: "Echo test tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}, transport, nil
|
||||
}
|
||||
|
||||
type scriptedTransport struct {
|
||||
sessionID string
|
||||
toolCallResult *sdkmcp.CallToolResult
|
||||
toolCallErr error
|
||||
|
||||
mu sync.Mutex
|
||||
toolCallCalls int
|
||||
closed bool
|
||||
incoming chan jsonrpc.Message
|
||||
}
|
||||
|
||||
func (t *scriptedTransport) Connect(context.Context) (sdkmcp.Connection, error) {
|
||||
if t.incoming == nil {
|
||||
t.incoming = make(chan jsonrpc.Message, 4)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *scriptedTransport) Read(ctx context.Context) (jsonrpc.Message, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case msg, ok := <-t.incoming:
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *scriptedTransport) Write(ctx context.Context, msg jsonrpc.Message) error {
|
||||
req, ok := msg.(*jsonrpc.Request)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
payload, err := json.Marshal(&sdkmcp.InitializeResult{
|
||||
ProtocolVersion: "2025-11-25",
|
||||
ServerInfo: &sdkmcp.Implementation{
|
||||
Name: "scripted-test-server",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
Capabilities: &sdkmcp.ServerCapabilities{
|
||||
Tools: &sdkmcp.ToolCapabilities{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
|
||||
return nil
|
||||
}
|
||||
|
||||
case "notifications/initialized":
|
||||
return nil
|
||||
|
||||
case "tools/call":
|
||||
t.mu.Lock()
|
||||
t.toolCallCalls++
|
||||
t.mu.Unlock()
|
||||
|
||||
if t.toolCallErr != nil {
|
||||
return t.toolCallErr
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(t.toolCallResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("unexpected method %q", req.Method)
|
||||
}
|
||||
|
||||
func (t *scriptedTransport) Close() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.closed {
|
||||
return nil
|
||||
}
|
||||
t.closed = true
|
||||
close(t.incoming)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *scriptedTransport) SessionID() string {
|
||||
return t.sessionID
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user