Files
picoclaw/pkg/mcp/client.go
T

604 lines
12 KiB
Go

package mcp
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
// Client is the transport-agnostic MCP client contract.
type Client interface {
Start(ctx context.Context) error
ListTools(ctx context.Context) ([]RemoteTool, error)
CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error)
Close() error
}
// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers).
type StdioClient struct {
config ServerConfig
mode string
mu sync.Mutex
writeMu sync.Mutex
started bool
closed bool
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.ReadCloser
waitCh chan struct{}
pending map[string]chan rpcResponse
nextID uint64
}
type rpcRequest struct {
JSONRPC string `json:"jsonrpc"`
ID string `json:"id,omitempty"`
Method string `json:"method"`
Params any `json:"params,omitempty"`
}
type rpcResponseEnvelope struct {
JSONRPC string `json:"jsonrpc"`
ID json.RawMessage `json:"id,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *rpcError `json:"error,omitempty"`
Method string `json:"method,omitempty"`
}
type rpcError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type rpcResponse struct {
result json.RawMessage
rpcErr *rpcError
err error
}
type initializeParams struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]any `json:"capabilities"`
ClientInfo map[string]interface{} `json:"clientInfo"`
}
func NewStdioClient(config ServerConfig) *StdioClient {
return &StdioClient{
config: config,
mode: normalizeProtocol(config.Protocol),
}
}
func (c *StdioClient) Start(ctx context.Context) error {
c.mu.Lock()
if c.started {
c.mu.Unlock()
return nil
}
if strings.TrimSpace(c.config.Command) == "" {
c.mu.Unlock()
return fmt.Errorf("mcp server %q command is empty", c.config.Name)
}
cmd := exec.Command(c.config.Command, c.config.Args...)
if c.config.WorkingDir != "" {
cmd.Dir = c.config.WorkingDir
}
cmd.Env = buildProcessEnv(c.config.Env)
stdin, err := cmd.StdinPipe()
if err != nil {
c.mu.Unlock()
return fmt.Errorf("create stdin pipe: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
c.mu.Unlock()
return fmt.Errorf("create stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
c.mu.Unlock()
return fmt.Errorf("create stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
c.mu.Unlock()
return fmt.Errorf("start process: %w", err)
}
c.started = true
c.closed = false
c.cmd = cmd
c.stdin = stdin
c.stdout = stdout
c.stderr = stderr
c.waitCh = make(chan struct{})
c.pending = make(map[string]chan rpcResponse)
c.mu.Unlock()
go c.readLoop()
go c.waitLoop()
go c.drainStderr()
initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout())
defer cancel()
_, err = c.request(initCtx, "initialize", initializeParams{
ProtocolVersion: "2024-11-05",
Capabilities: map[string]any{
"tools": map[string]any{},
},
ClientInfo: map[string]any{
"name": "picoclaw",
"version": "dev",
},
})
if err != nil {
_ = c.Close()
return fmt.Errorf("initialize failed: %w", err)
}
if err := c.notify("notifications/initialized", map[string]any{}); err != nil {
_ = c.Close()
return fmt.Errorf("initialized notification failed: %w", err)
}
return nil
}
func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) {
if err := c.Start(ctx); err != nil {
return nil, err
}
type listToolsResponse struct {
Tools []struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"inputSchema"`
} `json:"tools"`
NextCursor string `json:"nextCursor,omitempty"`
}
allTools := make([]RemoteTool, 0, 8)
cursor := ""
for page := 0; page < maxToolListPages; page++ {
params := map[string]any{}
if cursor != "" {
params["cursor"] = cursor
}
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
raw, err := c.request(callCtx, "tools/list", params)
cancel()
if err != nil {
return nil, err
}
var response listToolsResponse
if err := json.Unmarshal(raw, &response); err != nil {
return nil, fmt.Errorf("decode tools/list response: %w", err)
}
for _, tool := range response.Tools {
allTools = append(allTools, RemoteTool{
Name: tool.Name,
Description: tool.Description,
InputSchema: tool.InputSchema,
})
}
if response.NextCursor == "" {
return allTools, nil
}
cursor = response.NextCursor
}
return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages)
}
func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) {
if err := c.Start(ctx); err != nil {
return CallResult{}, err
}
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
defer cancel()
raw, err := c.request(callCtx, "tools/call", map[string]any{
"name": toolName,
"arguments": arguments,
})
if err != nil {
return CallResult{}, err
}
return formatCallPayload(raw, c.config.ResponseLimit())
}
func (c *StdioClient) Close() error {
c.mu.Lock()
if !c.started || c.closed {
c.mu.Unlock()
return nil
}
c.closed = true
cmd := c.cmd
stdin := c.stdin
waitCh := c.waitCh
c.mu.Unlock()
c.failPending(errors.New("mcp client closed"))
if stdin != nil {
_ = stdin.Close()
}
if cmd != nil && cmd.Process != nil {
_ = cmd.Process.Kill()
}
if waitCh != nil {
select {
case <-waitCh:
case <-time.After(2 * time.Second):
}
}
return nil
}
func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) {
id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10)
responseCh := make(chan rpcResponse, 1)
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, fmt.Errorf("mcp server %q is closed", c.config.Name)
}
c.pending[id] = responseCh
c.mu.Unlock()
req := rpcRequest{
JSONRPC: "2.0",
ID: id,
Method: method,
Params: params,
}
if err := c.writeMessage(req); err != nil {
c.removePending(id)
return nil, err
}
select {
case <-ctx.Done():
c.removePending(id)
return nil, ctx.Err()
case response := <-responseCh:
if response.err != nil {
return nil, response.err
}
if response.rpcErr != nil {
return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message)
}
return response.result, nil
}
}
func (c *StdioClient) notify(method string, params any) error {
req := rpcRequest{
JSONRPC: "2.0",
Method: method,
Params: params,
}
return c.writeMessage(req)
}
func (c *StdioClient) writeMessage(payload any) error {
c.mu.Lock()
if c.closed || c.stdin == nil {
c.mu.Unlock()
return fmt.Errorf("mcp server %q is not writable", c.config.Name)
}
stdin := c.stdin
c.mu.Unlock()
data, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal json-rpc payload: %w", err)
}
if c.mode == ProtocolJSONLines {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if _, err := stdin.Write(append(data, '\n')); err != nil {
return fmt.Errorf("write jsonl body: %w", err)
}
return nil
}
frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
c.writeMu.Lock()
defer c.writeMu.Unlock()
if _, err := io.WriteString(stdin, frameHeader); err != nil {
return fmt.Errorf("write frame header: %w", err)
}
if _, err := stdin.Write(data); err != nil {
return fmt.Errorf("write frame body: %w", err)
}
return nil
}
func (c *StdioClient) readLoop() {
if c.mode == ProtocolJSONLines {
c.readJSONLLoop()
return
}
c.readMCPFrameLoop()
}
func (c *StdioClient) readMCPFrameLoop() {
reader := bufio.NewReader(c.stdout)
for {
payload, err := readFramePayload(reader)
if err != nil {
c.failPending(err)
return
}
var envelope rpcResponseEnvelope
if err := json.Unmarshal(payload, &envelope); err != nil {
continue
}
c.dispatchResponse(envelope)
}
}
func (c *StdioClient) readJSONLLoop() {
scanner := bufio.NewScanner(c.stdout)
scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var envelope rpcResponseEnvelope
if err := json.Unmarshal([]byte(line), &envelope); err != nil {
continue
}
c.dispatchResponse(envelope)
}
if err := scanner.Err(); err != nil {
c.failPending(err)
return
}
c.failPending(io.EOF)
}
func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) {
if len(envelope.ID) == 0 {
return
}
id, ok := parseRPCID(envelope.ID)
if !ok {
return
}
c.mu.Lock()
responseCh := c.pending[id]
if responseCh != nil {
delete(c.pending, id)
}
c.mu.Unlock()
if responseCh == nil {
return
}
response := rpcResponse{
result: envelope.Result,
rpcErr: envelope.Error,
}
select {
case responseCh <- response:
default:
}
}
func (c *StdioClient) waitLoop() {
c.mu.Lock()
cmd := c.cmd
waitCh := c.waitCh
serverName := c.config.Name
c.mu.Unlock()
if cmd == nil {
if waitCh != nil {
close(waitCh)
}
return
}
err := cmd.Wait()
if waitCh != nil {
close(waitCh)
}
if err != nil {
logger.WarnCF("mcp", "MCP process exited with error",
map[string]any{
"server": serverName,
"error": err.Error(),
})
}
}
func (c *StdioClient) drainStderr() {
c.mu.Lock()
stderr := c.stderr
serverName := c.config.Name
c.mu.Unlock()
if stderr == nil {
return
}
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
logger.DebugCF("mcp", "MCP server stderr",
map[string]any{
"server": serverName,
"line": line,
})
}
}
func (c *StdioClient) failPending(err error) {
c.mu.Lock()
pending := c.pending
c.pending = make(map[string]chan rpcResponse)
c.mu.Unlock()
if len(pending) == 0 {
return
}
for _, ch := range pending {
select {
case ch <- rpcResponse{err: err}:
default:
}
}
}
func (c *StdioClient) removePending(id string) {
c.mu.Lock()
delete(c.pending, id)
c.mu.Unlock()
}
func readFramePayload(reader *bufio.Reader) ([]byte, error) {
contentLength := -1
for {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
trimmed := strings.TrimRight(line, "\r\n")
if trimmed == "" {
break
}
parts := strings.SplitN(trimmed, ":", 2)
if len(parts) != 2 {
continue
}
headerName := strings.TrimSpace(strings.ToLower(parts[0]))
if headerName != "content-length" {
continue
}
value := strings.TrimSpace(parts[1])
length, err := strconv.Atoi(value)
if err != nil {
return nil, fmt.Errorf("invalid content-length %q: %w", value, err)
}
contentLength = length
}
if contentLength <= 0 {
return nil, fmt.Errorf("missing content-length")
}
if contentLength > maxFrameBytes {
return nil, fmt.Errorf("frame too large (%d bytes)", contentLength)
}
payload := make([]byte, contentLength)
if _, err := io.ReadFull(reader, payload); err != nil {
return nil, err
}
return payload, nil
}
func parseRPCID(raw json.RawMessage) (string, bool) {
var stringID string
if err := json.Unmarshal(raw, &stringID); err == nil {
return stringID, true
}
var numberID float64
if err := json.Unmarshal(raw, &numberID); err == nil {
return strconv.FormatInt(int64(numberID), 10), true
}
return "", false
}
func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if _, hasDeadline := parent.Deadline(); hasDeadline {
return context.WithCancel(parent)
}
return context.WithTimeout(parent, timeout)
}
func buildProcessEnv(custom map[string]string) []string {
base := os.Environ()
if len(custom) == 0 {
return base
}
keys := make([]string, 0, len(custom))
for key := range custom {
keys = append(keys, key)
}
sort.Strings(keys)
env := make([]string, 0, len(base)+len(keys))
env = append(env, base...)
for _, key := range keys {
env = append(env, key+"="+custom[key])
}
return env
}
func normalizeProtocol(protocol string) string {
switch strings.ToLower(strings.TrimSpace(protocol)) {
case "", ProtocolMCPFrames:
return ProtocolMCPFrames
case ProtocolJSONLines:
return ProtocolJSONLines
default:
return ProtocolMCPFrames
}
}