mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
327 lines
8.4 KiB
Go
327 lines
8.4 KiB
Go
//go:build integration
|
|
|
|
package mcp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
)
|
|
|
|
// Run with: go test -tags=integration ./pkg/mcp
|
|
func TestIntegration_StreamableHTTPCompatibility(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping integration test in short mode")
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
transportType string
|
|
jsonResponse bool
|
|
rejectStandaloneGET bool
|
|
wantResponseContentType string
|
|
}{
|
|
{
|
|
name: "http/json-only-without-get-listener",
|
|
transportType: "http",
|
|
jsonResponse: true,
|
|
rejectStandaloneGET: true,
|
|
wantResponseContentType: "application/json",
|
|
},
|
|
{
|
|
name: "http/streaming-post-responses",
|
|
transportType: "http",
|
|
jsonResponse: false,
|
|
rejectStandaloneGET: false,
|
|
wantResponseContentType: "text/event-stream",
|
|
},
|
|
{
|
|
name: "streamable-http-alias/json-only-without-get-listener",
|
|
transportType: "streamable-http",
|
|
jsonResponse: true,
|
|
rejectStandaloneGET: true,
|
|
wantResponseContentType: "application/json",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server, recorder := newRecordedGoSDKStreamableServer(t, tt.jsonResponse, tt.rejectStandaloneGET)
|
|
defer server.Close()
|
|
|
|
mgr := NewManager()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
err := mgr.ConnectServer(ctx, "compat", config.MCPServerConfig{
|
|
Enabled: true,
|
|
Type: tt.transportType,
|
|
URL: server.URL,
|
|
Headers: map[string]string{
|
|
"Authorization": "Bearer integration-token",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("ConnectServer() error = %v", err)
|
|
}
|
|
|
|
tools := mgr.GetAllTools()
|
|
if got := len(tools["compat"]); got != 1 {
|
|
t.Fatalf("len(GetAllTools()[\"compat\"]) = %d, want 1", got)
|
|
}
|
|
|
|
result, err := mgr.CallTool(ctx, "compat", "echo", map[string]any{
|
|
"message": "hello from integration",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("CallTool() error = %v", err)
|
|
}
|
|
if got, want := extractTextResult(t, result), "hello from integration"; got != want {
|
|
t.Fatalf("CallTool() text = %q, want %q", got, want)
|
|
}
|
|
|
|
if err := mgr.Close(); err != nil {
|
|
t.Fatalf("Manager.Close() error = %v", err)
|
|
}
|
|
|
|
assertRecordedCompatibility(t, recorder.snapshot(), tt.wantResponseContentType)
|
|
})
|
|
}
|
|
}
|
|
|
|
type recordedRequest struct {
|
|
Method string
|
|
Path string
|
|
JSONRPCMethod string
|
|
RequestSessionID string
|
|
Authorization string
|
|
ResponseStatusCode int
|
|
ResponseContentType string
|
|
}
|
|
|
|
type requestRecorder struct {
|
|
mu sync.Mutex
|
|
requests []recordedRequest
|
|
}
|
|
|
|
func (r *requestRecorder) add(req recordedRequest) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.requests = append(r.requests, req)
|
|
}
|
|
|
|
func (r *requestRecorder) snapshot() []recordedRequest {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
out := make([]recordedRequest, len(r.requests))
|
|
copy(out, r.requests)
|
|
return out
|
|
}
|
|
|
|
func newRecordedGoSDKStreamableServer(
|
|
t *testing.T,
|
|
jsonResponse bool,
|
|
rejectStandaloneGET bool,
|
|
) (*httptest.Server, *requestRecorder) {
|
|
t.Helper()
|
|
|
|
server := sdkmcp.NewServer(&sdkmcp.Implementation{
|
|
Name: "streamable-integration-server",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
sdkmcp.AddTool(server, &sdkmcp.Tool{
|
|
Name: "echo",
|
|
Description: "Echo a message",
|
|
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) {
|
|
message, _ := args["message"].(string)
|
|
return &sdkmcp.CallToolResult{
|
|
Content: []sdkmcp.Content{
|
|
&sdkmcp.TextContent{Text: message},
|
|
},
|
|
}, nil, nil
|
|
})
|
|
|
|
recorder := &requestRecorder{}
|
|
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
|
|
return server
|
|
}, &sdkmcp.StreamableHTTPOptions{
|
|
JSONResponse: jsonResponse,
|
|
})
|
|
|
|
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if rejectStandaloneGET && r.Method == http.MethodGet {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
recorder.add(recordedRequest{
|
|
Method: r.Method,
|
|
Path: r.URL.Path,
|
|
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
|
|
Authorization: r.Header.Get("Authorization"),
|
|
ResponseStatusCode: http.StatusMethodNotAllowed,
|
|
ResponseContentType: normalizeContentType(w.Header().Get("Content-Type")),
|
|
})
|
|
return
|
|
}
|
|
|
|
recorded := recordedRequest{
|
|
Method: r.Method,
|
|
Path: r.URL.Path,
|
|
RequestSessionID: r.Header.Get("Mcp-Session-Id"),
|
|
Authorization: r.Header.Get("Authorization"),
|
|
}
|
|
if r.Method == http.MethodPost {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
t.Fatalf("reading request body: %v", err)
|
|
}
|
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
|
|
var envelope struct {
|
|
Method string `json:"method"`
|
|
}
|
|
if err := json.Unmarshal(body, &envelope); err == nil {
|
|
recorded.JSONRPCMethod = envelope.Method
|
|
}
|
|
}
|
|
|
|
rw := &recordingResponseWriter{ResponseWriter: w}
|
|
handler.ServeHTTP(rw, r)
|
|
|
|
recorded.ResponseStatusCode = rw.statusCode()
|
|
recorded.ResponseContentType = normalizeContentType(rw.Header().Get("Content-Type"))
|
|
recorder.add(recorded)
|
|
}))
|
|
|
|
return httpServer, recorder
|
|
}
|
|
|
|
type recordingResponseWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (w *recordingResponseWriter) WriteHeader(status int) {
|
|
if w.status == 0 {
|
|
w.status = status
|
|
}
|
|
w.ResponseWriter.WriteHeader(status)
|
|
}
|
|
|
|
func (w *recordingResponseWriter) Write(p []byte) (int, error) {
|
|
if w.status == 0 {
|
|
w.status = http.StatusOK
|
|
}
|
|
return w.ResponseWriter.Write(p)
|
|
}
|
|
|
|
func (w *recordingResponseWriter) Flush() {
|
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
func (w *recordingResponseWriter) statusCode() int {
|
|
if w.status != 0 {
|
|
return w.status
|
|
}
|
|
return http.StatusOK
|
|
}
|
|
|
|
func extractTextResult(t *testing.T, result *sdkmcp.CallToolResult) string {
|
|
t.Helper()
|
|
if result == nil || len(result.Content) != 1 {
|
|
t.Fatalf("unexpected CallToolResult: %#v", result)
|
|
}
|
|
text, ok := result.Content[0].(*sdkmcp.TextContent)
|
|
if !ok {
|
|
t.Fatalf("CallToolResult content type = %T, want *sdkmcp.TextContent", result.Content[0])
|
|
}
|
|
return text.Text
|
|
}
|
|
|
|
func assertRecordedCompatibility(
|
|
t *testing.T,
|
|
requests []recordedRequest,
|
|
wantResponseContentType string,
|
|
) {
|
|
t.Helper()
|
|
|
|
var (
|
|
getCount int
|
|
deleteCount int
|
|
postWithSession int
|
|
deleteWithSession int
|
|
requestsMissingAuth []string
|
|
observedContentTypesByMethod = map[string]string{}
|
|
)
|
|
|
|
for _, req := range requests {
|
|
switch req.Method {
|
|
case http.MethodGet:
|
|
getCount++
|
|
case http.MethodPost:
|
|
if req.RequestSessionID != "" {
|
|
postWithSession++
|
|
}
|
|
if req.JSONRPCMethod != "" && observedContentTypesByMethod[req.JSONRPCMethod] == "" {
|
|
observedContentTypesByMethod[req.JSONRPCMethod] = req.ResponseContentType
|
|
}
|
|
case http.MethodDelete:
|
|
deleteCount++
|
|
if req.RequestSessionID != "" {
|
|
deleteWithSession++
|
|
}
|
|
}
|
|
|
|
if req.Authorization != "Bearer integration-token" {
|
|
requestsMissingAuth = append(requestsMissingAuth, req.Method+" "+req.Path)
|
|
}
|
|
}
|
|
|
|
if getCount != 0 {
|
|
t.Fatalf("expected no standalone GET requests for streamable HTTP mode, saw %d", getCount)
|
|
}
|
|
if deleteCount != 1 {
|
|
t.Fatalf("DELETE count = %d, want 1", deleteCount)
|
|
}
|
|
if postWithSession == 0 {
|
|
t.Fatal("expected at least one POST request with Mcp-Session-Id")
|
|
}
|
|
if deleteWithSession != 1 {
|
|
t.Fatalf("expected exactly one DELETE with Mcp-Session-Id, got %d", deleteWithSession)
|
|
}
|
|
if len(requestsMissingAuth) > 0 {
|
|
t.Fatalf("Authorization header missing on requests: %v", requestsMissingAuth)
|
|
}
|
|
|
|
for _, method := range []string{"initialize", "tools/list", "tools/call"} {
|
|
if observedContentTypesByMethod[method] == "" {
|
|
t.Fatalf("did not observe POST response for JSON-RPC method %q", method)
|
|
}
|
|
if observedContentTypesByMethod[method] != wantResponseContentType {
|
|
t.Fatalf(
|
|
"response content-type for %s = %q, want %q",
|
|
method,
|
|
observedContentTypesByMethod[method],
|
|
wantResponseContentType,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
func normalizeContentType(value string) string {
|
|
return strings.TrimSpace(strings.SplitN(value, ";", 2)[0])
|
|
}
|