package mcp import ( "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "sync" "testing" "time" "github.com/modelcontextprotocol/go-sdk/jsonrpc" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/sipeed/picoclaw/pkg/config" runtimeevents "github.com/sipeed/picoclaw/pkg/events" ) func TestLoadEnvFile(t *testing.T) { tests := []struct { name string content string expected map[string]string expectErr bool }{ { name: "basic env file", content: `API_KEY=secret123 DATABASE_URL=postgres://localhost/db PORT=8080`, expected: map[string]string{ "API_KEY": "secret123", "DATABASE_URL": "postgres://localhost/db", "PORT": "8080", }, expectErr: false, }, { name: "with comments and empty lines", content: `# This is a comment API_KEY=secret123 # Another comment DATABASE_URL=postgres://localhost/db PORT=8080`, expected: map[string]string{ "API_KEY": "secret123", "DATABASE_URL": "postgres://localhost/db", "PORT": "8080", }, expectErr: false, }, { name: "with quoted values", content: `API_KEY="secret with spaces" NAME='single quoted' PLAIN=no-quotes`, expected: map[string]string{ "API_KEY": "secret with spaces", "NAME": "single quoted", "PLAIN": "no-quotes", }, expectErr: false, }, { name: "with spaces around equals", content: `API_KEY = secret123 DATABASE_URL= postgres://localhost/db PORT =8080`, expected: map[string]string{ "API_KEY": "secret123", "DATABASE_URL": "postgres://localhost/db", "PORT": "8080", }, expectErr: false, }, { name: "invalid format - no equals", content: `INVALID_LINE`, expectErr: true, }, { name: "empty file", content: ``, expected: map[string]string{}, expectErr: false, }, { name: "only comments", content: `# Comment 1 # Comment 2`, expected: map[string]string{}, expectErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tmpDir := t.TempDir() envFile := filepath.Join(tmpDir, ".env") if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil { t.Fatalf("Failed to create test file: %v", err) } result, err := loadEnvFile(envFile) if tt.expectErr { if err == nil { t.Errorf("Expected error but got none") } return } if err != nil { t.Errorf("Unexpected error: %v", err) return } if len(result) != len(tt.expected) { t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result)) } for key, expectedValue := range tt.expected { if actualValue, ok := result[key]; !ok { t.Errorf("Expected key %s not found", key) } else if actualValue != expectedValue { t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue) } } }) } } func TestLoadEnvFileNotFound(t *testing.T) { _, err := loadEnvFile("/nonexistent/file.env") if err == nil { t.Error("Expected error for nonexistent file") } } func TestExpandHomeCommandPath(t *testing.T) { homeDir := t.TempDir() t.Setenv("HOME", homeDir) t.Setenv("USERPROFILE", homeDir) want := filepath.Join(homeDir, "bin", "my-mcp") got := expandHomeCommandPath("~" + string(os.PathSeparator) + filepath.Join("bin", "my-mcp")) if got != want { t.Fatalf("expandHomeCommandPath() = %q, want %q", got, want) } if got := expandHomeCommandPath("npx"); got != "npx" { t.Fatalf("expandHomeCommandPath() should leave bare commands unchanged, got %q", got) } } func TestEnvFilePriority(t *testing.T) { // Create a temporary .env file tmpDir := t.TempDir() envFile := filepath.Join(tmpDir, ".env") envContent := `API_KEY=from_file DATABASE_URL=from_file SHARED_VAR=from_file` if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil { t.Fatalf("Failed to create .env file: %v", err) } // Load envFile envVars, err := loadEnvFile(envFile) if err != nil { t.Fatalf("Failed to load env file: %v", err) } // Verify envFile variables if envVars["API_KEY"] != "from_file" { t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"]) } // Simulate config.Env overriding envFile configEnv := map[string]string{ "SHARED_VAR": "from_config", "NEW_VAR": "from_config", } // Merge: envFile first, then config overrides merged := make(map[string]string) for k, v := range envVars { merged[k] = v } for k, v := range configEnv { merged[k] = v } // Verify priority: config.Env should override envFile if merged["SHARED_VAR"] != "from_config" { t.Errorf( "Expected SHARED_VAR=from_config (config should override file), got %s", merged["SHARED_VAR"], ) } if merged["API_KEY"] != "from_file" { t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"]) } if merged["NEW_VAR"] != "from_config" { t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"]) } } func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) { mgr := NewManager() mcpCfg := config.MCPConfig{ ToolConfig: config.ToolConfig{ Enabled: true, }, Servers: map[string]config.MCPServerConfig{ "test-server": { Enabled: true, Command: "echo", Args: []string{"ok"}, EnvFile: ".env", }, }, } err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "") if err == nil { t.Fatal("expected error for relative env_file with empty workspace path, got nil") } if !strings.Contains(err.Error(), "workspace path is empty") { t.Fatalf("expected workspace path validation error, got: %v", err) } } func TestNewManager_InitialState(t *testing.T) { mgr := NewManager() if mgr == nil { t.Fatal("expected manager instance, got nil") } if len(mgr.GetServers()) != 0 { t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers())) } } func TestConnectServerPublishesRuntimeEvents(t *testing.T) { originalConnectServerFunc := connectServerFunc t.Cleanup(func() { connectServerFunc = originalConnectServerFunc }) eventBus := runtimeevents.NewBus() defer func() { if err := eventBus.Close(); err != nil { t.Errorf("event bus close failed: %v", err) } }() _, eventsCh, err := eventBus.Channel().OfKind( runtimeevents.KindMCPServerConnected, runtimeevents.KindMCPServerFailed, ).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "mcp-events", Buffer: 2}) if err != nil { t.Fatalf("SubscribeChan failed: %v", err) } connectServerFunc = func( _ context.Context, name string, cfg config.MCPServerConfig, ) (*ServerConnection, error) { if name == "bad" { return nil, fmt.Errorf("connect failed") } return &ServerConnection{ Name: name, Config: cfg, Tools: []*sdkmcp.Tool{{Name: "echo"}}, }, nil } mgr := NewManager(WithRuntimeEvents(eventBus)) err = mgr.ConnectServer(context.Background(), "good", config.MCPServerConfig{ Type: "stdio", Command: "echo", }) if err != nil { t.Fatalf("ConnectServer(good) error = %v", err) } connected := receiveMCPRuntimeEvent(t, eventsCh) if connected.Kind != runtimeevents.KindMCPServerConnected || connected.Source.Name != "good" || connected.Severity != runtimeevents.SeverityInfo { t.Fatalf("connected event = %+v", connected) } if connected.Attrs["server"] != "good" || connected.Attrs["type"] != "stdio" || connected.Attrs["tool_count"] != 1 { t.Fatalf("connected attrs = %#v", connected.Attrs) } err = mgr.ConnectServer(context.Background(), "bad", config.MCPServerConfig{ Type: "stdio", Command: "echo", }) if err == nil { t.Fatal("expected ConnectServer(bad) to fail") } failed := receiveMCPRuntimeEvent(t, eventsCh) if failed.Kind != runtimeevents.KindMCPServerFailed || failed.Source.Name != "bad" || failed.Severity != runtimeevents.SeverityError { t.Fatalf("failed event = %+v", failed) } if failed.Attrs["server"] != "bad" || failed.Attrs["error"] != "connect failed" { t.Fatalf("failed attrs = %#v", failed.Attrs) } } func receiveMCPRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event { t.Helper() select { case evt, ok := <-ch: if !ok { t.Fatal("runtime event channel closed before expected event") } return evt case <-time.After(time.Second): t.Fatal("timed out waiting for runtime event") return runtimeevents.Event{} } } func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) { mgr := NewManager() err := mgr.LoadFromMCPConfig( context.Background(), config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}}, "/tmp", ) if err != nil { t.Fatalf("expected nil error when MCP disabled, got: %v", err) } err = mgr.LoadFromMCPConfig( context.Background(), config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}}, "/tmp", ) if err != nil { t.Fatalf("expected nil error when no servers configured, got: %v", err) } } func TestGetServers_ReturnsCopy(t *testing.T) { mgr := NewManager() mgr.servers["s1"] = &ServerConnection{Name: "s1"} servers := mgr.GetServers() delete(servers, "s1") if _, ok := mgr.GetServer("s1"); !ok { t.Fatal("expected internal manager state to remain unchanged") } } func TestGetAllTools_FiltersEmptyTools(t *testing.T) { mgr := NewManager() mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil} mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}} all := mgr.GetAllTools() if _, ok := all["empty"]; ok { t.Fatal("expected server without tools to be excluded") } if _, ok := all["with-tools"]; !ok { t.Fatal("expected server with tools to be included") } } func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) { t.Run("manager closed", func(t *testing.T) { mgr := NewManager() mgr.closed.Store(true) _, err := mgr.CallTool(context.Background(), "s1", "tool", nil) if err == nil || !strings.Contains(err.Error(), "manager is closed") { t.Fatalf("expected manager closed error, got: %v", err) } }) t.Run("server missing", func(t *testing.T) { mgr := NewManager() _, err := mgr.CallTool(context.Background(), "missing", "tool", nil) if err == nil || !strings.Contains(err.Error(), "not found") { t.Fatalf("expected server not found error, got: %v", err) } }) } func TestConnectServer_StreamableHTTPRequestResponseMode(t *testing.T) { t.Parallel() for _, transportType := range []string{"http", "streamable-http"} { t.Run(transportType, func(t *testing.T) { t.Parallel() server := sdkmcp.NewServer(&sdkmcp.Implementation{ Name: "streamable-test-server", Version: "1.0.0", }, nil) sdkmcp.AddTool(server, &sdkmcp.Tool{ Name: "echo", Description: "Echo test tool", }, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) { return &sdkmcp.CallToolResult{ Content: []sdkmcp.Content{ &sdkmcp.TextContent{Text: "ok"}, }, }, nil, nil }) type observedRequest struct { Method string SessionID string Authorization string } var ( mu sync.Mutex observed []observedRequest ) handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server { return server }, nil) httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() observed = append(observed, observedRequest{ Method: r.Method, SessionID: r.Header.Get("Mcp-Session-Id"), Authorization: r.Header.Get("Authorization"), }) mu.Unlock() handler.ServeHTTP(w, r) })) defer httpServer.Close() conn, err := connectServer(context.Background(), "streamable", config.MCPServerConfig{ Enabled: true, Type: transportType, URL: httpServer.URL, Headers: map[string]string{ "Authorization": "Bearer test-token", }, }) if err != nil { t.Fatalf("connectServer(%q) error = %v", transportType, err) } if got := len(conn.Tools); got != 1 { t.Fatalf("len(conn.Tools) = %d, want 1", got) } if got := conn.Session.ID(); got == "" { t.Fatal("expected non-empty streamable session ID") } if err := conn.Session.Close(); err != nil { t.Fatalf("Session.Close() error = %v", err) } mu.Lock() defer mu.Unlock() var ( getCount int postCount int deleteCount int postWithSession bool deleteWithSession bool requestsWithAuth int requestsWithoutAuth []string ) for _, req := range observed { switch req.Method { case http.MethodGet: getCount++ case http.MethodPost: postCount++ if req.SessionID != "" { postWithSession = true } case http.MethodDelete: deleteCount++ if req.SessionID != "" { deleteWithSession = true } } if req.Authorization == "Bearer test-token" { requestsWithAuth++ } else { requestsWithoutAuth = append(requestsWithoutAuth, req.Method) } } if getCount != 0 { t.Fatalf("expected no standalone GET requests for %q transport, saw %d", transportType, getCount) } if postCount == 0 { t.Fatal("expected POST requests during streamable HTTP handshake") } if deleteCount != 1 { t.Fatalf("DELETE count = %d, want 1", deleteCount) } if !postWithSession { t.Fatal("expected at least one POST request with Mcp-Session-Id") } if !deleteWithSession { t.Fatal("expected DELETE request with Mcp-Session-Id") } if requestsWithAuth != len(observed) { t.Fatalf("Authorization header missing on requests: %v", requestsWithoutAuth) } }) } } func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) { originalConnectServerFunc := connectServerFunc t.Cleanup(func() { connectServerFunc = originalConnectServerFunc }) staleConn, staleTransport, err := newScriptedServerConnection( "session-1", nil, fmt.Errorf(`sending "tools/call": failed to connect (session ID: session-1): %w`, sdkmcp.ErrSessionMissing), ) if err != nil { t.Fatalf("newScriptedServerConnection(stale) error = %v", err) } freshConn, freshTransport, err := newScriptedServerConnection( "session-2", &sdkmcp.CallToolResult{ Content: []sdkmcp.Content{ &sdkmcp.TextContent{Text: "reconnected"}, }, }, nil, ) if err != nil { t.Fatalf("newScriptedServerConnection(fresh) error = %v", err) } connectCalls := 0 connectServerFunc = func(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) { connectCalls++ if connectCalls == 1 { return freshConn, nil } return nil, fmt.Errorf("unexpected reconnect attempt %d", connectCalls) } mgr := NewManager() mgr.servers["flaky"] = staleConn result, err := mgr.CallTool(context.Background(), "flaky", "echo", map[string]any{ "query": "hello", }) if err != nil { t.Fatalf("CallTool() error = %v", err) } if result == nil || len(result.Content) != 1 { t.Fatalf("CallTool() returned unexpected content: %#v", result) } text, ok := result.Content[0].(*sdkmcp.TextContent) if !ok { t.Fatalf("CallTool() content type = %T, want *sdkmcp.TextContent", result.Content[0]) } if text.Text != "reconnected" { t.Fatalf("CallTool() text = %q, want %q", text.Text, "reconnected") } conn, ok := mgr.GetServer("flaky") if !ok { t.Fatal("expected flaky server to remain connected after reconnect") } if conn.Session.ID() != "session-2" { t.Fatalf("Session.ID() = %q, want %q", conn.Session.ID(), "session-2") } if connectCalls != 1 { t.Fatalf("connectCalls = %d, want 1", connectCalls) } if staleTransport.toolCallCalls != 1 { t.Fatalf("stale toolCallCalls = %d, want 1", staleTransport.toolCallCalls) } if freshTransport.toolCallCalls != 1 { t.Fatalf("fresh toolCallCalls = %d, want 1", freshTransport.toolCallCalls) } } func TestClose_IdempotentOnEmptyManager(t *testing.T) { mgr := NewManager() if err := mgr.Close(); err != nil { t.Fatalf("first close should succeed, got: %v", err) } if err := mgr.Close(); err != nil { t.Fatalf("second close should be idempotent, got: %v", err) } } func newScriptedServerConnection( sessionID string, toolCallResult *sdkmcp.CallToolResult, toolCallErr error, ) (*ServerConnection, *scriptedTransport, error) { transport := &scriptedTransport{ sessionID: sessionID, toolCallResult: toolCallResult, toolCallErr: toolCallErr, } client := sdkmcp.NewClient(&sdkmcp.Implementation{ Name: "picoclaw-test", Version: "1.0.0", }, nil) session, err := client.Connect(context.Background(), transport, nil) if err != nil { return nil, nil, err } return &ServerConnection{ Name: "flaky", Config: config.MCPServerConfig{Enabled: true, Type: "http", URL: "https://example.invalid/mcp"}, Client: client, Session: session, Tools: []*sdkmcp.Tool{ { Name: "echo", Description: "Echo test tool", InputSchema: map[string]any{"type": "object"}, }, }, }, transport, nil } type scriptedTransport struct { sessionID string toolCallResult *sdkmcp.CallToolResult toolCallErr error mu sync.Mutex toolCallCalls int closed bool incoming chan jsonrpc.Message } func (t *scriptedTransport) Connect(context.Context) (sdkmcp.Connection, error) { if t.incoming == nil { t.incoming = make(chan jsonrpc.Message, 4) } return t, nil } func (t *scriptedTransport) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() case msg, ok := <-t.incoming: if !ok { return nil, io.EOF } return msg, nil } } func (t *scriptedTransport) Write(ctx context.Context, msg jsonrpc.Message) error { req, ok := msg.(*jsonrpc.Request) if !ok { return nil } switch req.Method { case "initialize": payload, err := json.Marshal(&sdkmcp.InitializeResult{ ProtocolVersion: "2025-11-25", ServerInfo: &sdkmcp.Implementation{ Name: "scripted-test-server", Version: "1.0.0", }, Capabilities: &sdkmcp.ServerCapabilities{ Tools: &sdkmcp.ToolCapabilities{}, }, }) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}: return nil } case "notifications/initialized": return nil case "tools/call": t.mu.Lock() t.toolCallCalls++ t.mu.Unlock() if t.toolCallErr != nil { return t.toolCallErr } payload, err := json.Marshal(t.toolCallResult) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}: return nil } } return fmt.Errorf("unexpected method %q", req.Method) } func (t *scriptedTransport) Close() error { t.mu.Lock() defer t.mu.Unlock() if t.closed { return nil } t.closed = true close(t.incoming) return nil } func (t *scriptedTransport) SessionID() string { return t.sessionID }