mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
604 lines
12 KiB
Go
604 lines
12 KiB
Go
package mcp
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
// Client is the transport-agnostic MCP client contract.
|
|
type Client interface {
|
|
Start(ctx context.Context) error
|
|
ListTools(ctx context.Context) ([]RemoteTool, error)
|
|
CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error)
|
|
Close() error
|
|
}
|
|
|
|
// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers).
|
|
type StdioClient struct {
|
|
config ServerConfig
|
|
mode string
|
|
|
|
mu sync.Mutex
|
|
writeMu sync.Mutex
|
|
|
|
started bool
|
|
closed bool
|
|
|
|
cmd *exec.Cmd
|
|
stdin io.WriteCloser
|
|
stdout io.ReadCloser
|
|
stderr io.ReadCloser
|
|
waitCh chan struct{}
|
|
pending map[string]chan rpcResponse
|
|
|
|
nextID uint64
|
|
}
|
|
|
|
type rpcRequest struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID string `json:"id,omitempty"`
|
|
Method string `json:"method"`
|
|
Params any `json:"params,omitempty"`
|
|
}
|
|
|
|
type rpcResponseEnvelope struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID json.RawMessage `json:"id,omitempty"`
|
|
Result json.RawMessage `json:"result,omitempty"`
|
|
Error *rpcError `json:"error,omitempty"`
|
|
Method string `json:"method,omitempty"`
|
|
}
|
|
|
|
type rpcError struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
type rpcResponse struct {
|
|
result json.RawMessage
|
|
rpcErr *rpcError
|
|
err error
|
|
}
|
|
|
|
type initializeParams struct {
|
|
ProtocolVersion string `json:"protocolVersion"`
|
|
Capabilities map[string]any `json:"capabilities"`
|
|
ClientInfo map[string]interface{} `json:"clientInfo"`
|
|
}
|
|
|
|
func NewStdioClient(config ServerConfig) *StdioClient {
|
|
return &StdioClient{
|
|
config: config,
|
|
mode: normalizeProtocol(config.Protocol),
|
|
}
|
|
}
|
|
|
|
func (c *StdioClient) Start(ctx context.Context) error {
|
|
c.mu.Lock()
|
|
if c.started {
|
|
c.mu.Unlock()
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(c.config.Command) == "" {
|
|
c.mu.Unlock()
|
|
return fmt.Errorf("mcp server %q command is empty", c.config.Name)
|
|
}
|
|
|
|
cmd := exec.Command(c.config.Command, c.config.Args...)
|
|
if c.config.WorkingDir != "" {
|
|
cmd.Dir = c.config.WorkingDir
|
|
}
|
|
cmd.Env = buildProcessEnv(c.config.Env)
|
|
|
|
stdin, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
c.mu.Unlock()
|
|
return fmt.Errorf("create stdin pipe: %w", err)
|
|
}
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
c.mu.Unlock()
|
|
return fmt.Errorf("create stdout pipe: %w", err)
|
|
}
|
|
stderr, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
c.mu.Unlock()
|
|
return fmt.Errorf("create stderr pipe: %w", err)
|
|
}
|
|
if err := cmd.Start(); err != nil {
|
|
c.mu.Unlock()
|
|
return fmt.Errorf("start process: %w", err)
|
|
}
|
|
|
|
c.started = true
|
|
c.closed = false
|
|
c.cmd = cmd
|
|
c.stdin = stdin
|
|
c.stdout = stdout
|
|
c.stderr = stderr
|
|
c.waitCh = make(chan struct{})
|
|
c.pending = make(map[string]chan rpcResponse)
|
|
c.mu.Unlock()
|
|
|
|
go c.readLoop()
|
|
go c.waitLoop()
|
|
go c.drainStderr()
|
|
|
|
initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout())
|
|
defer cancel()
|
|
|
|
_, err = c.request(initCtx, "initialize", initializeParams{
|
|
ProtocolVersion: "2024-11-05",
|
|
Capabilities: map[string]any{
|
|
"tools": map[string]any{},
|
|
},
|
|
ClientInfo: map[string]any{
|
|
"name": "picoclaw",
|
|
"version": "dev",
|
|
},
|
|
})
|
|
if err != nil {
|
|
_ = c.Close()
|
|
return fmt.Errorf("initialize failed: %w", err)
|
|
}
|
|
|
|
if err := c.notify("notifications/initialized", map[string]any{}); err != nil {
|
|
_ = c.Close()
|
|
return fmt.Errorf("initialized notification failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) {
|
|
if err := c.Start(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
type listToolsResponse struct {
|
|
Tools []struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
InputSchema map[string]any `json:"inputSchema"`
|
|
} `json:"tools"`
|
|
NextCursor string `json:"nextCursor,omitempty"`
|
|
}
|
|
|
|
allTools := make([]RemoteTool, 0, 8)
|
|
cursor := ""
|
|
|
|
for page := 0; page < maxToolListPages; page++ {
|
|
params := map[string]any{}
|
|
if cursor != "" {
|
|
params["cursor"] = cursor
|
|
}
|
|
|
|
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
|
raw, err := c.request(callCtx, "tools/list", params)
|
|
cancel()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var response listToolsResponse
|
|
if err := json.Unmarshal(raw, &response); err != nil {
|
|
return nil, fmt.Errorf("decode tools/list response: %w", err)
|
|
}
|
|
|
|
for _, tool := range response.Tools {
|
|
allTools = append(allTools, RemoteTool{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
InputSchema: tool.InputSchema,
|
|
})
|
|
}
|
|
|
|
if response.NextCursor == "" {
|
|
return allTools, nil
|
|
}
|
|
cursor = response.NextCursor
|
|
}
|
|
|
|
return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages)
|
|
}
|
|
|
|
func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) {
|
|
if err := c.Start(ctx); err != nil {
|
|
return CallResult{}, err
|
|
}
|
|
|
|
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
|
defer cancel()
|
|
|
|
raw, err := c.request(callCtx, "tools/call", map[string]any{
|
|
"name": toolName,
|
|
"arguments": arguments,
|
|
})
|
|
if err != nil {
|
|
return CallResult{}, err
|
|
}
|
|
|
|
return formatCallPayload(raw, c.config.ResponseLimit())
|
|
}
|
|
|
|
func (c *StdioClient) Close() error {
|
|
c.mu.Lock()
|
|
if !c.started || c.closed {
|
|
c.mu.Unlock()
|
|
return nil
|
|
}
|
|
c.closed = true
|
|
cmd := c.cmd
|
|
stdin := c.stdin
|
|
waitCh := c.waitCh
|
|
c.mu.Unlock()
|
|
|
|
c.failPending(errors.New("mcp client closed"))
|
|
|
|
if stdin != nil {
|
|
_ = stdin.Close()
|
|
}
|
|
if cmd != nil && cmd.Process != nil {
|
|
_ = cmd.Process.Kill()
|
|
}
|
|
|
|
if waitCh != nil {
|
|
select {
|
|
case <-waitCh:
|
|
case <-time.After(2 * time.Second):
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) {
|
|
id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10)
|
|
responseCh := make(chan rpcResponse, 1)
|
|
|
|
c.mu.Lock()
|
|
if c.closed {
|
|
c.mu.Unlock()
|
|
return nil, fmt.Errorf("mcp server %q is closed", c.config.Name)
|
|
}
|
|
c.pending[id] = responseCh
|
|
c.mu.Unlock()
|
|
|
|
req := rpcRequest{
|
|
JSONRPC: "2.0",
|
|
ID: id,
|
|
Method: method,
|
|
Params: params,
|
|
}
|
|
if err := c.writeMessage(req); err != nil {
|
|
c.removePending(id)
|
|
return nil, err
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
c.removePending(id)
|
|
return nil, ctx.Err()
|
|
case response := <-responseCh:
|
|
if response.err != nil {
|
|
return nil, response.err
|
|
}
|
|
if response.rpcErr != nil {
|
|
return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message)
|
|
}
|
|
return response.result, nil
|
|
}
|
|
}
|
|
|
|
func (c *StdioClient) notify(method string, params any) error {
|
|
req := rpcRequest{
|
|
JSONRPC: "2.0",
|
|
Method: method,
|
|
Params: params,
|
|
}
|
|
return c.writeMessage(req)
|
|
}
|
|
|
|
func (c *StdioClient) writeMessage(payload any) error {
|
|
c.mu.Lock()
|
|
if c.closed || c.stdin == nil {
|
|
c.mu.Unlock()
|
|
return fmt.Errorf("mcp server %q is not writable", c.config.Name)
|
|
}
|
|
stdin := c.stdin
|
|
c.mu.Unlock()
|
|
|
|
data, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal json-rpc payload: %w", err)
|
|
}
|
|
|
|
if c.mode == ProtocolJSONLines {
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
|
|
if _, err := stdin.Write(append(data, '\n')); err != nil {
|
|
return fmt.Errorf("write jsonl body: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
|
|
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
|
|
if _, err := io.WriteString(stdin, frameHeader); err != nil {
|
|
return fmt.Errorf("write frame header: %w", err)
|
|
}
|
|
if _, err := stdin.Write(data); err != nil {
|
|
return fmt.Errorf("write frame body: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *StdioClient) readLoop() {
|
|
if c.mode == ProtocolJSONLines {
|
|
c.readJSONLLoop()
|
|
return
|
|
}
|
|
|
|
c.readMCPFrameLoop()
|
|
}
|
|
|
|
func (c *StdioClient) readMCPFrameLoop() {
|
|
reader := bufio.NewReader(c.stdout)
|
|
|
|
for {
|
|
payload, err := readFramePayload(reader)
|
|
if err != nil {
|
|
c.failPending(err)
|
|
return
|
|
}
|
|
|
|
var envelope rpcResponseEnvelope
|
|
if err := json.Unmarshal(payload, &envelope); err != nil {
|
|
continue
|
|
}
|
|
c.dispatchResponse(envelope)
|
|
}
|
|
}
|
|
|
|
func (c *StdioClient) readJSONLLoop() {
|
|
scanner := bufio.NewScanner(c.stdout)
|
|
scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes)
|
|
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
var envelope rpcResponseEnvelope
|
|
if err := json.Unmarshal([]byte(line), &envelope); err != nil {
|
|
continue
|
|
}
|
|
c.dispatchResponse(envelope)
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
c.failPending(err)
|
|
return
|
|
}
|
|
c.failPending(io.EOF)
|
|
}
|
|
|
|
func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) {
|
|
if len(envelope.ID) == 0 {
|
|
return
|
|
}
|
|
|
|
id, ok := parseRPCID(envelope.ID)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
c.mu.Lock()
|
|
responseCh := c.pending[id]
|
|
if responseCh != nil {
|
|
delete(c.pending, id)
|
|
}
|
|
c.mu.Unlock()
|
|
|
|
if responseCh == nil {
|
|
return
|
|
}
|
|
|
|
response := rpcResponse{
|
|
result: envelope.Result,
|
|
rpcErr: envelope.Error,
|
|
}
|
|
select {
|
|
case responseCh <- response:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (c *StdioClient) waitLoop() {
|
|
c.mu.Lock()
|
|
cmd := c.cmd
|
|
waitCh := c.waitCh
|
|
serverName := c.config.Name
|
|
c.mu.Unlock()
|
|
|
|
if cmd == nil {
|
|
if waitCh != nil {
|
|
close(waitCh)
|
|
}
|
|
return
|
|
}
|
|
|
|
err := cmd.Wait()
|
|
if waitCh != nil {
|
|
close(waitCh)
|
|
}
|
|
if err != nil {
|
|
logger.WarnCF("mcp", "MCP process exited with error",
|
|
map[string]any{
|
|
"server": serverName,
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *StdioClient) drainStderr() {
|
|
c.mu.Lock()
|
|
stderr := c.stderr
|
|
serverName := c.config.Name
|
|
c.mu.Unlock()
|
|
|
|
if stderr == nil {
|
|
return
|
|
}
|
|
|
|
scanner := bufio.NewScanner(stderr)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
logger.DebugCF("mcp", "MCP server stderr",
|
|
map[string]any{
|
|
"server": serverName,
|
|
"line": line,
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *StdioClient) failPending(err error) {
|
|
c.mu.Lock()
|
|
pending := c.pending
|
|
c.pending = make(map[string]chan rpcResponse)
|
|
c.mu.Unlock()
|
|
|
|
if len(pending) == 0 {
|
|
return
|
|
}
|
|
|
|
for _, ch := range pending {
|
|
select {
|
|
case ch <- rpcResponse{err: err}:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *StdioClient) removePending(id string) {
|
|
c.mu.Lock()
|
|
delete(c.pending, id)
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
func readFramePayload(reader *bufio.Reader) ([]byte, error) {
|
|
contentLength := -1
|
|
|
|
for {
|
|
line, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
trimmed := strings.TrimRight(line, "\r\n")
|
|
if trimmed == "" {
|
|
break
|
|
}
|
|
|
|
parts := strings.SplitN(trimmed, ":", 2)
|
|
if len(parts) != 2 {
|
|
continue
|
|
}
|
|
headerName := strings.TrimSpace(strings.ToLower(parts[0]))
|
|
if headerName != "content-length" {
|
|
continue
|
|
}
|
|
value := strings.TrimSpace(parts[1])
|
|
length, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid content-length %q: %w", value, err)
|
|
}
|
|
contentLength = length
|
|
}
|
|
|
|
if contentLength <= 0 {
|
|
return nil, fmt.Errorf("missing content-length")
|
|
}
|
|
if contentLength > maxFrameBytes {
|
|
return nil, fmt.Errorf("frame too large (%d bytes)", contentLength)
|
|
}
|
|
|
|
payload := make([]byte, contentLength)
|
|
if _, err := io.ReadFull(reader, payload); err != nil {
|
|
return nil, err
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
func parseRPCID(raw json.RawMessage) (string, bool) {
|
|
var stringID string
|
|
if err := json.Unmarshal(raw, &stringID); err == nil {
|
|
return stringID, true
|
|
}
|
|
|
|
var numberID float64
|
|
if err := json.Unmarshal(raw, &numberID); err == nil {
|
|
return strconv.FormatInt(int64(numberID), 10), true
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
|
if _, hasDeadline := parent.Deadline(); hasDeadline {
|
|
return context.WithCancel(parent)
|
|
}
|
|
return context.WithTimeout(parent, timeout)
|
|
}
|
|
|
|
func buildProcessEnv(custom map[string]string) []string {
|
|
base := os.Environ()
|
|
if len(custom) == 0 {
|
|
return base
|
|
}
|
|
|
|
keys := make([]string, 0, len(custom))
|
|
for key := range custom {
|
|
keys = append(keys, key)
|
|
}
|
|
sort.Strings(keys)
|
|
|
|
env := make([]string, 0, len(base)+len(keys))
|
|
env = append(env, base...)
|
|
for _, key := range keys {
|
|
env = append(env, key+"="+custom[key])
|
|
}
|
|
return env
|
|
}
|
|
|
|
func normalizeProtocol(protocol string) string {
|
|
switch strings.ToLower(strings.TrimSpace(protocol)) {
|
|
case "", ProtocolMCPFrames:
|
|
return ProtocolMCPFrames
|
|
case ProtocolJSONLines:
|
|
return ProtocolJSONLines
|
|
default:
|
|
return ProtocolMCPFrames
|
|
}
|
|
}
|