diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go index e28388827..92ea426a6 100644 --- a/pkg/mcp/manager.go +++ b/pkg/mcp/manager.go @@ -117,10 +117,12 @@ func loadEnvFile(path string) (map[string]string, error) { // ServerConnection represents a connection to an MCP server type ServerConnection struct { - Name string - Client *mcp.Client - Session *mcp.ClientSession - Tools []*mcp.Tool + Name string + Config config.MCPServerConfig + Client *mcp.Client + Session *mcp.ClientSession + Tools []*mcp.Tool + reconnectMu sync.Mutex } // Manager manages multiple MCP server connections @@ -131,6 +133,8 @@ type Manager struct { wg sync.WaitGroup // tracks in-flight CallTool calls } +var connectServerFunc = connectServer + // NewManager creates a new MCP manager func NewManager() *Manager { return &Manager{ @@ -260,6 +264,28 @@ func (m *Manager) ConnectServer( name string, cfg config.MCPServerConfig, ) error { + conn, err := connectServerFunc(ctx, name, cfg) + if err != nil { + return err + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed.Load() { + _ = conn.Session.Close() + return fmt.Errorf("manager is closed") + } + + m.servers[name] = conn + return nil +} + +func connectServer( + ctx context.Context, + name string, + cfg config.MCPServerConfig, +) (*ServerConnection, error) { logger.InfoCF("mcp", "Connecting to MCP server", map[string]any{ "server": name, @@ -285,14 +311,14 @@ func (m *Manager) ConnectServer( } else if cfg.Command != "" { transportType = "stdio" } else { - return fmt.Errorf("either URL or command must be provided") + return nil, fmt.Errorf("either URL or command must be provided") } } switch transportType { case "sse", "http": if cfg.URL == "" { - return fmt.Errorf("URL is required for SSE/HTTP transport") + return nil, fmt.Errorf("URL is required for SSE/HTTP transport") } // Configure DisableStandaloneSSE based on transport type. @@ -334,7 +360,7 @@ func (m *Manager) ConnectServer( transport = sseTransport case "stdio": if cfg.Command == "" { - return fmt.Errorf("command is required for stdio transport") + return nil, fmt.Errorf("command is required for stdio transport") } logger.DebugCF("mcp", "Using stdio transport", map[string]any{ @@ -359,7 +385,7 @@ func (m *Manager) ConnectServer( if cfg.EnvFile != "" { envVars, err := loadEnvFile(cfg.EnvFile) if err != nil { - return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err) + return nil, fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err) } for k, v := range envVars { envMap[k] = v @@ -385,7 +411,7 @@ func (m *Manager) ConnectServer( cmd.Env = env transport = &isolatedCommandTransport{Command: cmd} default: - return fmt.Errorf( + return nil, fmt.Errorf( "unsupported transport type: %s (supported: stdio, sse, http)", transportType, ) @@ -394,7 +420,7 @@ func (m *Manager) ConnectServer( // Connect to server session, err := client.Connect(ctx, transport, nil) if err != nil { - return fmt.Errorf("failed to connect: %w", err) + return nil, fmt.Errorf("failed to connect: %w", err) } // Get server info @@ -408,38 +434,19 @@ func (m *Manager) ConnectServer( }) // List available tools if supported - var tools []*mcp.Tool - if initResult.Capabilities.Tools != nil { - for tool, err := range session.Tools(ctx, nil) { - if err != nil { - logger.WarnCF("mcp", "Error listing tool", - map[string]any{ - "server": name, - "error": err.Error(), - }) - continue - } - tools = append(tools, tool) - } - - logger.InfoCF("mcp", "Listed tools from MCP server", - map[string]any{ - "server": name, - "toolCount": len(tools), - }) + tools, err := listServerTools(ctx, name, session, initResult) + if err != nil { + _ = session.Close() + return nil, err } - // Store connection - m.mu.Lock() - m.servers[name] = &ServerConnection{ + return &ServerConnection{ Name: name, + Config: cfg, Client: client, Session: session, Tools: tools, - } - m.mu.Unlock() - - return nil + }, nil } // GetServers returns all connected servers @@ -498,12 +505,131 @@ func (m *Manager) CallTool( result, err := conn.Session.CallTool(ctx, params) if err != nil { + if shouldReconnectCallError(err) { + logger.WarnCF("mcp", "MCP server session was lost during tool call, reconnecting", + map[string]any{ + "server": serverName, + "tool": toolName, + "error": err.Error(), + }) + + reconnectedConn, reconnectErr := m.reconnectServer(ctx, serverName, conn) + if reconnectErr != nil { + return nil, fmt.Errorf("failed to recover lost MCP session: %w", reconnectErr) + } + + result, err = reconnectedConn.Session.CallTool(ctx, params) + if err == nil { + return result, nil + } + } + return nil, fmt.Errorf("failed to call tool: %w", err) } return result, nil } +func listServerTools( + ctx context.Context, + name string, + session *mcp.ClientSession, + initResult *mcp.InitializeResult, +) ([]*mcp.Tool, error) { + var tools []*mcp.Tool + if initResult.Capabilities.Tools == nil { + return tools, nil + } + + for tool, err := range session.Tools(ctx, nil) { + if err != nil { + logger.WarnCF("mcp", "Error listing tool", + map[string]any{ + "server": name, + "error": err.Error(), + }) + continue + } + tools = append(tools, tool) + } + + logger.InfoCF("mcp", "Listed tools from MCP server", + map[string]any{ + "server": name, + "toolCount": len(tools), + }) + + return tools, nil +} + +func shouldReconnectCallError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, mcp.ErrSessionMissing) { + return true + } + return strings.Contains(strings.ToLower(err.Error()), mcp.ErrSessionMissing.Error()) +} + +func (m *Manager) reconnectServer( + ctx context.Context, + serverName string, + staleConn *ServerConnection, +) (*ServerConnection, error) { + if staleConn == nil { + return nil, fmt.Errorf("server %s not found", serverName) + } + + staleConn.reconnectMu.Lock() + defer staleConn.reconnectMu.Unlock() + + if m.closed.Load() { + return nil, fmt.Errorf("manager is closed") + } + + m.mu.RLock() + currentConn, ok := m.servers[serverName] + m.mu.RUnlock() + if !ok { + return nil, fmt.Errorf("server %s not found", serverName) + } + if currentConn != staleConn { + return currentConn, nil + } + + freshConn, err := connectServerFunc(ctx, serverName, staleConn.Config) + if err != nil { + return nil, err + } + + m.mu.Lock() + if m.closed.Load() { + m.mu.Unlock() + _ = freshConn.Session.Close() + return nil, fmt.Errorf("manager is closed") + } + + currentConn, ok = m.servers[serverName] + if !ok { + m.mu.Unlock() + _ = freshConn.Session.Close() + return nil, fmt.Errorf("server %s not found", serverName) + } + + if currentConn == staleConn { + m.servers[serverName] = freshConn + staleToClose := staleConn + m.mu.Unlock() + _ = staleToClose.Session.Close() + return freshConn, nil + } + + m.mu.Unlock() + _ = freshConn.Session.Close() + return currentConn, nil +} + // Close closes all server connections func (m *Manager) Close() error { // Use Swap to atomically set closed=true and get the previous value diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go index fff315655..682d4c346 100644 --- a/pkg/mcp/manager_test.go +++ b/pkg/mcp/manager_test.go @@ -2,11 +2,16 @@ package mcp import ( "context" + "encoding/json" + "fmt" + "io" "os" "path/filepath" "strings" + "sync" "testing" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/sipeed/picoclaw/pkg/config" @@ -312,6 +317,81 @@ func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) { }) } +func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) { + originalConnectServerFunc := connectServerFunc + t.Cleanup(func() { + connectServerFunc = originalConnectServerFunc + }) + + staleConn, staleTransport, err := newScriptedServerConnection( + "session-1", + nil, + fmt.Errorf(`sending "tools/call": failed to connect (session ID: session-1): %w`, sdkmcp.ErrSessionMissing), + ) + if err != nil { + t.Fatalf("newScriptedServerConnection(stale) error = %v", err) + } + freshConn, freshTransport, err := newScriptedServerConnection( + "session-2", + &sdkmcp.CallToolResult{ + Content: []sdkmcp.Content{ + &sdkmcp.TextContent{Text: "reconnected"}, + }, + }, + nil, + ) + if err != nil { + t.Fatalf("newScriptedServerConnection(fresh) error = %v", err) + } + + connectCalls := 0 + connectServerFunc = func(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) { + connectCalls++ + if connectCalls == 1 { + return freshConn, nil + } + return nil, fmt.Errorf("unexpected reconnect attempt %d", connectCalls) + } + + mgr := NewManager() + mgr.servers["flaky"] = staleConn + + result, err := mgr.CallTool(context.Background(), "flaky", "echo", map[string]any{ + "query": "hello", + }) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if result == nil || len(result.Content) != 1 { + t.Fatalf("CallTool() returned unexpected content: %#v", result) + } + + text, ok := result.Content[0].(*sdkmcp.TextContent) + if !ok { + t.Fatalf("CallTool() content type = %T, want *sdkmcp.TextContent", result.Content[0]) + } + if text.Text != "reconnected" { + t.Fatalf("CallTool() text = %q, want %q", text.Text, "reconnected") + } + + conn, ok := mgr.GetServer("flaky") + if !ok { + t.Fatal("expected flaky server to remain connected after reconnect") + } + if conn.Session.ID() != "session-2" { + t.Fatalf("Session.ID() = %q, want %q", conn.Session.ID(), "session-2") + } + if connectCalls != 1 { + t.Fatalf("connectCalls = %d, want 1", connectCalls) + } + if staleTransport.toolCallCalls != 1 { + t.Fatalf("stale toolCallCalls = %d, want 1", staleTransport.toolCallCalls) + } + if freshTransport.toolCallCalls != 1 { + t.Fatalf("fresh toolCallCalls = %d, want 1", freshTransport.toolCallCalls) + } +} + func TestClose_IdempotentOnEmptyManager(t *testing.T) { mgr := NewManager() @@ -322,3 +402,138 @@ func TestClose_IdempotentOnEmptyManager(t *testing.T) { t.Fatalf("second close should be idempotent, got: %v", err) } } + +func newScriptedServerConnection( + sessionID string, + toolCallResult *sdkmcp.CallToolResult, + toolCallErr error, +) (*ServerConnection, *scriptedTransport, error) { + transport := &scriptedTransport{ + sessionID: sessionID, + toolCallResult: toolCallResult, + toolCallErr: toolCallErr, + } + + client := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "picoclaw-test", + Version: "1.0.0", + }, nil) + session, err := client.Connect(context.Background(), transport, nil) + if err != nil { + return nil, nil, err + } + + return &ServerConnection{ + Name: "flaky", + Config: config.MCPServerConfig{Enabled: true, Type: "http", URL: "https://example.invalid/mcp"}, + Client: client, + Session: session, + Tools: []*sdkmcp.Tool{ + { + Name: "echo", + Description: "Echo test tool", + InputSchema: map[string]any{"type": "object"}, + }, + }, + }, transport, nil +} + +type scriptedTransport struct { + sessionID string + toolCallResult *sdkmcp.CallToolResult + toolCallErr error + + mu sync.Mutex + toolCallCalls int + closed bool + incoming chan jsonrpc.Message +} + +func (t *scriptedTransport) Connect(context.Context) (sdkmcp.Connection, error) { + if t.incoming == nil { + t.incoming = make(chan jsonrpc.Message, 4) + } + return t, nil +} + +func (t *scriptedTransport) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-t.incoming: + if !ok { + return nil, io.EOF + } + return msg, nil + } +} + +func (t *scriptedTransport) Write(ctx context.Context, msg jsonrpc.Message) error { + req, ok := msg.(*jsonrpc.Request) + if !ok { + return nil + } + + switch req.Method { + case "initialize": + payload, err := json.Marshal(&sdkmcp.InitializeResult{ + ProtocolVersion: "2025-11-25", + ServerInfo: &sdkmcp.Implementation{ + Name: "scripted-test-server", + Version: "1.0.0", + }, + Capabilities: &sdkmcp.ServerCapabilities{ + Tools: &sdkmcp.ToolCapabilities{}, + }, + }) + if err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}: + return nil + } + + case "notifications/initialized": + return nil + + case "tools/call": + t.mu.Lock() + t.toolCallCalls++ + t.mu.Unlock() + + if t.toolCallErr != nil { + return t.toolCallErr + } + + payload, err := json.Marshal(t.toolCallResult) + if err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}: + return nil + } + } + + return fmt.Errorf("unexpected method %q", req.Method) +} + +func (t *scriptedTransport) Close() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.closed { + return nil + } + t.closed = true + close(t.incoming) + return nil +} + +func (t *scriptedTransport) SessionID() string { + return t.sessionID +}