Merge pull request #2811 from afjcjsbx/fix/mcp-streamable-http-support

fix(mcp): support streamable HTTP alias, request-response mode and integration tests
This commit is contained in:
Mauro
2026-05-15 12:55:58 +02:00
committed by GitHub
23 changed files with 1417 additions and 42 deletions
+5 -1
View File
@@ -1126,7 +1126,11 @@ type MCPServerConfig struct {
Env map[string]string `json:"env,omitempty"`
// EnvFile is the path to a file containing environment variables (stdio only)
EnvFile string `json:"env_file,omitempty"`
// Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
// Type is "stdio", "sse", "http", or "streamable-http".
// "http" and "streamable-http" both select streamable HTTP request-response
// mode, while "sse" keeps the standalone SSE listener enabled for
// server-initiated notifications. Defaults: stdio if command is set, sse if
// url is set.
Type string `json:"type,omitempty"`
// URL is used for SSE/HTTP transport
URL string `json:"url,omitempty"`
+32
View File
@@ -0,0 +1,32 @@
package config
import "strings"
// NormalizeMCPTransportType canonicalizes MCP transport names used in config.
// "http" is PicoClaw's streamable HTTP request-response mode, and
// "streamable-http" is accepted as an explicit alias for the same transport.
func NormalizeMCPTransportType(transport string) string {
normalized := strings.ToLower(strings.TrimSpace(transport))
switch normalized {
case "streamable-http", "streamable_http", "streamablehttp":
return "http"
default:
return normalized
}
}
// EffectiveMCPTransportType returns the normalized configured transport, or the
// inferred default when the config leaves Type empty.
func EffectiveMCPTransportType(server MCPServerConfig) string {
if transport := NormalizeMCPTransportType(server.Type); transport != "" {
return transport
}
if server.URL != "" {
return "sse"
}
if server.Command != "" {
return "stdio"
}
return ""
}
+1 -10
View File
@@ -79,14 +79,5 @@ func setMCPAttrString(attrs map[string]any, key, value string) {
}
func mcpTransportType(cfg config.MCPServerConfig) string {
if cfg.Type != "" {
return cfg.Type
}
if cfg.URL != "" {
return "sse"
}
if cfg.Command != "" {
return "stdio"
}
return ""
return config.EffectiveMCPTransportType(cfg)
}
+7 -14
View File
@@ -342,17 +342,9 @@ func connectServer(
// Create transport based on configuration
// Auto-detect transport type if not explicitly specified
var transport mcp.Transport
transportType := cfg.Type
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
transportType := config.EffectiveMCPTransportType(cfg)
if transportType == "" {
if cfg.URL != "" {
transportType = "sse"
} else if cfg.Command != "" {
transportType = "stdio"
} else {
return nil, fmt.Errorf("either URL or command must be provided")
}
return nil, fmt.Errorf("either URL or command must be provided")
}
switch transportType {
@@ -362,12 +354,13 @@ func connectServer(
}
// Configure DisableStandaloneSSE based on transport type.
// - "http": Request-response only mode. Disable the standalone SSE stream
// to avoid compatibility issues with servers that don't support GET /mcp.
// - "http": Streamable HTTP request-response mode. Disable the standalone
// SSE stream to avoid compatibility issues with servers that don't
// support the optional GET listener.
// - "sse": Bidirectional mode. Enable the standalone SSE stream to receive
// server-initiated notifications (e.g., ToolListChangedNotification).
// - Empty or auto-detected: Defaults to "sse" behavior (standalone SSE enabled).
disableStandaloneSSE := (cfg.Type == "http")
disableStandaloneSSE := transportType == "http"
logger.DebugCF("mcp", "Using SSE/HTTP transport",
map[string]any{
@@ -452,7 +445,7 @@ func connectServer(
transport = &isolatedCommandTransport{Command: cmd}
default:
return nil, fmt.Errorf(
"unsupported transport type: %s (supported: stdio, sse, http)",
"unsupported transport type: %s (supported: stdio, sse, http, streamable-http)",
transportType,
)
}
+326
View File
@@ -0,0 +1,326 @@
//go:build integration
package mcp
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
)
// Run with: go test -tags=integration ./pkg/mcp
func TestIntegration_StreamableHTTPCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
tests := []struct {
name string
transportType string
jsonResponse bool
rejectStandaloneGET bool
wantResponseContentType string
}{
{
name: "http/json-only-without-get-listener",
transportType: "http",
jsonResponse: true,
rejectStandaloneGET: true,
wantResponseContentType: "application/json",
},
{
name: "http/streaming-post-responses",
transportType: "http",
jsonResponse: false,
rejectStandaloneGET: false,
wantResponseContentType: "text/event-stream",
},
{
name: "streamable-http-alias/json-only-without-get-listener",
transportType: "streamable-http",
jsonResponse: true,
rejectStandaloneGET: true,
wantResponseContentType: "application/json",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server, recorder := newRecordedGoSDKStreamableServer(t, tt.jsonResponse, tt.rejectStandaloneGET)
defer server.Close()
mgr := NewManager()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := mgr.ConnectServer(ctx, "compat", config.MCPServerConfig{
Enabled: true,
Type: tt.transportType,
URL: server.URL,
Headers: map[string]string{
"Authorization": "Bearer integration-token",
},
})
if err != nil {
t.Fatalf("ConnectServer() error = %v", err)
}
tools := mgr.GetAllTools()
if got := len(tools["compat"]); got != 1 {
t.Fatalf("len(GetAllTools()[\"compat\"]) = %d, want 1", got)
}
result, err := mgr.CallTool(ctx, "compat", "echo", map[string]any{
"message": "hello from integration",
})
if err != nil {
t.Fatalf("CallTool() error = %v", err)
}
if got, want := extractTextResult(t, result), "hello from integration"; got != want {
t.Fatalf("CallTool() text = %q, want %q", got, want)
}
if err := mgr.Close(); err != nil {
t.Fatalf("Manager.Close() error = %v", err)
}
assertRecordedCompatibility(t, recorder.snapshot(), tt.wantResponseContentType)
})
}
}
type recordedRequest struct {
Method string
Path string
JSONRPCMethod string
RequestSessionID string
Authorization string
ResponseStatusCode int
ResponseContentType string
}
type requestRecorder struct {
mu sync.Mutex
requests []recordedRequest
}
func (r *requestRecorder) add(req recordedRequest) {
r.mu.Lock()
defer r.mu.Unlock()
r.requests = append(r.requests, req)
}
func (r *requestRecorder) snapshot() []recordedRequest {
r.mu.Lock()
defer r.mu.Unlock()
out := make([]recordedRequest, len(r.requests))
copy(out, r.requests)
return out
}
func newRecordedGoSDKStreamableServer(
t *testing.T,
jsonResponse bool,
rejectStandaloneGET bool,
) (*httptest.Server, *requestRecorder) {
t.Helper()
server := sdkmcp.NewServer(&sdkmcp.Implementation{
Name: "streamable-integration-server",
Version: "1.0.0",
}, nil)
sdkmcp.AddTool(server, &sdkmcp.Tool{
Name: "echo",
Description: "Echo a message",
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) {
message, _ := args["message"].(string)
return &sdkmcp.CallToolResult{
Content: []sdkmcp.Content{
&sdkmcp.TextContent{Text: message},
},
}, nil, nil
})
recorder := &requestRecorder{}
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
return server
}, &sdkmcp.StreamableHTTPOptions{
JSONResponse: jsonResponse,
})
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if rejectStandaloneGET && r.Method == http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
recorder.add(recordedRequest{
Method: r.Method,
Path: r.URL.Path,
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
Authorization: r.Header.Get("Authorization"),
ResponseStatusCode: http.StatusMethodNotAllowed,
ResponseContentType: normalizeContentType(w.Header().Get("Content-Type")),
})
return
}
recorded := recordedRequest{
Method: r.Method,
Path: r.URL.Path,
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
Authorization: r.Header.Get("Authorization"),
}
if r.Method == http.MethodPost {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading request body: %v", err)
}
r.Body = io.NopCloser(bytes.NewReader(body))
var envelope struct {
Method string `json:"method"`
}
if err := json.Unmarshal(body, &envelope); err == nil {
recorded.JSONRPCMethod = envelope.Method
}
}
rw := &recordingResponseWriter{ResponseWriter: w}
handler.ServeHTTP(rw, r)
recorded.ResponseStatusCode = rw.statusCode()
recorded.ResponseContentType = normalizeContentType(rw.Header().Get("Content-Type"))
recorder.add(recorded)
}))
return httpServer, recorder
}
type recordingResponseWriter struct {
http.ResponseWriter
status int
}
func (w *recordingResponseWriter) WriteHeader(status int) {
if w.status == 0 {
w.status = status
}
w.ResponseWriter.WriteHeader(status)
}
func (w *recordingResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
return w.ResponseWriter.Write(p)
}
func (w *recordingResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *recordingResponseWriter) statusCode() int {
if w.status != 0 {
return w.status
}
return http.StatusOK
}
func extractTextResult(t *testing.T, result *sdkmcp.CallToolResult) string {
t.Helper()
if result == nil || len(result.Content) != 1 {
t.Fatalf("unexpected CallToolResult: %#v", result)
}
text, ok := result.Content[0].(*sdkmcp.TextContent)
if !ok {
t.Fatalf("CallToolResult content type = %T, want *sdkmcp.TextContent", result.Content[0])
}
return text.Text
}
func assertRecordedCompatibility(
t *testing.T,
requests []recordedRequest,
wantResponseContentType string,
) {
t.Helper()
var (
getCount int
deleteCount int
postWithSession int
deleteWithSession int
requestsMissingAuth []string
observedContentTypesByMethod = map[string]string{}
)
for _, req := range requests {
switch req.Method {
case http.MethodGet:
getCount++
case http.MethodPost:
if req.RequestSessionID != "" {
postWithSession++
}
if req.JSONRPCMethod != "" && observedContentTypesByMethod[req.JSONRPCMethod] == "" {
observedContentTypesByMethod[req.JSONRPCMethod] = req.ResponseContentType
}
case http.MethodDelete:
deleteCount++
if req.RequestSessionID != "" {
deleteWithSession++
}
}
if req.Authorization != "Bearer integration-token" {
requestsMissingAuth = append(requestsMissingAuth, req.Method+" "+req.Path)
}
}
if getCount != 0 {
t.Fatalf("expected no standalone GET requests for streamable HTTP mode, saw %d", getCount)
}
if deleteCount != 1 {
t.Fatalf("DELETE count = %d, want 1", deleteCount)
}
if postWithSession == 0 {
t.Fatal("expected at least one POST request with Mcp-Session-Id")
}
if deleteWithSession != 1 {
t.Fatalf("expected exactly one DELETE with Mcp-Session-Id, got %d", deleteWithSession)
}
if len(requestsMissingAuth) > 0 {
t.Fatalf("Authorization header missing on requests: %v", requestsMissingAuth)
}
for _, method := range []string{"initialize", "tools/list", "tools/call"} {
if observedContentTypesByMethod[method] == "" {
t.Fatalf("did not observe POST response for JSON-RPC method %q", method)
}
if observedContentTypesByMethod[method] != wantResponseContentType {
t.Fatalf(
"response content-type for %s = %q, want %q",
method,
observedContentTypesByMethod[method],
wantResponseContentType,
)
}
}
}
func normalizeContentType(value string) string {
return strings.TrimSpace(strings.SplitN(value, ";", 2)[0])
}
@@ -0,0 +1,148 @@
//go:build integration
package mcp
import (
"context"
"encoding/json"
"os"
"strconv"
"strings"
"testing"
"time"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
)
// TestIntegration_RealConfiguredServer is an opt-in smoke test for a real MCP
// server configured via environment variables.
//
// Run with:
//
// go test -tags=integration ./pkg/mcp -run TestIntegration_RealConfiguredServer -v
//
// Minimum configuration:
//
// PICOCLAW_MCP_REAL_SERVER_JSON='{"enabled":true,"type":"http","url":"http://127.0.0.1:8080/mcp"}'
//
// Optional tool invocation:
//
// PICOCLAW_MCP_REAL_TOOL_NAME=echo
// PICOCLAW_MCP_REAL_TOOL_ARGS_JSON='{"message":"hello"}'
// PICOCLAW_MCP_REAL_EXPECT_SUBSTRING=hello
//
// Stdio subprocess example:
//
// PICOCLAW_MCP_REAL_SERVER_JSON='{"enabled":true,"type":"stdio","command":"npx","args":["-y","@modelcontextprotocol/server-filesystem","."]}'
func TestIntegration_RealConfiguredServer(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
serverJSON := strings.TrimSpace(os.Getenv("PICOCLAW_MCP_REAL_SERVER_JSON"))
if serverJSON == "" {
t.Skip("skipping integration test (set PICOCLAW_MCP_REAL_SERVER_JSON to enable)")
}
serverCfg, err := loadRealServerConfig(serverJSON)
if err != nil {
t.Fatalf("loadRealServerConfig() error = %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
mgr := NewManager()
if err := mgr.ConnectServer(ctx, "real", serverCfg); err != nil {
t.Fatalf("ConnectServer() error = %v", err)
}
defer func() {
if err := mgr.Close(); err != nil {
t.Errorf("Manager.Close() error = %v", err)
}
}()
tools := mgr.GetAllTools()["real"]
if len(tools) == 0 {
t.Fatal("expected at least one discovered tool from real MCP server")
}
t.Logf("connected to real MCP server via %s with %d tool(s)", config.EffectiveMCPTransportType(serverCfg), len(tools))
for _, tool := range tools {
if tool != nil {
t.Logf("discovered tool: %s", tool.Name)
}
}
if expectedCountRaw := strings.TrimSpace(os.Getenv("PICOCLAW_MCP_REAL_EXPECT_TOOL_COUNT")); expectedCountRaw != "" {
expectedCount, err := strconv.Atoi(expectedCountRaw)
if err != nil {
t.Fatalf("invalid PICOCLAW_MCP_REAL_EXPECT_TOOL_COUNT %q: %v", expectedCountRaw, err)
}
if len(tools) != expectedCount {
t.Fatalf("tool count = %d, want %d", len(tools), expectedCount)
}
}
toolName := strings.TrimSpace(os.Getenv("PICOCLAW_MCP_REAL_TOOL_NAME"))
if toolName == "" {
return
}
toolArgs, err := loadRealToolArgs(os.Getenv("PICOCLAW_MCP_REAL_TOOL_ARGS_JSON"))
if err != nil {
t.Fatalf("loadRealToolArgs() error = %v", err)
}
result, err := mgr.CallTool(ctx, "real", toolName, toolArgs)
if err != nil {
t.Fatalf("CallTool(%q) error = %v", toolName, err)
}
textPayload := joinTextContents(result)
t.Logf("tool %q returned text payload: %q", toolName, textPayload)
if want := os.Getenv("PICOCLAW_MCP_REAL_EXPECT_SUBSTRING"); want != "" && !strings.Contains(textPayload, want) {
t.Fatalf("tool result %q does not contain expected substring %q", textPayload, want)
}
}
func loadRealServerConfig(raw string) (config.MCPServerConfig, error) {
var cfg config.MCPServerConfig
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return config.MCPServerConfig{}, err
}
if !cfg.Enabled {
cfg.Enabled = true
}
return cfg, nil
}
func loadRealToolArgs(raw string) (map[string]any, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return map[string]any{}, nil
}
var args map[string]any
if err := json.Unmarshal([]byte(raw), &args); err != nil {
return nil, err
}
return args, nil
}
func joinTextContents(result *sdkmcp.CallToolResult) string {
if result == nil || len(result.Content) == 0 {
return ""
}
parts := make([]string, 0, len(result.Content))
for _, content := range result.Content {
if text, ok := content.(*sdkmcp.TextContent); ok && text != nil {
parts = append(parts, text.Text)
}
}
return strings.Join(parts, "\n")
}
+129
View File
@@ -5,6 +5,8 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
@@ -408,6 +410,133 @@ func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
})
}
func TestConnectServer_StreamableHTTPRequestResponseMode(t *testing.T) {
t.Parallel()
for _, transportType := range []string{"http", "streamable-http"} {
t.Run(transportType, func(t *testing.T) {
t.Parallel()
server := sdkmcp.NewServer(&sdkmcp.Implementation{
Name: "streamable-test-server",
Version: "1.0.0",
}, nil)
sdkmcp.AddTool(server, &sdkmcp.Tool{
Name: "echo",
Description: "Echo test tool",
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) {
return &sdkmcp.CallToolResult{
Content: []sdkmcp.Content{
&sdkmcp.TextContent{Text: "ok"},
},
}, nil, nil
})
type observedRequest struct {
Method string
SessionID string
Authorization string
}
var (
mu sync.Mutex
observed []observedRequest
)
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
return server
}, nil)
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
observed = append(observed, observedRequest{
Method: r.Method,
SessionID: r.Header.Get("Mcp-Session-Id"),
Authorization: r.Header.Get("Authorization"),
})
mu.Unlock()
handler.ServeHTTP(w, r)
}))
defer httpServer.Close()
conn, err := connectServer(context.Background(), "streamable", config.MCPServerConfig{
Enabled: true,
Type: transportType,
URL: httpServer.URL,
Headers: map[string]string{
"Authorization": "Bearer test-token",
},
})
if err != nil {
t.Fatalf("connectServer(%q) error = %v", transportType, err)
}
if got := len(conn.Tools); got != 1 {
t.Fatalf("len(conn.Tools) = %d, want 1", got)
}
if got := conn.Session.ID(); got == "" {
t.Fatal("expected non-empty streamable session ID")
}
if err := conn.Session.Close(); err != nil {
t.Fatalf("Session.Close() error = %v", err)
}
mu.Lock()
defer mu.Unlock()
var (
getCount int
postCount int
deleteCount int
postWithSession bool
deleteWithSession bool
requestsWithAuth int
requestsWithoutAuth []string
)
for _, req := range observed {
switch req.Method {
case http.MethodGet:
getCount++
case http.MethodPost:
postCount++
if req.SessionID != "" {
postWithSession = true
}
case http.MethodDelete:
deleteCount++
if req.SessionID != "" {
deleteWithSession = true
}
}
if req.Authorization == "Bearer test-token" {
requestsWithAuth++
} else {
requestsWithoutAuth = append(requestsWithoutAuth, req.Method)
}
}
if getCount != 0 {
t.Fatalf("expected no standalone GET requests for %q transport, saw %d", transportType, getCount)
}
if postCount == 0 {
t.Fatal("expected POST requests during streamable HTTP handshake")
}
if deleteCount != 1 {
t.Fatalf("DELETE count = %d, want 1", deleteCount)
}
if !postWithSession {
t.Fatal("expected at least one POST request with Mcp-Session-Id")
}
if !deleteWithSession {
t.Fatal("expected DELETE request with Mcp-Session-Id")
}
if requestsWithAuth != len(observed) {
t.Fatalf("Authorization header missing on requests: %v", requestsWithoutAuth)
}
})
}
}
func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) {
originalConnectServerFunc := connectServerFunc
t.Cleanup(func() {