Files
picoclaw/pkg/mcp/manager_integration_test.go
T

327 lines
8.4 KiB
Go

//go:build integration
package mcp
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
)
// Run with: go test -tags=integration ./pkg/mcp
func TestIntegration_StreamableHTTPCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
tests := []struct {
name string
transportType string
jsonResponse bool
rejectStandaloneGET bool
wantResponseContentType string
}{
{
name: "http/json-only-without-get-listener",
transportType: "http",
jsonResponse: true,
rejectStandaloneGET: true,
wantResponseContentType: "application/json",
},
{
name: "http/streaming-post-responses",
transportType: "http",
jsonResponse: false,
rejectStandaloneGET: false,
wantResponseContentType: "text/event-stream",
},
{
name: "streamable-http-alias/json-only-without-get-listener",
transportType: "streamable-http",
jsonResponse: true,
rejectStandaloneGET: true,
wantResponseContentType: "application/json",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server, recorder := newRecordedGoSDKStreamableServer(t, tt.jsonResponse, tt.rejectStandaloneGET)
defer server.Close()
mgr := NewManager()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := mgr.ConnectServer(ctx, "compat", config.MCPServerConfig{
Enabled: true,
Type: tt.transportType,
URL: server.URL,
Headers: map[string]string{
"Authorization": "Bearer integration-token",
},
})
if err != nil {
t.Fatalf("ConnectServer() error = %v", err)
}
tools := mgr.GetAllTools()
if got := len(tools["compat"]); got != 1 {
t.Fatalf("len(GetAllTools()[\"compat\"]) = %d, want 1", got)
}
result, err := mgr.CallTool(ctx, "compat", "echo", map[string]any{
"message": "hello from integration",
})
if err != nil {
t.Fatalf("CallTool() error = %v", err)
}
if got, want := extractTextResult(t, result), "hello from integration"; got != want {
t.Fatalf("CallTool() text = %q, want %q", got, want)
}
if err := mgr.Close(); err != nil {
t.Fatalf("Manager.Close() error = %v", err)
}
assertRecordedCompatibility(t, recorder.snapshot(), tt.wantResponseContentType)
})
}
}
type recordedRequest struct {
Method string
Path string
JSONRPCMethod string
RequestSessionID string
Authorization string
ResponseStatusCode int
ResponseContentType string
}
type requestRecorder struct {
mu sync.Mutex
requests []recordedRequest
}
func (r *requestRecorder) add(req recordedRequest) {
r.mu.Lock()
defer r.mu.Unlock()
r.requests = append(r.requests, req)
}
func (r *requestRecorder) snapshot() []recordedRequest {
r.mu.Lock()
defer r.mu.Unlock()
out := make([]recordedRequest, len(r.requests))
copy(out, r.requests)
return out
}
func newRecordedGoSDKStreamableServer(
t *testing.T,
jsonResponse bool,
rejectStandaloneGET bool,
) (*httptest.Server, *requestRecorder) {
t.Helper()
server := sdkmcp.NewServer(&sdkmcp.Implementation{
Name: "streamable-integration-server",
Version: "1.0.0",
}, nil)
sdkmcp.AddTool(server, &sdkmcp.Tool{
Name: "echo",
Description: "Echo a message",
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) {
message, _ := args["message"].(string)
return &sdkmcp.CallToolResult{
Content: []sdkmcp.Content{
&sdkmcp.TextContent{Text: message},
},
}, nil, nil
})
recorder := &requestRecorder{}
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
return server
}, &sdkmcp.StreamableHTTPOptions{
JSONResponse: jsonResponse,
})
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if rejectStandaloneGET && r.Method == http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
recorder.add(recordedRequest{
Method: r.Method,
Path: r.URL.Path,
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
Authorization: r.Header.Get("Authorization"),
ResponseStatusCode: http.StatusMethodNotAllowed,
ResponseContentType: normalizeContentType(w.Header().Get("Content-Type")),
})
return
}
recorded := recordedRequest{
Method: r.Method,
Path: r.URL.Path,
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
Authorization: r.Header.Get("Authorization"),
}
if r.Method == http.MethodPost {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("reading request body: %v", err)
}
r.Body = io.NopCloser(bytes.NewReader(body))
var envelope struct {
Method string `json:"method"`
}
if err := json.Unmarshal(body, &envelope); err == nil {
recorded.JSONRPCMethod = envelope.Method
}
}
rw := &recordingResponseWriter{ResponseWriter: w}
handler.ServeHTTP(rw, r)
recorded.ResponseStatusCode = rw.statusCode()
recorded.ResponseContentType = normalizeContentType(rw.Header().Get("Content-Type"))
recorder.add(recorded)
}))
return httpServer, recorder
}
type recordingResponseWriter struct {
http.ResponseWriter
status int
}
func (w *recordingResponseWriter) WriteHeader(status int) {
if w.status == 0 {
w.status = status
}
w.ResponseWriter.WriteHeader(status)
}
func (w *recordingResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
return w.ResponseWriter.Write(p)
}
func (w *recordingResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *recordingResponseWriter) statusCode() int {
if w.status != 0 {
return w.status
}
return http.StatusOK
}
func extractTextResult(t *testing.T, result *sdkmcp.CallToolResult) string {
t.Helper()
if result == nil || len(result.Content) != 1 {
t.Fatalf("unexpected CallToolResult: %#v", result)
}
text, ok := result.Content[0].(*sdkmcp.TextContent)
if !ok {
t.Fatalf("CallToolResult content type = %T, want *sdkmcp.TextContent", result.Content[0])
}
return text.Text
}
func assertRecordedCompatibility(
t *testing.T,
requests []recordedRequest,
wantResponseContentType string,
) {
t.Helper()
var (
getCount int
deleteCount int
postWithSession int
deleteWithSession int
requestsMissingAuth []string
observedContentTypesByMethod = map[string]string{}
)
for _, req := range requests {
switch req.Method {
case http.MethodGet:
getCount++
case http.MethodPost:
if req.RequestSessionID != "" {
postWithSession++
}
if req.JSONRPCMethod != "" && observedContentTypesByMethod[req.JSONRPCMethod] == "" {
observedContentTypesByMethod[req.JSONRPCMethod] = req.ResponseContentType
}
case http.MethodDelete:
deleteCount++
if req.RequestSessionID != "" {
deleteWithSession++
}
}
if req.Authorization != "Bearer integration-token" {
requestsMissingAuth = append(requestsMissingAuth, req.Method+" "+req.Path)
}
}
if getCount != 0 {
t.Fatalf("expected no standalone GET requests for streamable HTTP mode, saw %d", getCount)
}
if deleteCount != 1 {
t.Fatalf("DELETE count = %d, want 1", deleteCount)
}
if postWithSession == 0 {
t.Fatal("expected at least one POST request with Mcp-Session-Id")
}
if deleteWithSession != 1 {
t.Fatalf("expected exactly one DELETE with Mcp-Session-Id, got %d", deleteWithSession)
}
if len(requestsMissingAuth) > 0 {
t.Fatalf("Authorization header missing on requests: %v", requestsMissingAuth)
}
for _, method := range []string{"initialize", "tools/list", "tools/call"} {
if observedContentTypesByMethod[method] == "" {
t.Fatalf("did not observe POST response for JSON-RPC method %q", method)
}
if observedContentTypesByMethod[method] != wantResponseContentType {
t.Fatalf(
"response content-type for %s = %q, want %q",
method,
observedContentTypesByMethod[method],
wantResponseContentType,
)
}
}
}
func normalizeContentType(value string) string {
return strings.TrimSpace(strings.SplitN(value, ";", 2)[0])
}