mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
758 lines
19 KiB
Go
758 lines
19 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
|
)
|
|
|
|
func TestLoadEnvFile(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
content string
|
|
expected map[string]string
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "basic env file",
|
|
content: `API_KEY=secret123
|
|
DATABASE_URL=postgres://localhost/db
|
|
PORT=8080`,
|
|
expected: map[string]string{
|
|
"API_KEY": "secret123",
|
|
"DATABASE_URL": "postgres://localhost/db",
|
|
"PORT": "8080",
|
|
},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "with comments and empty lines",
|
|
content: `# This is a comment
|
|
API_KEY=secret123
|
|
|
|
# Another comment
|
|
DATABASE_URL=postgres://localhost/db
|
|
|
|
PORT=8080`,
|
|
expected: map[string]string{
|
|
"API_KEY": "secret123",
|
|
"DATABASE_URL": "postgres://localhost/db",
|
|
"PORT": "8080",
|
|
},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "with quoted values",
|
|
content: `API_KEY="secret with spaces"
|
|
NAME='single quoted'
|
|
PLAIN=no-quotes`,
|
|
expected: map[string]string{
|
|
"API_KEY": "secret with spaces",
|
|
"NAME": "single quoted",
|
|
"PLAIN": "no-quotes",
|
|
},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "with spaces around equals",
|
|
content: `API_KEY = secret123
|
|
DATABASE_URL= postgres://localhost/db
|
|
PORT =8080`,
|
|
expected: map[string]string{
|
|
"API_KEY": "secret123",
|
|
"DATABASE_URL": "postgres://localhost/db",
|
|
"PORT": "8080",
|
|
},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "invalid format - no equals",
|
|
content: `INVALID_LINE`,
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "empty file",
|
|
content: ``,
|
|
expected: map[string]string{},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "only comments",
|
|
content: `# Comment 1
|
|
# Comment 2`,
|
|
expected: map[string]string{},
|
|
expectErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
envFile := filepath.Join(tmpDir, ".env")
|
|
|
|
if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil {
|
|
t.Fatalf("Failed to create test file: %v", err)
|
|
}
|
|
|
|
result, err := loadEnvFile(envFile)
|
|
|
|
if tt.expectErr {
|
|
if err == nil {
|
|
t.Errorf("Expected error but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if len(result) != len(tt.expected) {
|
|
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
|
|
}
|
|
|
|
for key, expectedValue := range tt.expected {
|
|
if actualValue, ok := result[key]; !ok {
|
|
t.Errorf("Expected key %s not found", key)
|
|
} else if actualValue != expectedValue {
|
|
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLoadEnvFileNotFound(t *testing.T) {
|
|
_, err := loadEnvFile("/nonexistent/file.env")
|
|
if err == nil {
|
|
t.Error("Expected error for nonexistent file")
|
|
}
|
|
}
|
|
|
|
func TestExpandHomeCommandPath(t *testing.T) {
|
|
homeDir := t.TempDir()
|
|
t.Setenv("HOME", homeDir)
|
|
t.Setenv("USERPROFILE", homeDir)
|
|
|
|
want := filepath.Join(homeDir, "bin", "my-mcp")
|
|
got := expandHomeCommandPath("~" + string(os.PathSeparator) + filepath.Join("bin", "my-mcp"))
|
|
if got != want {
|
|
t.Fatalf("expandHomeCommandPath() = %q, want %q", got, want)
|
|
}
|
|
|
|
if got := expandHomeCommandPath("npx"); got != "npx" {
|
|
t.Fatalf("expandHomeCommandPath() should leave bare commands unchanged, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestEnvFilePriority(t *testing.T) {
|
|
// Create a temporary .env file
|
|
tmpDir := t.TempDir()
|
|
envFile := filepath.Join(tmpDir, ".env")
|
|
|
|
envContent := `API_KEY=from_file
|
|
DATABASE_URL=from_file
|
|
SHARED_VAR=from_file`
|
|
|
|
if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil {
|
|
t.Fatalf("Failed to create .env file: %v", err)
|
|
}
|
|
|
|
// Load envFile
|
|
envVars, err := loadEnvFile(envFile)
|
|
if err != nil {
|
|
t.Fatalf("Failed to load env file: %v", err)
|
|
}
|
|
|
|
// Verify envFile variables
|
|
if envVars["API_KEY"] != "from_file" {
|
|
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
|
|
}
|
|
|
|
// Simulate config.Env overriding envFile
|
|
configEnv := map[string]string{
|
|
"SHARED_VAR": "from_config",
|
|
"NEW_VAR": "from_config",
|
|
}
|
|
|
|
// Merge: envFile first, then config overrides
|
|
merged := make(map[string]string)
|
|
for k, v := range envVars {
|
|
merged[k] = v
|
|
}
|
|
for k, v := range configEnv {
|
|
merged[k] = v
|
|
}
|
|
|
|
// Verify priority: config.Env should override envFile
|
|
if merged["SHARED_VAR"] != "from_config" {
|
|
t.Errorf(
|
|
"Expected SHARED_VAR=from_config (config should override file), got %s",
|
|
merged["SHARED_VAR"],
|
|
)
|
|
}
|
|
if merged["API_KEY"] != "from_file" {
|
|
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
|
|
}
|
|
if merged["NEW_VAR"] != "from_config" {
|
|
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
|
|
}
|
|
}
|
|
|
|
func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
|
|
mgr := NewManager()
|
|
|
|
mcpCfg := config.MCPConfig{
|
|
ToolConfig: config.ToolConfig{
|
|
Enabled: true,
|
|
},
|
|
Servers: map[string]config.MCPServerConfig{
|
|
"test-server": {
|
|
Enabled: true,
|
|
Command: "echo",
|
|
Args: []string{"ok"},
|
|
EnvFile: ".env",
|
|
},
|
|
},
|
|
}
|
|
|
|
err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "")
|
|
if err == nil {
|
|
t.Fatal("expected error for relative env_file with empty workspace path, got nil")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "workspace path is empty") {
|
|
t.Fatalf("expected workspace path validation error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewManager_InitialState(t *testing.T) {
|
|
mgr := NewManager()
|
|
if mgr == nil {
|
|
t.Fatal("expected manager instance, got nil")
|
|
}
|
|
if len(mgr.GetServers()) != 0 {
|
|
t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers()))
|
|
}
|
|
}
|
|
|
|
func TestConnectServerPublishesRuntimeEvents(t *testing.T) {
|
|
originalConnectServerFunc := connectServerFunc
|
|
t.Cleanup(func() {
|
|
connectServerFunc = originalConnectServerFunc
|
|
})
|
|
|
|
eventBus := runtimeevents.NewBus()
|
|
defer func() {
|
|
if err := eventBus.Close(); err != nil {
|
|
t.Errorf("event bus close failed: %v", err)
|
|
}
|
|
}()
|
|
|
|
_, eventsCh, err := eventBus.Channel().OfKind(
|
|
runtimeevents.KindMCPServerConnected,
|
|
runtimeevents.KindMCPServerFailed,
|
|
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "mcp-events", Buffer: 2})
|
|
if err != nil {
|
|
t.Fatalf("SubscribeChan failed: %v", err)
|
|
}
|
|
|
|
connectServerFunc = func(
|
|
_ context.Context,
|
|
name string,
|
|
cfg config.MCPServerConfig,
|
|
) (*ServerConnection, error) {
|
|
if name == "bad" {
|
|
return nil, fmt.Errorf("connect failed")
|
|
}
|
|
return &ServerConnection{
|
|
Name: name,
|
|
Config: cfg,
|
|
Tools: []*sdkmcp.Tool{{Name: "echo"}},
|
|
}, nil
|
|
}
|
|
|
|
mgr := NewManager(WithRuntimeEvents(eventBus))
|
|
err = mgr.ConnectServer(context.Background(), "good", config.MCPServerConfig{
|
|
Type: "stdio",
|
|
Command: "echo",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("ConnectServer(good) error = %v", err)
|
|
}
|
|
connected := receiveMCPRuntimeEvent(t, eventsCh)
|
|
if connected.Kind != runtimeevents.KindMCPServerConnected ||
|
|
connected.Source.Name != "good" ||
|
|
connected.Severity != runtimeevents.SeverityInfo {
|
|
t.Fatalf("connected event = %+v", connected)
|
|
}
|
|
if connected.Attrs["server"] != "good" ||
|
|
connected.Attrs["type"] != "stdio" ||
|
|
connected.Attrs["tool_count"] != 1 {
|
|
t.Fatalf("connected attrs = %#v", connected.Attrs)
|
|
}
|
|
|
|
err = mgr.ConnectServer(context.Background(), "bad", config.MCPServerConfig{
|
|
Type: "stdio",
|
|
Command: "echo",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected ConnectServer(bad) to fail")
|
|
}
|
|
failed := receiveMCPRuntimeEvent(t, eventsCh)
|
|
if failed.Kind != runtimeevents.KindMCPServerFailed ||
|
|
failed.Source.Name != "bad" ||
|
|
failed.Severity != runtimeevents.SeverityError {
|
|
t.Fatalf("failed event = %+v", failed)
|
|
}
|
|
if failed.Attrs["server"] != "bad" || failed.Attrs["error"] != "connect failed" {
|
|
t.Fatalf("failed attrs = %#v", failed.Attrs)
|
|
}
|
|
}
|
|
|
|
func receiveMCPRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
|
t.Helper()
|
|
|
|
select {
|
|
case evt, ok := <-ch:
|
|
if !ok {
|
|
t.Fatal("runtime event channel closed before expected event")
|
|
}
|
|
return evt
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for runtime event")
|
|
return runtimeevents.Event{}
|
|
}
|
|
}
|
|
|
|
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
|
|
mgr := NewManager()
|
|
|
|
err := mgr.LoadFromMCPConfig(
|
|
context.Background(),
|
|
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}},
|
|
"/tmp",
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
|
|
}
|
|
|
|
err = mgr.LoadFromMCPConfig(
|
|
context.Background(),
|
|
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}},
|
|
"/tmp",
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error when no servers configured, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestGetServers_ReturnsCopy(t *testing.T) {
|
|
mgr := NewManager()
|
|
mgr.servers["s1"] = &ServerConnection{Name: "s1"}
|
|
|
|
servers := mgr.GetServers()
|
|
delete(servers, "s1")
|
|
|
|
if _, ok := mgr.GetServer("s1"); !ok {
|
|
t.Fatal("expected internal manager state to remain unchanged")
|
|
}
|
|
}
|
|
|
|
func TestGetAllTools_FiltersEmptyTools(t *testing.T) {
|
|
mgr := NewManager()
|
|
mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil}
|
|
mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}}
|
|
|
|
all := mgr.GetAllTools()
|
|
if _, ok := all["empty"]; ok {
|
|
t.Fatal("expected server without tools to be excluded")
|
|
}
|
|
if _, ok := all["with-tools"]; !ok {
|
|
t.Fatal("expected server with tools to be included")
|
|
}
|
|
}
|
|
|
|
func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
|
|
t.Run("manager closed", func(t *testing.T) {
|
|
mgr := NewManager()
|
|
mgr.closed.Store(true)
|
|
|
|
_, err := mgr.CallTool(context.Background(), "s1", "tool", nil)
|
|
if err == nil || !strings.Contains(err.Error(), "manager is closed") {
|
|
t.Fatalf("expected manager closed error, got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("server missing", func(t *testing.T) {
|
|
mgr := NewManager()
|
|
|
|
_, err := mgr.CallTool(context.Background(), "missing", "tool", nil)
|
|
if err == nil || !strings.Contains(err.Error(), "not found") {
|
|
t.Fatalf("expected server not found error, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConnectServer_StreamableHTTPRequestResponseMode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, transportType := range []string{"http", "streamable-http"} {
|
|
t.Run(transportType, func(t *testing.T) {
|
|
server := sdkmcp.NewServer(&sdkmcp.Implementation{
|
|
Name: "streamable-test-server",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
sdkmcp.AddTool(server, &sdkmcp.Tool{
|
|
Name: "echo",
|
|
Description: "Echo test tool",
|
|
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) {
|
|
return &sdkmcp.CallToolResult{
|
|
Content: []sdkmcp.Content{
|
|
&sdkmcp.TextContent{Text: "ok"},
|
|
},
|
|
}, nil, nil
|
|
})
|
|
|
|
type observedRequest struct {
|
|
Method string
|
|
SessionID string
|
|
Authorization string
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
observed []observedRequest
|
|
)
|
|
|
|
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
|
|
return server
|
|
}, nil)
|
|
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
observed = append(observed, observedRequest{
|
|
Method: r.Method,
|
|
SessionID: r.Header.Get("Mcp-Session-Id"),
|
|
Authorization: r.Header.Get("Authorization"),
|
|
})
|
|
mu.Unlock()
|
|
handler.ServeHTTP(w, r)
|
|
}))
|
|
defer httpServer.Close()
|
|
|
|
conn, err := connectServer(context.Background(), "streamable", config.MCPServerConfig{
|
|
Enabled: true,
|
|
Type: transportType,
|
|
URL: httpServer.URL,
|
|
Headers: map[string]string{
|
|
"Authorization": "Bearer test-token",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("connectServer(%q) error = %v", transportType, err)
|
|
}
|
|
if got := len(conn.Tools); got != 1 {
|
|
t.Fatalf("len(conn.Tools) = %d, want 1", got)
|
|
}
|
|
if got := conn.Session.ID(); got == "" {
|
|
t.Fatal("expected non-empty streamable session ID")
|
|
}
|
|
if err := conn.Session.Close(); err != nil {
|
|
t.Fatalf("Session.Close() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
var (
|
|
getCount int
|
|
postCount int
|
|
deleteCount int
|
|
postWithSession bool
|
|
deleteWithSession bool
|
|
requestsWithAuth int
|
|
requestsWithoutAuth []string
|
|
)
|
|
|
|
for _, req := range observed {
|
|
switch req.Method {
|
|
case http.MethodGet:
|
|
getCount++
|
|
case http.MethodPost:
|
|
postCount++
|
|
if req.SessionID != "" {
|
|
postWithSession = true
|
|
}
|
|
case http.MethodDelete:
|
|
deleteCount++
|
|
if req.SessionID != "" {
|
|
deleteWithSession = true
|
|
}
|
|
}
|
|
|
|
if req.Authorization == "Bearer test-token" {
|
|
requestsWithAuth++
|
|
} else {
|
|
requestsWithoutAuth = append(requestsWithoutAuth, req.Method)
|
|
}
|
|
}
|
|
|
|
if getCount != 0 {
|
|
t.Fatalf("expected no standalone GET requests for %q transport, saw %d", transportType, getCount)
|
|
}
|
|
if postCount == 0 {
|
|
t.Fatal("expected POST requests during streamable HTTP handshake")
|
|
}
|
|
if deleteCount != 1 {
|
|
t.Fatalf("DELETE count = %d, want 1", deleteCount)
|
|
}
|
|
if !postWithSession {
|
|
t.Fatal("expected at least one POST request with Mcp-Session-Id")
|
|
}
|
|
if !deleteWithSession {
|
|
t.Fatal("expected DELETE request with Mcp-Session-Id")
|
|
}
|
|
if requestsWithAuth != len(observed) {
|
|
t.Fatalf("Authorization header missing on requests: %v", requestsWithoutAuth)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCallTool_ReconnectsWhenHTTPServerLosesSession(t *testing.T) {
|
|
originalConnectServerFunc := connectServerFunc
|
|
t.Cleanup(func() {
|
|
connectServerFunc = originalConnectServerFunc
|
|
})
|
|
|
|
staleConn, staleTransport, err := newScriptedServerConnection(
|
|
"session-1",
|
|
nil,
|
|
fmt.Errorf(`sending "tools/call": failed to connect (session ID: session-1): %w`, sdkmcp.ErrSessionMissing),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("newScriptedServerConnection(stale) error = %v", err)
|
|
}
|
|
freshConn, freshTransport, err := newScriptedServerConnection(
|
|
"session-2",
|
|
&sdkmcp.CallToolResult{
|
|
Content: []sdkmcp.Content{
|
|
&sdkmcp.TextContent{Text: "reconnected"},
|
|
},
|
|
},
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("newScriptedServerConnection(fresh) error = %v", err)
|
|
}
|
|
|
|
connectCalls := 0
|
|
connectServerFunc = func(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) {
|
|
connectCalls++
|
|
if connectCalls == 1 {
|
|
return freshConn, nil
|
|
}
|
|
return nil, fmt.Errorf("unexpected reconnect attempt %d", connectCalls)
|
|
}
|
|
|
|
mgr := NewManager()
|
|
mgr.servers["flaky"] = staleConn
|
|
|
|
result, err := mgr.CallTool(context.Background(), "flaky", "echo", map[string]any{
|
|
"query": "hello",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("CallTool() error = %v", err)
|
|
}
|
|
if result == nil || len(result.Content) != 1 {
|
|
t.Fatalf("CallTool() returned unexpected content: %#v", result)
|
|
}
|
|
|
|
text, ok := result.Content[0].(*sdkmcp.TextContent)
|
|
if !ok {
|
|
t.Fatalf("CallTool() content type = %T, want *sdkmcp.TextContent", result.Content[0])
|
|
}
|
|
if text.Text != "reconnected" {
|
|
t.Fatalf("CallTool() text = %q, want %q", text.Text, "reconnected")
|
|
}
|
|
|
|
conn, ok := mgr.GetServer("flaky")
|
|
if !ok {
|
|
t.Fatal("expected flaky server to remain connected after reconnect")
|
|
}
|
|
if conn.Session.ID() != "session-2" {
|
|
t.Fatalf("Session.ID() = %q, want %q", conn.Session.ID(), "session-2")
|
|
}
|
|
if connectCalls != 1 {
|
|
t.Fatalf("connectCalls = %d, want 1", connectCalls)
|
|
}
|
|
if staleTransport.toolCallCalls != 1 {
|
|
t.Fatalf("stale toolCallCalls = %d, want 1", staleTransport.toolCallCalls)
|
|
}
|
|
if freshTransport.toolCallCalls != 1 {
|
|
t.Fatalf("fresh toolCallCalls = %d, want 1", freshTransport.toolCallCalls)
|
|
}
|
|
}
|
|
|
|
func TestClose_IdempotentOnEmptyManager(t *testing.T) {
|
|
mgr := NewManager()
|
|
|
|
if err := mgr.Close(); err != nil {
|
|
t.Fatalf("first close should succeed, got: %v", err)
|
|
}
|
|
if err := mgr.Close(); err != nil {
|
|
t.Fatalf("second close should be idempotent, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func newScriptedServerConnection(
|
|
sessionID string,
|
|
toolCallResult *sdkmcp.CallToolResult,
|
|
toolCallErr error,
|
|
) (*ServerConnection, *scriptedTransport, error) {
|
|
transport := &scriptedTransport{
|
|
sessionID: sessionID,
|
|
toolCallResult: toolCallResult,
|
|
toolCallErr: toolCallErr,
|
|
}
|
|
|
|
client := sdkmcp.NewClient(&sdkmcp.Implementation{
|
|
Name: "picoclaw-test",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
session, err := client.Connect(context.Background(), transport, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return &ServerConnection{
|
|
Name: "flaky",
|
|
Config: config.MCPServerConfig{Enabled: true, Type: "http", URL: "https://example.invalid/mcp"},
|
|
Client: client,
|
|
Session: session,
|
|
Tools: []*sdkmcp.Tool{
|
|
{
|
|
Name: "echo",
|
|
Description: "Echo test tool",
|
|
InputSchema: map[string]any{"type": "object"},
|
|
},
|
|
},
|
|
}, transport, nil
|
|
}
|
|
|
|
type scriptedTransport struct {
|
|
sessionID string
|
|
toolCallResult *sdkmcp.CallToolResult
|
|
toolCallErr error
|
|
|
|
mu sync.Mutex
|
|
toolCallCalls int
|
|
closed bool
|
|
incoming chan jsonrpc.Message
|
|
}
|
|
|
|
func (t *scriptedTransport) Connect(context.Context) (sdkmcp.Connection, error) {
|
|
if t.incoming == nil {
|
|
t.incoming = make(chan jsonrpc.Message, 4)
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
func (t *scriptedTransport) Read(ctx context.Context) (jsonrpc.Message, error) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case msg, ok := <-t.incoming:
|
|
if !ok {
|
|
return nil, io.EOF
|
|
}
|
|
return msg, nil
|
|
}
|
|
}
|
|
|
|
func (t *scriptedTransport) Write(ctx context.Context, msg jsonrpc.Message) error {
|
|
req, ok := msg.(*jsonrpc.Request)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
switch req.Method {
|
|
case "initialize":
|
|
payload, err := json.Marshal(&sdkmcp.InitializeResult{
|
|
ProtocolVersion: "2025-11-25",
|
|
ServerInfo: &sdkmcp.Implementation{
|
|
Name: "scripted-test-server",
|
|
Version: "1.0.0",
|
|
},
|
|
Capabilities: &sdkmcp.ServerCapabilities{
|
|
Tools: &sdkmcp.ToolCapabilities{},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
|
|
return nil
|
|
}
|
|
|
|
case "notifications/initialized":
|
|
return nil
|
|
|
|
case "tools/call":
|
|
t.mu.Lock()
|
|
t.toolCallCalls++
|
|
t.mu.Unlock()
|
|
|
|
if t.toolCallErr != nil {
|
|
return t.toolCallErr
|
|
}
|
|
|
|
payload, err := json.Marshal(t.toolCallResult)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case t.incoming <- &jsonrpc.Response{ID: req.ID, Result: payload}:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("unexpected method %q", req.Method)
|
|
}
|
|
|
|
func (t *scriptedTransport) Close() error {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
if t.closed {
|
|
return nil
|
|
}
|
|
t.closed = true
|
|
close(t.incoming)
|
|
return nil
|
|
}
|
|
|
|
func (t *scriptedTransport) SessionID() string {
|
|
return t.sessionID
|
|
}
|