feat(mcp): add Model Context Protocol integration

Implement comprehensive MCP support with stdio/HTTP/SSE transports, environment variable configuration (env and envFile), custom headers, tool registration, and automatic resource cleanup. Includes full test coverage and VSCode-compatible configuration.

- Added pkg/mcp/manager.go for server lifecycle management
- Added pkg/tools/mcp_tool.go for tool wrapping
- Integrated into agent loop with cleanup
- Support for envFile loading (.env format)
- Headers injection for HTTP/SSE authentication
- Example configs for filesystem, github, brave-search, postgres
This commit is contained in:
yuchou87
2026-02-15 17:26:36 +08:00
parent 9a3f3611c3
commit 91c168db20
9 changed files with 1366 additions and 8 deletions
+60 -2
View File
@@ -14,7 +14,9 @@
"enabled": false,
"token": "YOUR_TELEGRAM_BOT_TOKEN",
"proxy": "",
"allow_from": ["YOUR_USER_ID"]
"allow_from": [
"YOUR_USER_ID"
]
},
"discord": {
"enabled": false,
@@ -115,6 +117,62 @@
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
},
"mcp": {
"enabled": false,
"servers": {
"filesystem": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-filesystem",
"/tmp"
],
"env": {}
},
"github": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-github"
],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
},
"envFile": ".env"
},
"brave-search": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-brave-search"
],
"env": {
"BRAVE_API_KEY": "YOUR_BRAVE_API_KEY"
}
},
"postgres": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-postgres",
"postgresql://user:password@localhost/dbname"
]
},
"remote-http-example": {
"enabled": false,
"url": "https://mcp-server.example.com/stream",
"type": "sse",
"headers": {
"Authorization": "Bearer YOUR_TOKEN",
"X-Custom-Header": "custom-value"
}
}
}
}
},
"heartbeat": {
@@ -129,4 +187,4 @@
"host": "0.0.0.0",
"port": 18790
}
}
}
+3 -3
View File
@@ -11,6 +11,7 @@ require (
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/modelcontextprotocol/go-sdk v1.3.0
github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
@@ -19,7 +20,7 @@ require (
golang.org/x/oauth2 v0.35.0
)
require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
require (
github.com/andybalholm/brotli v1.2.0 // indirect
@@ -28,9 +29,9 @@ require (
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/github/copilot-sdk/go v0.1.23
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/grbit/go-json v0.11.0 // indirect
github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
@@ -47,5 +48,4 @@ require (
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
)
+10
View File
@@ -43,6 +43,8 @@ github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
@@ -58,6 +60,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -84,6 +88,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs=
github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE=
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -141,6 +147,8 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
@@ -228,6 +236,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+43 -3
View File
@@ -22,6 +22,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/mcp"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/state"
@@ -40,6 +41,7 @@ type AgentLoop struct {
state *state.Manager
contextBuilder *ContextBuilder
tools *tools.ToolRegistry
mcpManager *mcp.Manager // MCP server manager for resource cleanup
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
}
@@ -58,7 +60,7 @@ type processOptions struct {
// createToolRegistry creates a tool registry with common tools.
// This is shared between main agent and subagents.
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus, mcpManager *mcp.Manager) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
// File system tools
@@ -99,6 +101,23 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg
})
registry.Register(messageTool)
// Register MCP tools from all connected servers
if mcpManager != nil {
servers := mcpManager.GetServers()
for serverName, conn := range servers {
for _, tool := range conn.Tools {
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
registry.Register(mcpTool)
logger.DebugCF("agent", "Registered MCP tool",
map[string]interface{}{
"server": serverName,
"tool": tool.Name,
"name": mcpTool.Name(),
})
}
}
}
return registry
}
@@ -108,12 +127,22 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// Initialize MCP Manager and load servers
mcpManager := mcp.NewManager()
ctx := context.Background()
if err := mcpManager.LoadFromConfig(ctx, cfg); err != nil {
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
map[string]interface{}{
"error": err.Error(),
})
}
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
@@ -145,6 +174,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
state: stateManager,
contextBuilder: contextBuilder,
tools: toolsRegistry,
mcpManager: mcpManager,
summarizing: sync.Map{},
}
}
@@ -193,6 +223,16 @@ func (al *AgentLoop) Run(ctx context.Context) error {
func (al *AgentLoop) Stop() {
al.running.Store(false)
// Clean up MCP connections
if al.mcpManager != nil {
if err := al.mcpManager.Close(); err != nil {
logger.ErrorCF("agent", "Failed to close MCP manager",
map[string]interface{}{
"error": err.Error(),
})
}
}
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
+61
View File
@@ -212,6 +212,35 @@ type WebToolsConfig struct {
type ToolsConfig struct {
Web WebToolsConfig `json:"web"`
MCP MCPConfig `json:"mcp"`
}
// MCPServerConfig defines configuration for a single MCP server
type MCPServerConfig struct {
// Enabled indicates whether this MCP server is active
Enabled bool `json:"enabled"`
// Command is the executable to run (e.g., "npx", "python", "/path/to/server")
Command string `json:"command"`
// Args are the arguments to pass to the command
Args []string `json:"args,omitempty"`
// Env are environment variables to set for the server process (stdio only)
Env map[string]string `json:"env,omitempty"`
// EnvFile is the path to a file containing environment variables (stdio only)
EnvFile string `json:"envFile,omitempty"`
// Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
Type string `json:"type,omitempty"`
// URL is used for SSE/HTTP transport
URL string `json:"url,omitempty"`
// Headers are HTTP headers to send with requests (sse/http only)
Headers map[string]string `json:"headers,omitempty"`
}
// MCPConfig defines configuration for all MCP servers
type MCPConfig struct {
// Enabled globally enables/disables MCP integration
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
// Servers is a map of server name to server configuration
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
}
func DefaultConfig() *Config {
@@ -321,6 +350,38 @@ func DefaultConfig() *Config {
MaxResults: 5,
},
},
MCP: MCPConfig{
Enabled: false,
Servers: map[string]MCPServerConfig{
"filesystem": {
Enabled: false,
Command: "npx",
Args: []string{"-y", "@modelcontextprotocol/server-filesystem", "/tmp"},
Env: map[string]string{},
},
"github": {
Enabled: false,
Command: "npx",
Args: []string{"-y", "@modelcontextprotocol/server-github"},
Env: map[string]string{
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN",
},
},
"brave-search": {
Enabled: false,
Command: "npx",
Args: []string{"-y", "@modelcontextprotocol/server-brave-search"},
Env: map[string]string{
"BRAVE_API_KEY": "YOUR_BRAVE_API_KEY",
},
},
"postgres": {
Enabled: false,
Command: "npx",
Args: []string{"-y", "@modelcontextprotocol/server-postgres", "postgresql://user:password@localhost/dbname"},
},
},
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
+432
View File
@@ -0,0 +1,432 @@
package mcp
import (
"bufio"
"context"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// headerTransport is an http.RoundTripper that adds custom headers to requests
type headerTransport struct {
base http.RoundTripper
headers map[string]string
}
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
req = req.Clone(req.Context())
// Add custom headers
for key, value := range t.headers {
req.Header.Set(key, value)
}
// Use the base transport
base := t.base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}
// loadEnvFile loads environment variables from a file in .env format
// Each line should be in the format: KEY=value
// Lines starting with # are comments
// Empty lines are ignored
func loadEnvFile(path string) (map[string]string, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open env file: %w", err)
}
defer file.Close()
envVars := make(map[string]string)
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Parse KEY=value
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// Remove surrounding quotes if present
if len(value) >= 2 {
if (value[0] == '"' && value[len(value)-1] == '"') ||
(value[0] == '\'' && value[len(value)-1] == '\'') {
value = value[1 : len(value)-1]
}
}
envVars[key] = value
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading env file: %w", err)
}
return envVars, nil
}
// ServerConnection represents a connection to an MCP server
type ServerConnection struct {
Name string
Client *mcp.Client
Session *mcp.ClientSession
Tools []*mcp.Tool
}
// Manager manages multiple MCP server connections
type Manager struct {
servers map[string]*ServerConnection
mu sync.RWMutex
}
// NewManager creates a new MCP manager
func NewManager() *Manager {
return &Manager{
servers: make(map[string]*ServerConnection),
}
}
// LoadFromConfig loads MCP servers from configuration
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
if !cfg.Tools.MCP.Enabled {
logger.InfoCF("mcp", "MCP integration is disabled", nil)
return nil
}
if len(cfg.Tools.MCP.Servers) == 0 {
logger.InfoCF("mcp", "No MCP servers configured", nil)
return nil
}
logger.InfoCF("mcp", "Initializing MCP servers",
map[string]interface{}{
"count": len(cfg.Tools.MCP.Servers),
})
var wg sync.WaitGroup
errs := make(chan error, len(cfg.Tools.MCP.Servers))
for name, serverCfg := range cfg.Tools.MCP.Servers {
if !serverCfg.Enabled {
logger.DebugCF("mcp", "Skipping disabled server",
map[string]interface{}{
"server": name,
})
continue
}
wg.Add(1)
go func(name string, serverCfg config.MCPServerConfig) {
defer wg.Done()
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
logger.ErrorCF("mcp", "Failed to connect to MCP server",
map[string]interface{}{
"server": name,
"error": err.Error(),
})
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
}
}(name, serverCfg)
}
wg.Wait()
close(errs)
// Collect errors
var allErrors []error
for err := range errs {
allErrors = append(allErrors, err)
}
if len(allErrors) > 0 {
logger.WarnCF("mcp", "Some MCP servers failed to connect",
map[string]interface{}{
"failed": len(allErrors),
"total": len(cfg.Tools.MCP.Servers),
})
// Don't fail completely if some servers fail to connect
}
connectedCount := len(m.GetServers())
logger.InfoCF("mcp", "MCP server initialization complete",
map[string]interface{}{
"connected": connectedCount,
"total": len(cfg.Tools.MCP.Servers),
})
return nil
}
// ConnectServer connects to a single MCP server
func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCPServerConfig) error {
logger.InfoCF("mcp", "Connecting to MCP server",
map[string]interface{}{
"server": name,
"command": cfg.Command,
"args": cfg.Args,
})
// Create client
client := mcp.NewClient(&mcp.Implementation{
Name: "picoclaw",
Version: "1.0.0",
}, nil)
// Create transport based on configuration
// Auto-detect transport type if not explicitly specified
var transport mcp.Transport
transportType := cfg.Type
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
if transportType == "" {
if cfg.URL != "" {
transportType = "sse"
} else if cfg.Command != "" {
transportType = "stdio"
} else {
return fmt.Errorf("either URL or command must be provided")
}
}
switch transportType {
case "sse", "http":
if cfg.URL == "" {
return fmt.Errorf("URL is required for SSE/HTTP transport")
}
logger.DebugCF("mcp", "Using SSE/HTTP transport",
map[string]interface{}{
"server": name,
"url": cfg.URL,
})
sseTransport := &mcp.StreamableClientTransport{
Endpoint: cfg.URL,
}
// Add custom headers if provided
if len(cfg.Headers) > 0 {
// Create a custom HTTP client with header-injecting transport
sseTransport.HTTPClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: cfg.Headers,
},
}
logger.DebugCF("mcp", "Added custom HTTP headers",
map[string]interface{}{
"server": name,
"header_count": len(cfg.Headers),
})
}
transport = sseTransport
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("command is required for stdio transport")
}
logger.DebugCF("mcp", "Using stdio transport",
map[string]interface{}{
"server": name,
"command": cfg.Command,
})
// Create command with context
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
// Set environment variables
env := cmd.Environ()
// Load environment variables from file if specified
if cfg.EnvFile != "" {
envVars, err := loadEnvFile(cfg.EnvFile)
if err != nil {
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
}
for k, v := range envVars {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
logger.DebugCF("mcp", "Loaded environment variables from file",
map[string]interface{}{
"server": name,
"envFile": cfg.EnvFile,
"var_count": len(envVars),
})
}
// Environment variables from config override those from file
if len(cfg.Env) > 0 {
for k, v := range cfg.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
}
// Set environment if we added any variables
if len(env) > len(cmd.Environ()) {
cmd.Env = env
}
transport = &mcp.CommandTransport{Command: cmd}
default:
return fmt.Errorf("unsupported transport type: %s (supported: stdio, sse, http)", transportType)
}
// Connect to server
session, err := client.Connect(ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
// Get server info
initResult := session.InitializeResult()
logger.InfoCF("mcp", "Connected to MCP server",
map[string]interface{}{
"server": name,
"serverName": initResult.ServerInfo.Name,
"serverVersion": initResult.ServerInfo.Version,
"protocol": initResult.ProtocolVersion,
})
// List available tools if supported
var tools []*mcp.Tool
if initResult.Capabilities.Tools != nil {
for tool, err := range session.Tools(ctx, nil) {
if err != nil {
logger.WarnCF("mcp", "Error listing tool",
map[string]interface{}{
"server": name,
"error": err.Error(),
})
continue
}
tools = append(tools, tool)
}
logger.InfoCF("mcp", "Listed tools from MCP server",
map[string]interface{}{
"server": name,
"toolCount": len(tools),
})
}
// Store connection
m.mu.Lock()
m.servers[name] = &ServerConnection{
Name: name,
Client: client,
Session: session,
Tools: tools,
}
m.mu.Unlock()
return nil
}
// GetServers returns all connected servers
func (m *Manager) GetServers() map[string]*ServerConnection {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]*ServerConnection, len(m.servers))
for k, v := range m.servers {
result[k] = v
}
return result
}
// GetServer returns a specific server connection
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
conn, ok := m.servers[name]
return conn, ok
}
// CallTool calls a tool on a specific server
func (m *Manager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
conn, ok := m.GetServer(serverName)
if !ok {
return nil, fmt.Errorf("server %s not found", serverName)
}
params := &mcp.CallToolParams{
Name: toolName,
Arguments: arguments,
}
result, err := conn.Session.CallTool(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to call tool: %w", err)
}
return result, nil
}
// Close closes all server connections
func (m *Manager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
logger.InfoCF("mcp", "Closing all MCP server connections",
map[string]interface{}{
"count": len(m.servers),
})
var errs []error
for name, conn := range m.servers {
if err := conn.Session.Close(); err != nil {
logger.ErrorCF("mcp", "Failed to close server connection",
map[string]interface{}{
"server": name,
"error": err.Error(),
})
errs = append(errs, err)
}
}
m.servers = make(map[string]*ServerConnection)
if len(errs) > 0 {
return fmt.Errorf("failed to close %d server(s)", len(errs))
}
return nil
}
// GetAllTools returns all tools from all connected servers
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string][]*mcp.Tool)
for name, conn := range m.servers {
if len(conn.Tools) > 0 {
result[name] = conn.Tools
}
}
return result
}
+182
View File
@@ -0,0 +1,182 @@
package mcp
import (
"os"
"path/filepath"
"testing"
)
func TestLoadEnvFile(t *testing.T) {
tests := []struct {
name string
content string
expected map[string]string
expectErr bool
}{
{
name: "basic env file",
content: `API_KEY=secret123
DATABASE_URL=postgres://localhost/db
PORT=8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "with comments and empty lines",
content: `# This is a comment
API_KEY=secret123
# Another comment
DATABASE_URL=postgres://localhost/db
PORT=8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "with quoted values",
content: `API_KEY="secret with spaces"
NAME='single quoted'
PLAIN=no-quotes`,
expected: map[string]string{
"API_KEY": "secret with spaces",
"NAME": "single quoted",
"PLAIN": "no-quotes",
},
expectErr: false,
},
{
name: "with spaces around equals",
content: `API_KEY = secret123
DATABASE_URL= postgres://localhost/db
PORT =8080`,
expected: map[string]string{
"API_KEY": "secret123",
"DATABASE_URL": "postgres://localhost/db",
"PORT": "8080",
},
expectErr: false,
},
{
name: "invalid format - no equals",
content: `INVALID_LINE`,
expectErr: true,
},
{
name: "empty file",
content: ``,
expected: map[string]string{},
expectErr: false,
},
{
name: "only comments",
content: `# Comment 1
# Comment 2`,
expected: map[string]string{},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
envFile := filepath.Join(tmpDir, ".env")
if err := os.WriteFile(envFile, []byte(tt.content), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
result, err := loadEnvFile(envFile)
if tt.expectErr {
if err == nil {
t.Errorf("Expected error but got none")
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if len(result) != len(tt.expected) {
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
}
for key, expectedValue := range tt.expected {
if actualValue, ok := result[key]; !ok {
t.Errorf("Expected key %s not found", key)
} else if actualValue != expectedValue {
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
}
}
})
}
}
func TestLoadEnvFileNotFound(t *testing.T) {
_, err := loadEnvFile("/nonexistent/file.env")
if err == nil {
t.Error("Expected error for nonexistent file")
}
}
func TestEnvFilePriority(t *testing.T) {
// Create a temporary .env file
tmpDir := t.TempDir()
envFile := filepath.Join(tmpDir, ".env")
envContent := `API_KEY=from_file
DATABASE_URL=from_file
SHARED_VAR=from_file`
if err := os.WriteFile(envFile, []byte(envContent), 0644); err != nil {
t.Fatalf("Failed to create .env file: %v", err)
}
// Load envFile
envVars, err := loadEnvFile(envFile)
if err != nil {
t.Fatalf("Failed to load env file: %v", err)
}
// Verify envFile variables
if envVars["API_KEY"] != "from_file" {
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
}
// Simulate config.Env overriding envFile
configEnv := map[string]string{
"SHARED_VAR": "from_config",
"NEW_VAR": "from_config",
}
// Merge: envFile first, then config overrides
merged := make(map[string]string)
for k, v := range envVars {
merged[k] = v
}
for k, v := range configEnv {
merged[k] = v
}
// Verify priority: config.Env should override envFile
if merged["SHARED_VAR"] != "from_config" {
t.Errorf("Expected SHARED_VAR=from_config (config should override file), got %s", merged["SHARED_VAR"])
}
if merged["API_KEY"] != "from_file" {
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
}
if merged["NEW_VAR"] != "from_config" {
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
}
}
+119
View File
@@ -0,0 +1,119 @@
package tools
import (
"context"
"fmt"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
mcpPkg "github.com/sipeed/picoclaw/pkg/mcp"
)
// MCPManager defines the interface for MCP manager operations
// This allows for easier testing with mock implementations
type MCPManager interface {
CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error)
}
// MCPTool wraps an MCP tool to implement the Tool interface
type MCPTool struct {
manager MCPManager
serverName string
tool *mcp.Tool
}
// NewMCPTool creates a new MCP tool wrapper
func NewMCPTool(manager *mcpPkg.Manager, serverName string, tool *mcp.Tool) *MCPTool {
return &MCPTool{
manager: manager,
serverName: serverName,
tool: tool,
}
}
// Name returns the tool name, prefixed with the server name
func (t *MCPTool) Name() string {
// Prefix with server name to avoid conflicts
return fmt.Sprintf("mcp_%s_%s", t.serverName, t.tool.Name)
}
// Description returns the tool description
func (t *MCPTool) Description() string {
desc := t.tool.Description
if desc == "" {
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
}
// Add server info to description
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
}
// Parameters returns the tool parameters schema
func (t *MCPTool) Parameters() map[string]interface{} {
// The InputSchema is already a JSON Schema object
schema := t.tool.InputSchema
// Convert to map[string]interface{} for compatibility
result := make(map[string]interface{})
// Use reflection to convert the schema
// The schema should already be in the correct format
if schema != nil {
// Attempt to convert directly
if schemaMap, ok := schema.(map[string]interface{}); ok {
return schemaMap
}
// Otherwise, build it manually
result["type"] = "object"
result["properties"] = map[string]interface{}{}
result["required"] = []string{}
} else {
// Default schema when nil
result["type"] = "object"
result["properties"] = map[string]interface{}{}
result["required"] = []string{}
}
return result
}
// Execute executes the MCP tool
func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
if err != nil {
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
}
// Handle error result from server
if result.IsError {
errMsg := extractContentText(result.Content)
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
}
// Extract text content from result
output := extractContentText(result.Content)
return &ToolResult{
ForLLM: output,
IsError: false,
}
}
// extractContentText extracts text from MCP content array
func extractContentText(content []mcp.Content) string {
var parts []string
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
parts = append(parts, v.Text)
case *mcp.ImageContent:
// For images, just indicate that an image was returned
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
default:
// For other content types, use string representation
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
}
}
return strings.Join(parts, "\n")
}
+456
View File
@@ -0,0 +1,456 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MockMCPManager is a mock implementation of MCPManager interface for testing
type MockMCPManager struct {
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error)
}
func (m *MockMCPManager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
if m.callToolFunc != nil {
return m.callToolFunc(ctx, serverName, toolName, arguments)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "mock result"},
},
IsError: false,
}, nil
}
// newMCPToolForTest creates an MCP tool for testing with mock manager
func newMCPToolForTest(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
return &MCPTool{
manager: manager,
serverName: serverName,
tool: tool,
}
}
// TestNewMCPTool verifies MCP tool creation
func TestNewMCPTool(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: "A test tool",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"input": map[string]interface{}{
"type": "string",
"description": "Test input",
},
},
},
}
mcpTool := newMCPToolForTest(manager, "test_server", tool)
if mcpTool == nil {
t.Fatal("NewMCPTool should not return nil")
}
// Verify tool properties we can access
if mcpTool.Name() != "mcp_test_server_test_tool" {
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
}
}
// TestMCPTool_Name verifies tool name with server prefix
func TestMCPTool_Name(t *testing.T) {
tests := []struct {
name string
serverName string
toolName string
expected string
}{
{
name: "simple name",
serverName: "github",
toolName: "create_issue",
expected: "mcp_github_create_issue",
},
{
name: "filesystem server",
serverName: "filesystem",
toolName: "read_file",
expected: "mcp_filesystem_read_file",
},
{
name: "remote server",
serverName: "remote-api",
toolName: "fetch_data",
expected: "mcp_remote-api_fetch_data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: tt.toolName}
mcpTool := newMCPToolForTest(manager, tt.serverName, tool)
result := mcpTool.Name()
if result != tt.expected {
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
}
})
}
}
// TestMCPTool_Description verifies tool description generation
func TestMCPTool_Description(t *testing.T) {
tests := []struct {
name string
serverName string
toolDescription string
expectContains []string
}{
{
name: "with description",
serverName: "github",
toolDescription: "Create a GitHub issue",
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
},
{
name: "empty description",
serverName: "filesystem",
toolDescription: "",
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
Description: tt.toolDescription,
}
mcpTool := newMCPToolForTest(manager, tt.serverName, tool)
result := mcpTool.Description()
for _, expected := range tt.expectContains {
if !strings.Contains(result, expected) {
t.Errorf("Description should contain '%s', got: %s", expected, result)
}
}
})
}
}
// TestMCPTool_Parameters verifies parameter schema conversion
func TestMCPTool_Parameters(t *testing.T) {
tests := []struct {
name string
inputSchema interface{}
expectType string
}{
{
name: "map schema",
inputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "Search query",
},
},
"required": []string{"query"},
},
expectType: "object",
},
{
name: "nil schema",
inputSchema: nil,
expectType: "object",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: tt.inputSchema,
}
mcpTool := newMCPToolForTest(manager, "test_server", tool)
params := mcpTool.Parameters()
if params == nil {
t.Fatal("Parameters should not be nil")
}
if params["type"] != tt.expectType {
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
}
})
}
}
// TestMCPTool_Execute_Success tests successful tool execution
func TestMCPTool_Execute_Success(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
// Verify correct parameters passed
if serverName != "github" {
t.Errorf("Expected serverName 'github', got '%s'", serverName)
}
if toolName != "search_repos" {
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Found 3 repositories"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{
Name: "search_repos",
Description: "Search GitHub repositories",
}
mcpTool := newMCPToolForTest(manager, "github", tool)
ctx := context.Background()
args := map[string]interface{}{
"query": "golang mcp",
}
result := mcpTool.Execute(ctx, args)
if result == nil {
t.Fatal("Result should not be nil")
}
if result.IsError {
t.Errorf("Expected no error, got error: %s", result.ForLLM)
}
if result.ForLLM != "Found 3 repositories" {
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
}
}
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
func TestMCPTool_Execute_ManagerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
return nil, fmt.Errorf("connection failed")
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := newMCPToolForTest(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]interface{}{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "connection failed") {
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_ServerError tests execution when server returns error
func TestMCPTool_Execute_ServerError(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Invalid API key"},
},
IsError: true,
}, nil
},
}
tool := &mcp.Tool{Name: "test_tool"}
mcpTool := newMCPToolForTest(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]interface{}{})
if result == nil {
t.Fatal("Result should not be nil")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Invalid API key") {
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
}
}
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
manager := &MockMCPManager{
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "First line"},
&mcp.TextContent{Text: "Second line"},
&mcp.TextContent{Text: "Third line"},
},
IsError: false,
}, nil
},
}
tool := &mcp.Tool{Name: "multi_output"}
mcpTool := newMCPToolForTest(manager, "test_server", tool)
ctx := context.Background()
result := mcpTool.Execute(ctx, map[string]interface{}{})
if result.IsError {
t.Errorf("Expected no error, got: %s", result.ForLLM)
}
expected := "First line\nSecond line\nThird line"
if result.ForLLM != expected {
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
}
}
// TestExtractContentText_TextContent tests text content extraction
func TestExtractContentText_TextContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Hello World"},
&mcp.TextContent{Text: "Second message"},
}
result := extractContentText(content)
expected := "Hello World\nSecond message"
if result != expected {
t.Errorf("Expected '%s', got '%s'", expected, result)
}
}
// TestExtractContentText_ImageContent tests image content extraction
func TestExtractContentText_ImageContent(t *testing.T) {
content := []mcp.Content{
&mcp.ImageContent{
Data: []byte("base64data"),
MIMEType: "image/png",
},
}
result := extractContentText(content)
if !strings.Contains(result, "[Image:") {
t.Errorf("Expected image indicator, got: %s", result)
}
if !strings.Contains(result, "image/png") {
t.Errorf("Expected MIME type in output, got: %s", result)
}
}
// TestExtractContentText_MixedContent tests mixed content types
func TestExtractContentText_MixedContent(t *testing.T) {
content := []mcp.Content{
&mcp.TextContent{Text: "Description"},
&mcp.ImageContent{
Data: []byte("data"),
MIMEType: "image/jpeg",
},
&mcp.TextContent{Text: "More text"},
}
result := extractContentText(content)
if !strings.Contains(result, "Description") {
t.Errorf("Should contain text content, got: %s", result)
}
if !strings.Contains(result, "[Image:") {
t.Errorf("Should contain image indicator, got: %s", result)
}
if !strings.Contains(result, "More text") {
t.Errorf("Should contain second text, got: %s", result)
}
}
// TestExtractContentText_EmptyContent tests empty content array
func TestExtractContentText_EmptyContent(t *testing.T) {
content := []mcp.Content{}
result := extractContentText(content)
if result != "" {
t.Errorf("Expected empty string for empty content, got: %s", result)
}
}
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
func TestMCPTool_InterfaceCompliance(t *testing.T) {
manager := &MockMCPManager{}
tool := &mcp.Tool{Name: "test"}
mcpTool := newMCPToolForTest(manager, "test_server", tool)
// Verify it implements Tool interface
var _ Tool = mcpTool
}
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
manager := &MockMCPManager{}
schema := map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
"description": "The name parameter",
},
},
"required": []string{"name"},
}
tool := &mcp.Tool{
Name: "test_tool",
InputSchema: schema,
}
mcpTool := newMCPToolForTest(manager, "test_server", tool)
params := mcpTool.Parameters()
// Should return the schema as-is when it's already a map
if params["type"] != "object" {
t.Errorf("Expected type 'object', got '%v'", params["type"])
}
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Error("Properties should be a map")
}
nameParam, ok := props["name"].(map[string]interface{})
if !ok {
t.Error("Name parameter should exist")
}
if nameParam["type"] != "string" {
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
}
}