mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
Merge pull request #2811 from afjcjsbx/fix/mcp-streamable-http-support
fix(mcp): support streamable HTTP alias, request-response mode and integration tests
This commit is contained in:
@@ -1126,7 +1126,11 @@ type MCPServerConfig struct {
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
// EnvFile is the path to a file containing environment variables (stdio only)
|
||||
EnvFile string `json:"env_file,omitempty"`
|
||||
// Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
|
||||
// Type is "stdio", "sse", "http", or "streamable-http".
|
||||
// "http" and "streamable-http" both select streamable HTTP request-response
|
||||
// mode, while "sse" keeps the standalone SSE listener enabled for
|
||||
// server-initiated notifications. Defaults: stdio if command is set, sse if
|
||||
// url is set.
|
||||
Type string `json:"type,omitempty"`
|
||||
// URL is used for SSE/HTTP transport
|
||||
URL string `json:"url,omitempty"`
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
// NormalizeMCPTransportType canonicalizes MCP transport names used in config.
|
||||
// "http" is PicoClaw's streamable HTTP request-response mode, and
|
||||
// "streamable-http" is accepted as an explicit alias for the same transport.
|
||||
func NormalizeMCPTransportType(transport string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(transport))
|
||||
|
||||
switch normalized {
|
||||
case "streamable-http", "streamable_http", "streamablehttp":
|
||||
return "http"
|
||||
default:
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
// EffectiveMCPTransportType returns the normalized configured transport, or the
|
||||
// inferred default when the config leaves Type empty.
|
||||
func EffectiveMCPTransportType(server MCPServerConfig) string {
|
||||
if transport := NormalizeMCPTransportType(server.Type); transport != "" {
|
||||
return transport
|
||||
}
|
||||
if server.URL != "" {
|
||||
return "sse"
|
||||
}
|
||||
if server.Command != "" {
|
||||
return "stdio"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
+1
-10
@@ -79,14 +79,5 @@ func setMCPAttrString(attrs map[string]any, key, value string) {
|
||||
}
|
||||
|
||||
func mcpTransportType(cfg config.MCPServerConfig) string {
|
||||
if cfg.Type != "" {
|
||||
return cfg.Type
|
||||
}
|
||||
if cfg.URL != "" {
|
||||
return "sse"
|
||||
}
|
||||
if cfg.Command != "" {
|
||||
return "stdio"
|
||||
}
|
||||
return ""
|
||||
return config.EffectiveMCPTransportType(cfg)
|
||||
}
|
||||
|
||||
+7
-14
@@ -342,17 +342,9 @@ func connectServer(
|
||||
// Create transport based on configuration
|
||||
// Auto-detect transport type if not explicitly specified
|
||||
var transport mcp.Transport
|
||||
transportType := cfg.Type
|
||||
|
||||
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
|
||||
transportType := config.EffectiveMCPTransportType(cfg)
|
||||
if transportType == "" {
|
||||
if cfg.URL != "" {
|
||||
transportType = "sse"
|
||||
} else if cfg.Command != "" {
|
||||
transportType = "stdio"
|
||||
} else {
|
||||
return nil, fmt.Errorf("either URL or command must be provided")
|
||||
}
|
||||
return nil, fmt.Errorf("either URL or command must be provided")
|
||||
}
|
||||
|
||||
switch transportType {
|
||||
@@ -362,12 +354,13 @@ func connectServer(
|
||||
}
|
||||
|
||||
// Configure DisableStandaloneSSE based on transport type.
|
||||
// - "http": Request-response only mode. Disable the standalone SSE stream
|
||||
// to avoid compatibility issues with servers that don't support GET /mcp.
|
||||
// - "http": Streamable HTTP request-response mode. Disable the standalone
|
||||
// SSE stream to avoid compatibility issues with servers that don't
|
||||
// support the optional GET listener.
|
||||
// - "sse": Bidirectional mode. Enable the standalone SSE stream to receive
|
||||
// server-initiated notifications (e.g., ToolListChangedNotification).
|
||||
// - Empty or auto-detected: Defaults to "sse" behavior (standalone SSE enabled).
|
||||
disableStandaloneSSE := (cfg.Type == "http")
|
||||
disableStandaloneSSE := transportType == "http"
|
||||
|
||||
logger.DebugCF("mcp", "Using SSE/HTTP transport",
|
||||
map[string]any{
|
||||
@@ -452,7 +445,7 @@ func connectServer(
|
||||
transport = &isolatedCommandTransport{Command: cmd}
|
||||
default:
|
||||
return nil, fmt.Errorf(
|
||||
"unsupported transport type: %s (supported: stdio, sse, http)",
|
||||
"unsupported transport type: %s (supported: stdio, sse, http, streamable-http)",
|
||||
transportType,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
//go:build integration
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// Run with: go test -tags=integration ./pkg/mcp
|
||||
func TestIntegration_StreamableHTTPCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
transportType string
|
||||
jsonResponse bool
|
||||
rejectStandaloneGET bool
|
||||
wantResponseContentType string
|
||||
}{
|
||||
{
|
||||
name: "http/json-only-without-get-listener",
|
||||
transportType: "http",
|
||||
jsonResponse: true,
|
||||
rejectStandaloneGET: true,
|
||||
wantResponseContentType: "application/json",
|
||||
},
|
||||
{
|
||||
name: "http/streaming-post-responses",
|
||||
transportType: "http",
|
||||
jsonResponse: false,
|
||||
rejectStandaloneGET: false,
|
||||
wantResponseContentType: "text/event-stream",
|
||||
},
|
||||
{
|
||||
name: "streamable-http-alias/json-only-without-get-listener",
|
||||
transportType: "streamable-http",
|
||||
jsonResponse: true,
|
||||
rejectStandaloneGET: true,
|
||||
wantResponseContentType: "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, recorder := newRecordedGoSDKStreamableServer(t, tt.jsonResponse, tt.rejectStandaloneGET)
|
||||
defer server.Close()
|
||||
|
||||
mgr := NewManager()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := mgr.ConnectServer(ctx, "compat", config.MCPServerConfig{
|
||||
Enabled: true,
|
||||
Type: tt.transportType,
|
||||
URL: server.URL,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer integration-token",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConnectServer() error = %v", err)
|
||||
}
|
||||
|
||||
tools := mgr.GetAllTools()
|
||||
if got := len(tools["compat"]); got != 1 {
|
||||
t.Fatalf("len(GetAllTools()[\"compat\"]) = %d, want 1", got)
|
||||
}
|
||||
|
||||
result, err := mgr.CallTool(ctx, "compat", "echo", map[string]any{
|
||||
"message": "hello from integration",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool() error = %v", err)
|
||||
}
|
||||
if got, want := extractTextResult(t, result), "hello from integration"; got != want {
|
||||
t.Fatalf("CallTool() text = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if err := mgr.Close(); err != nil {
|
||||
t.Fatalf("Manager.Close() error = %v", err)
|
||||
}
|
||||
|
||||
assertRecordedCompatibility(t, recorder.snapshot(), tt.wantResponseContentType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type recordedRequest struct {
|
||||
Method string
|
||||
Path string
|
||||
JSONRPCMethod string
|
||||
RequestSessionID string
|
||||
Authorization string
|
||||
ResponseStatusCode int
|
||||
ResponseContentType string
|
||||
}
|
||||
|
||||
type requestRecorder struct {
|
||||
mu sync.Mutex
|
||||
requests []recordedRequest
|
||||
}
|
||||
|
||||
func (r *requestRecorder) add(req recordedRequest) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.requests = append(r.requests, req)
|
||||
}
|
||||
|
||||
func (r *requestRecorder) snapshot() []recordedRequest {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]recordedRequest, len(r.requests))
|
||||
copy(out, r.requests)
|
||||
return out
|
||||
}
|
||||
|
||||
func newRecordedGoSDKStreamableServer(
|
||||
t *testing.T,
|
||||
jsonResponse bool,
|
||||
rejectStandaloneGET bool,
|
||||
) (*httptest.Server, *requestRecorder) {
|
||||
t.Helper()
|
||||
|
||||
server := sdkmcp.NewServer(&sdkmcp.Implementation{
|
||||
Name: "streamable-integration-server",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echo a message",
|
||||
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) {
|
||||
message, _ := args["message"].(string)
|
||||
return &sdkmcp.CallToolResult{
|
||||
Content: []sdkmcp.Content{
|
||||
&sdkmcp.TextContent{Text: message},
|
||||
},
|
||||
}, nil, nil
|
||||
})
|
||||
|
||||
recorder := &requestRecorder{}
|
||||
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
|
||||
return server
|
||||
}, &sdkmcp.StreamableHTTPOptions{
|
||||
JSONResponse: jsonResponse,
|
||||
})
|
||||
|
||||
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if rejectStandaloneGET && r.Method == http.MethodGet {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
recorder.add(recordedRequest{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
|
||||
Authorization: r.Header.Get("Authorization"),
|
||||
ResponseStatusCode: http.StatusMethodNotAllowed,
|
||||
ResponseContentType: normalizeContentType(w.Header().Get("Content-Type")),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
recorded := recordedRequest{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
|
||||
Authorization: r.Header.Get("Authorization"),
|
||||
}
|
||||
if r.Method == http.MethodPost {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading request body: %v", err)
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
var envelope struct {
|
||||
Method string `json:"method"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &envelope); err == nil {
|
||||
recorded.JSONRPCMethod = envelope.Method
|
||||
}
|
||||
}
|
||||
|
||||
rw := &recordingResponseWriter{ResponseWriter: w}
|
||||
handler.ServeHTTP(rw, r)
|
||||
|
||||
recorded.ResponseStatusCode = rw.statusCode()
|
||||
recorded.ResponseContentType = normalizeContentType(rw.Header().Get("Content-Type"))
|
||||
recorder.add(recorded)
|
||||
}))
|
||||
|
||||
return httpServer, recorder
|
||||
}
|
||||
|
||||
type recordingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *recordingResponseWriter) WriteHeader(status int) {
|
||||
if w.status == 0 {
|
||||
w.status = status
|
||||
}
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *recordingResponseWriter) Write(p []byte) (int, error) {
|
||||
if w.status == 0 {
|
||||
w.status = http.StatusOK
|
||||
}
|
||||
return w.ResponseWriter.Write(p)
|
||||
}
|
||||
|
||||
func (w *recordingResponseWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *recordingResponseWriter) statusCode() int {
|
||||
if w.status != 0 {
|
||||
return w.status
|
||||
}
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func extractTextResult(t *testing.T, result *sdkmcp.CallToolResult) string {
|
||||
t.Helper()
|
||||
if result == nil || len(result.Content) != 1 {
|
||||
t.Fatalf("unexpected CallToolResult: %#v", result)
|
||||
}
|
||||
text, ok := result.Content[0].(*sdkmcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("CallToolResult content type = %T, want *sdkmcp.TextContent", result.Content[0])
|
||||
}
|
||||
return text.Text
|
||||
}
|
||||
|
||||
func assertRecordedCompatibility(
|
||||
t *testing.T,
|
||||
requests []recordedRequest,
|
||||
wantResponseContentType string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
var (
|
||||
getCount int
|
||||
deleteCount int
|
||||
postWithSession int
|
||||
deleteWithSession int
|
||||
requestsMissingAuth []string
|
||||
observedContentTypesByMethod = map[string]string{}
|
||||
)
|
||||
|
||||
for _, req := range requests {
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
getCount++
|
||||
case http.MethodPost:
|
||||
if req.RequestSessionID != "" {
|
||||
postWithSession++
|
||||
}
|
||||
if req.JSONRPCMethod != "" && observedContentTypesByMethod[req.JSONRPCMethod] == "" {
|
||||
observedContentTypesByMethod[req.JSONRPCMethod] = req.ResponseContentType
|
||||
}
|
||||
case http.MethodDelete:
|
||||
deleteCount++
|
||||
if req.RequestSessionID != "" {
|
||||
deleteWithSession++
|
||||
}
|
||||
}
|
||||
|
||||
if req.Authorization != "Bearer integration-token" {
|
||||
requestsMissingAuth = append(requestsMissingAuth, req.Method+" "+req.Path)
|
||||
}
|
||||
}
|
||||
|
||||
if getCount != 0 {
|
||||
t.Fatalf("expected no standalone GET requests for streamable HTTP mode, saw %d", getCount)
|
||||
}
|
||||
if deleteCount != 1 {
|
||||
t.Fatalf("DELETE count = %d, want 1", deleteCount)
|
||||
}
|
||||
if postWithSession == 0 {
|
||||
t.Fatal("expected at least one POST request with Mcp-Session-Id")
|
||||
}
|
||||
if deleteWithSession != 1 {
|
||||
t.Fatalf("expected exactly one DELETE with Mcp-Session-Id, got %d", deleteWithSession)
|
||||
}
|
||||
if len(requestsMissingAuth) > 0 {
|
||||
t.Fatalf("Authorization header missing on requests: %v", requestsMissingAuth)
|
||||
}
|
||||
|
||||
for _, method := range []string{"initialize", "tools/list", "tools/call"} {
|
||||
if observedContentTypesByMethod[method] == "" {
|
||||
t.Fatalf("did not observe POST response for JSON-RPC method %q", method)
|
||||
}
|
||||
if observedContentTypesByMethod[method] != wantResponseContentType {
|
||||
t.Fatalf(
|
||||
"response content-type for %s = %q, want %q",
|
||||
method,
|
||||
observedContentTypesByMethod[method],
|
||||
wantResponseContentType,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeContentType(value string) string {
|
||||
return strings.TrimSpace(strings.SplitN(value, ";", 2)[0])
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
//go:build integration
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// TestIntegration_RealConfiguredServer is an opt-in smoke test for a real MCP
|
||||
// server configured via environment variables.
|
||||
//
|
||||
// Run with:
|
||||
//
|
||||
// go test -tags=integration ./pkg/mcp -run TestIntegration_RealConfiguredServer -v
|
||||
//
|
||||
// Minimum configuration:
|
||||
//
|
||||
// PICOCLAW_MCP_REAL_SERVER_JSON='{"enabled":true,"type":"http","url":"http://127.0.0.1:8080/mcp"}'
|
||||
//
|
||||
// Optional tool invocation:
|
||||
//
|
||||
// PICOCLAW_MCP_REAL_TOOL_NAME=echo
|
||||
// PICOCLAW_MCP_REAL_TOOL_ARGS_JSON='{"message":"hello"}'
|
||||
// PICOCLAW_MCP_REAL_EXPECT_SUBSTRING=hello
|
||||
//
|
||||
// Stdio subprocess example:
|
||||
//
|
||||
// PICOCLAW_MCP_REAL_SERVER_JSON='{"enabled":true,"type":"stdio","command":"npx","args":["-y","@modelcontextprotocol/server-filesystem","."]}'
|
||||
func TestIntegration_RealConfiguredServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
serverJSON := strings.TrimSpace(os.Getenv("PICOCLAW_MCP_REAL_SERVER_JSON"))
|
||||
if serverJSON == "" {
|
||||
t.Skip("skipping integration test (set PICOCLAW_MCP_REAL_SERVER_JSON to enable)")
|
||||
}
|
||||
|
||||
serverCfg, err := loadRealServerConfig(serverJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("loadRealServerConfig() error = %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mgr := NewManager()
|
||||
if err := mgr.ConnectServer(ctx, "real", serverCfg); err != nil {
|
||||
t.Fatalf("ConnectServer() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := mgr.Close(); err != nil {
|
||||
t.Errorf("Manager.Close() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tools := mgr.GetAllTools()["real"]
|
||||
if len(tools) == 0 {
|
||||
t.Fatal("expected at least one discovered tool from real MCP server")
|
||||
}
|
||||
|
||||
t.Logf("connected to real MCP server via %s with %d tool(s)", config.EffectiveMCPTransportType(serverCfg), len(tools))
|
||||
for _, tool := range tools {
|
||||
if tool != nil {
|
||||
t.Logf("discovered tool: %s", tool.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if expectedCountRaw := strings.TrimSpace(os.Getenv("PICOCLAW_MCP_REAL_EXPECT_TOOL_COUNT")); expectedCountRaw != "" {
|
||||
expectedCount, err := strconv.Atoi(expectedCountRaw)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid PICOCLAW_MCP_REAL_EXPECT_TOOL_COUNT %q: %v", expectedCountRaw, err)
|
||||
}
|
||||
if len(tools) != expectedCount {
|
||||
t.Fatalf("tool count = %d, want %d", len(tools), expectedCount)
|
||||
}
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(os.Getenv("PICOCLAW_MCP_REAL_TOOL_NAME"))
|
||||
if toolName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
toolArgs, err := loadRealToolArgs(os.Getenv("PICOCLAW_MCP_REAL_TOOL_ARGS_JSON"))
|
||||
if err != nil {
|
||||
t.Fatalf("loadRealToolArgs() error = %v", err)
|
||||
}
|
||||
|
||||
result, err := mgr.CallTool(ctx, "real", toolName, toolArgs)
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool(%q) error = %v", toolName, err)
|
||||
}
|
||||
|
||||
textPayload := joinTextContents(result)
|
||||
t.Logf("tool %q returned text payload: %q", toolName, textPayload)
|
||||
|
||||
if want := os.Getenv("PICOCLAW_MCP_REAL_EXPECT_SUBSTRING"); want != "" && !strings.Contains(textPayload, want) {
|
||||
t.Fatalf("tool result %q does not contain expected substring %q", textPayload, want)
|
||||
}
|
||||
}
|
||||
|
||||
func loadRealServerConfig(raw string) (config.MCPServerConfig, error) {
|
||||
var cfg config.MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return config.MCPServerConfig{}, err
|
||||
}
|
||||
if !cfg.Enabled {
|
||||
cfg.Enabled = true
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func loadRealToolArgs(raw string) (map[string]any, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func joinTextContents(result *sdkmcp.CallToolResult) string {
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(result.Content))
|
||||
for _, content := range result.Content {
|
||||
if text, ok := content.(*sdkmcp.TextContent); ok && text != nil {
|
||||
parts = append(parts, text.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -408,6 +410,133 @@ func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectServer_StreamableHTTPRequestResponseMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, transportType := range []string{"http", "streamable-http"} {
|
||||
t.Run(transportType, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := sdkmcp.NewServer(&sdkmcp.Implementation{
|
||||
Name: "streamable-test-server",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echo test tool",
|
||||
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) {
|
||||
return &sdkmcp.CallToolResult{
|
||||
Content: []sdkmcp.Content{
|
||||
&sdkmcp.TextContent{Text: "ok"},
|
||||
},
|
||||
}, nil, nil
|
||||
})
|
||||
|
||||
type observedRequest struct {
|
||||
Method string
|
||||
SessionID string
|
||||
Authorization string
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
observed []observedRequest
|
||||
)
|
||||
|
||||
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
|
||||
return server
|
||||
}, nil)
|
||||
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
observed = append(observed, observedRequest{
|
||||
Method: r.Method,
|
||||
SessionID: r.Header.Get("Mcp-Session-Id"),
|
||||
Authorization: r.Header.Get("Authorization"),
|
||||
})
|
||||
mu.Unlock()
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
defer httpServer.Close()
|
||||
|
||||
conn, err := connectServer(context.Background(), "streamable", config.MCPServerConfig{
|
||||
Enabled: true,
|
||||
Type: transportType,
|
||||
URL: httpServer.URL,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test-token",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("connectServer(%q) error = %v", transportType, err)
|
||||
}
|
||||
if got := len(conn.Tools); got != 1 {
|
||||
t.Fatalf("len(conn.Tools) = %d, want 1", got)
|
||||
}
|
||||
if got := conn.Session.ID(); got == "" {
|
||||
t.Fatal("expected non-empty streamable session ID")
|
||||
}
|
||||
if err := conn.Session.Close(); err != nil {
|
||||
t.Fatalf("Session.Close() error = %v", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
var (
|
||||
getCount int
|
||||
postCount int
|
||||
deleteCount int
|
||||
postWithSession bool
|
||||
deleteWithSession bool
|
||||
requestsWithAuth int
|
||||
requestsWithoutAuth []string
|
||||
)
|
||||
|
||||
for _, req := range observed {
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
getCount++
|
||||
case http.MethodPost:
|
||||
postCount++
|
||||
if req.SessionID != "" {
|
||||
postWithSession = true
|
||||
}
|
||||
case http.MethodDelete:
|
||||
deleteCount++
|
||||
if req.SessionID != "" {
|
||||
deleteWithSession = true
|
||||
}
|
||||
}
|
||||
|
||||
if req.Authorization == "Bearer test-token" {
|
||||
requestsWithAuth++
|
||||
} else {
|
||||
requestsWithoutAuth = append(requestsWithoutAuth, req.Method)
|
||||
}
|
||||
}
|
||||
|
||||
if getCount != 0 {
|
||||
t.Fatalf("expected no standalone GET requests for %q transport, saw %d", transportType, getCount)
|
||||
}
|
||||
if postCount == 0 {
|
||||
t.Fatal("expected POST requests during streamable HTTP handshake")
|
||||
}
|
||||
if deleteCount != 1 {
|
||||
t.Fatalf("DELETE count = %d, want 1", deleteCount)
|
||||
}
|
||||
if !postWithSession {
|
||||
t.Fatal("expected at least one POST request with Mcp-Session-Id")
|
||||
}
|
||||
if !deleteWithSession {
|
||||
t.Fatal("expected DELETE request with Mcp-Session-Id")
|
||||
}
|
||||
if requestsWithAuth != len(observed) {
|
||||
t.Fatalf("Authorization header missing on requests: %v", requestsWithoutAuth)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) {
|
||||
originalConnectServerFunc := connectServerFunc
|
||||
t.Cleanup(func() {
|
||||
|
||||
Reference in New Issue
Block a user