Files
picoclaw/pkg/mcp/manager_test.go
T

540 lines
13 KiB
Go

package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestLoadEnvFile(t *testing.T) {
tests := []struct {
name string
content string
expected map[string]string
expectErr bool
}{
{
name: "basic env file",
content: `API_KEY=secret123
DATABASE_URL=postgres://localhost/db
PORT=8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "with comments and empty lines",
content: `# This is a comment
API_KEY=secret123
# Another comment
DATABASE_URL=postgres://localhost/db
PORT=8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "with quoted values",
content: `API_KEY="secret with spaces"
NAME='single quoted'
PLAIN=no-quotes`,
expected: map[string]string{
"API_KEY": "secret with spaces",
"NAME": "single quoted",
"PLAIN": "no-quotes",
},
expectErr: false,
},
{
name: "with spaces around equals",
content: `API_KEY = secret123
DATABASE_URL= postgres://localhost/db
PORT =8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "invalid format - no equals",
content: `INVALID_LINE`,
expectErr: true,
},
{
name: "empty file",
content: ``,
expected: map[string]string{},
expectErr: false,
},
{
name: "only comments",
content: `# Comment 1
# Comment 2`,
expected: map[string]string{},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
envFile := filepath.Join(tmpDir, ".env")
if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
result, err := loadEnvFile(envFile)
if tt.expectErr {
if err == nil {
t.Errorf("Expected error but got none")
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if len(result) != len(tt.expected) {
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
}
for key, expectedValue := range tt.expected {
if actualValue, ok := result[key]; !ok {
t.Errorf("Expected key %s not found", key)
} else if actualValue != expectedValue {
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
}
}
})
}
}
func TestLoadEnvFileNotFound(t *testing.T) {
_, err := loadEnvFile("/nonexistent/file.env")
if err == nil {
t.Error("Expected error for nonexistent file")
}
}
func TestExpandHomeCommandPath(t *testing.T) {
homeDir := t.TempDir()
t.Setenv("HOME", homeDir)
t.Setenv("USERPROFILE", homeDir)
want := filepath.Join(homeDir, "bin", "my-mcp")
got := expandHomeCommandPath("~" + string(os.PathSeparator) + filepath.Join("bin", "my-mcp"))
if got != want {
t.Fatalf("expandHomeCommandPath() = %q, want %q", got, want)
}
if got := expandHomeCommandPath("npx"); got != "npx" {
t.Fatalf("expandHomeCommandPath() should leave bare commands unchanged, got %q", got)
}
}
func TestEnvFilePriority(t *testing.T) {
// Create a temporary .env file
tmpDir := t.TempDir()
envFile := filepath.Join(tmpDir, ".env")
envContent := `API_KEY=from_file
DATABASE_URL=from_file
SHARED_VAR=from_file`
if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil {
t.Fatalf("Failed to create .env file: %v", err)
}
// Load envFile
envVars, err := loadEnvFile(envFile)
if err != nil {
t.Fatalf("Failed to load env file: %v", err)
}
// Verify envFile variables
if envVars["API_KEY"] != "from_file" {
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
}
// Simulate config.Env overriding envFile
configEnv := map[string]string{
"SHARED_VAR": "from_config",
"NEW_VAR": "from_config",
}
// Merge: envFile first, then config overrides
merged := make(map[string]string)
for k, v := range envVars {
merged[k] = v
}
for k, v := range configEnv {
merged[k] = v
}
// Verify priority: config.Env should override envFile
if merged["SHARED_VAR"] != "from_config" {
t.Errorf(
"Expected SHARED_VAR=from_config (config should override file), got %s",
merged["SHARED_VAR"],
)
}
if merged["API_KEY"] != "from_file" {
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
}
if merged["NEW_VAR"] != "from_config" {
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
}
}
func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
mgr := NewManager()
mcpCfg := config.MCPConfig{
ToolConfig: config.ToolConfig{
Enabled: true,
},
Servers: map[string]config.MCPServerConfig{
"test-server": {
Enabled: true,
Command: "echo",
Args: []string{"ok"},
EnvFile: ".env",
},
},
}
err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "")
if err == nil {
t.Fatal("expected error for relative env_file with empty workspace path, got nil")
}
if !strings.Contains(err.Error(), "workspace path is empty") {
t.Fatalf("expected workspace path validation error, got: %v", err)
}
}
func TestNewManager_InitialState(t *testing.T) {
mgr := NewManager()
if mgr == nil {
t.Fatal("expected manager instance, got nil")
}
if len(mgr.GetServers()) != 0 {
t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers()))
}
}
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
mgr := NewManager()
err := mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
}
err = mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when no servers configured, got: %v", err)
}
}
func TestGetServers_ReturnsCopy(t *testing.T) {
mgr := NewManager()
mgr.servers["s1"] = &ServerConnection{Name: "s1"}
servers := mgr.GetServers()
delete(servers, "s1")
if _, ok := mgr.GetServer("s1"); !ok {
t.Fatal("expected internal manager state to remain unchanged")
}
}
func TestGetAllTools_FiltersEmptyTools(t *testing.T) {
mgr := NewManager()
mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil}
mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}}
all := mgr.GetAllTools()
if _, ok := all["empty"]; ok {
t.Fatal("expected server without tools to be excluded")
}
if _, ok := all["with-tools"]; !ok {
t.Fatal("expected server with tools to be included")
}
}
func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
t.Run("manager closed", func(t *testing.T) {
mgr := NewManager()
mgr.closed.Store(true)
_, err := mgr.CallTool(context.Background(), "s1", "tool", nil)
if err == nil || !strings.Contains(err.Error(), "manager is closed") {
t.Fatalf("expected manager closed error, got: %v", err)
}
})
t.Run("server missing", func(t *testing.T) {
mgr := NewManager()
_, err := mgr.CallTool(context.Background(), "missing", "tool", nil)
if err == nil || !strings.Contains(err.Error(), "not found") {
t.Fatalf("expected server not found error, got: %v", err)
}
})
}
func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) {
originalConnectServerFunc := connectServerFunc
t.Cleanup(func() {
connectServerFunc = originalConnectServerFunc
})
staleConn, staleTransport, err := newScriptedServerConnection(
"session-1",
nil,
fmt.Errorf(`sending "tools/call": failed to connect (session ID: session-1): %w`, sdkmcp.ErrSessionMissing),
)
if err != nil {
t.Fatalf("newScriptedServerConnection(stale) error = %v", err)
}
freshConn, freshTransport, err := newScriptedServerConnection(
"session-2",
&sdkmcp.CallToolResult{
Content: []sdkmcp.Content{
&sdkmcp.TextContent{Text: "reconnected"},
},
},
nil,
)
if err != nil {
t.Fatalf("newScriptedServerConnection(fresh) error = %v", err)
}
connectCalls := 0
connectServerFunc = func(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) {
connectCalls++
if connectCalls == 1 {
return freshConn, nil
}
return nil, fmt.Errorf("unexpected reconnect attempt %d", connectCalls)
}
mgr := NewManager()
mgr.servers["flaky"] = staleConn
result, err := mgr.CallTool(context.Background(), "flaky", "echo", map[string]any{
"query": "hello",
})
if err != nil {
t.Fatalf("CallTool() error = %v", err)
}
if result == nil || len(result.Content) != 1 {
t.Fatalf("CallTool() returned unexpected content: %#v", result)
}
text, ok := result.Content[0].(*sdkmcp.TextContent)
if !ok {
t.Fatalf("CallTool() content type = %T, want *sdkmcp.TextContent", result.Content[0])
}
if text.Text != "reconnected" {
t.Fatalf("CallTool() text = %q, want %q", text.Text, "reconnected")
}
conn, ok := mgr.GetServer("flaky")
if !ok {
t.Fatal("expected flaky server to remain connected after reconnect")
}
if conn.Session.ID() != "session-2" {
t.Fatalf("Session.ID() = %q, want %q", conn.Session.ID(), "session-2")
}
if connectCalls != 1 {
t.Fatalf("connectCalls = %d, want 1", connectCalls)
}
if staleTransport.toolCallCalls != 1 {
t.Fatalf("stale toolCallCalls = %d, want 1", staleTransport.toolCallCalls)
}
if freshTransport.toolCallCalls != 1 {
t.Fatalf("fresh toolCallCalls = %d, want 1", freshTransport.toolCallCalls)
}
}
func TestClose_IdempotentOnEmptyManager(t *testing.T) {
mgr := NewManager()
if err := mgr.Close(); err != nil {
t.Fatalf("first close should succeed, got: %v", err)
}
if err := mgr.Close(); err != nil {
t.Fatalf("second close should be idempotent, got: %v", err)
}
}
func newScriptedServerConnection(
sessionID string,
toolCallResult *sdkmcp.CallToolResult,
toolCallErr error,
) (*ServerConnection, *scriptedTransport, error) {
transport := &scriptedTransport{
sessionID: sessionID,
toolCallResult: toolCallResult,
toolCallErr: toolCallErr,
}
client := sdkmcp.NewClient(&sdkmcp.Implementation{
Name: "picoclaw-test",
Version: "1.0.0",
}, nil)
session, err := client.Connect(context.Background(), transport, nil)
if err != nil {
return nil, nil, err
}
return &ServerConnection{
Name: "flaky",
Config: config.MCPServerConfig{Enabled: true, Type: "http", URL: "https://example.invalid/mcp"},
Client: client,
Session: session,
Tools: []*sdkmcp.Tool{
{
Name: "echo",
Description: "Echo test tool",
InputSchema: map[string]any{"type": "object"},
},
},
}, transport, nil
}
type scriptedTransport struct {
sessionID string
toolCallResult *sdkmcp.CallToolResult
toolCallErr error
mu sync.Mutex
toolCallCalls int
closed bool
incoming chan jsonrpc.Message
}
func (t *scriptedTransport) Connect(context.Context) (sdkmcp.Connection, error) {
if t.incoming == nil {
t.incoming = make(chan jsonrpc.Message, 4)
}
return t, nil
}
func (t *scriptedTransport) Read(ctx context.Context) (jsonrpc.Message, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case msg, ok := <-t.incoming:
if !ok {
return nil, io.EOF
}
return msg, nil
}
}
func (t *scriptedTransport) Write(ctx context.Context, msg jsonrpc.Message) error {
req, ok := msg.(*jsonrpc.Request)
if !ok {
return nil
}
switch req.Method {
case "initialize":
payload, err := json.Marshal(&sdkmcp.InitializeResult{
ProtocolVersion: "2025-11-25",
ServerInfo: &sdkmcp.Implementation{
Name: "scripted-test-server",
Version: "1.0.0",
},
Capabilities: &sdkmcp.ServerCapabilities{
Tools: &sdkmcp.ToolCapabilities{},
},
})
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
return nil
}
case "notifications/initialized":
return nil
case "tools/call":
t.mu.Lock()
t.toolCallCalls++
t.mu.Unlock()
if t.toolCallErr != nil {
return t.toolCallErr
}
payload, err := json.Marshal(t.toolCallResult)
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
return nil
}
}
return fmt.Errorf("unexpected method %q", req.Method)
}
func (t *scriptedTransport) Close() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return nil
}
t.closed = true
close(t.incoming)
return nil
}
func (t *scriptedTransport) SessionID() string {
return t.sessionID
}