mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(mcp): add Model Context Protocol integration
Implement comprehensive MCP support with stdio/HTTP/SSE transports, environment variable configuration (env and envFile), custom headers, tool registration, and automatic resource cleanup. Includes full test coverage and VSCode-compatible configuration. - Added pkg/mcp/manager.go for server lifecycle management - Added pkg/tools/mcp_tool.go for tool wrapping - Integrated into agent loop with cleanup - Support for envFile loading (.env format) - Headers injection for HTTP/SSE authentication - Example configs for filesystem, github, brave-search, postgres
This commit is contained in:
@@ -14,7 +14,9 @@
|
||||
"enabled": false,
|
||||
"token": "YOUR_TELEGRAM_BOT_TOKEN",
|
||||
"proxy": "",
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
"allow_from": [
|
||||
"YOUR_USER_ID"
|
||||
]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": false,
|
||||
@@ -115,6 +117,62 @@
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"mcp": {
|
||||
"enabled": false,
|
||||
"servers": {
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/tmp"
|
||||
],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-github"
|
||||
],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
|
||||
},
|
||||
"envFile": ".env"
|
||||
},
|
||||
"brave-search": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-brave-search"
|
||||
],
|
||||
"env": {
|
||||
"BRAVE_API_KEY": "YOUR_BRAVE_API_KEY"
|
||||
}
|
||||
},
|
||||
"postgres": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-postgres",
|
||||
"postgresql://user:password@localhost/dbname"
|
||||
]
|
||||
},
|
||||
"remote-http-example": {
|
||||
"enabled": false,
|
||||
"url": "https://mcp-server.example.com/stream",
|
||||
"type": "sse",
|
||||
"headers": {
|
||||
"Authorization": "Bearer YOUR_TOKEN",
|
||||
"X-Custom-Header": "custom-value"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
@@ -129,4 +187,4 @@
|
||||
"host": "0.0.0.0",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0
|
||||
github.com/mymmrac/telego v1.6.0
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/openai/openai-go/v3 v3.22.0
|
||||
@@ -19,7 +20,7 @@ require (
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
)
|
||||
|
||||
|
||||
require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
@@ -28,9 +29,9 @@ require (
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/github/copilot-sdk/go v0.1.23
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/grbit/go-json v0.11.0 // indirect
|
||||
github.com/klauspost/compress v1.18.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
@@ -47,5 +48,4 @@ require (
|
||||
golang.org/x/net v0.50.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
|
||||
)
|
||||
|
||||
@@ -43,6 +43,8 @@ github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
|
||||
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||
@@ -58,6 +60,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -84,6 +88,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE=
|
||||
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
|
||||
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
@@ -141,6 +147,8 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB
|
||||
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
@@ -228,6 +236,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
|
||||
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
+43
-3
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
@@ -40,6 +41,7 @@ type AgentLoop struct {
|
||||
state *state.Manager
|
||||
contextBuilder *ContextBuilder
|
||||
tools *tools.ToolRegistry
|
||||
mcpManager *mcp.Manager // MCP server manager for resource cleanup
|
||||
running atomic.Bool
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
}
|
||||
@@ -58,7 +60,7 @@ type processOptions struct {
|
||||
|
||||
// createToolRegistry creates a tool registry with common tools.
|
||||
// This is shared between main agent and subagents.
|
||||
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
|
||||
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus, mcpManager *mcp.Manager) *tools.ToolRegistry {
|
||||
registry := tools.NewToolRegistry()
|
||||
|
||||
// File system tools
|
||||
@@ -99,6 +101,23 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg
|
||||
})
|
||||
registry.Register(messageTool)
|
||||
|
||||
// Register MCP tools from all connected servers
|
||||
if mcpManager != nil {
|
||||
servers := mcpManager.GetServers()
|
||||
for serverName, conn := range servers {
|
||||
for _, tool := range conn.Tools {
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
registry.Register(mcpTool)
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]interface{}{
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
@@ -108,12 +127,22 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
|
||||
// Initialize MCP Manager and load servers
|
||||
mcpManager := mcp.NewManager()
|
||||
ctx := context.Background()
|
||||
if err := mcpManager.LoadFromConfig(ctx, cfg); err != nil {
|
||||
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Create tool registry for main agent
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager)
|
||||
|
||||
// Create subagent manager with its own tool registry
|
||||
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager)
|
||||
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
||||
subagentManager.SetTools(subagentTools)
|
||||
|
||||
@@ -145,6 +174,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
state: stateManager,
|
||||
contextBuilder: contextBuilder,
|
||||
tools: toolsRegistry,
|
||||
mcpManager: mcpManager,
|
||||
summarizing: sync.Map{},
|
||||
}
|
||||
}
|
||||
@@ -193,6 +223,16 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
|
||||
func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
|
||||
// Clean up MCP connections
|
||||
if al.mcpManager != nil {
|
||||
if err := al.mcpManager.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
|
||||
@@ -212,6 +212,35 @@ type WebToolsConfig struct {
|
||||
|
||||
type ToolsConfig struct {
|
||||
Web WebToolsConfig `json:"web"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
// MCPServerConfig defines configuration for a single MCP server
|
||||
type MCPServerConfig struct {
|
||||
// Enabled indicates whether this MCP server is active
|
||||
Enabled bool `json:"enabled"`
|
||||
// Command is the executable to run (e.g., "npx", "python", "/path/to/server")
|
||||
Command string `json:"command"`
|
||||
// Args are the arguments to pass to the command
|
||||
Args []string `json:"args,omitempty"`
|
||||
// Env are environment variables to set for the server process (stdio only)
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
// EnvFile is the path to a file containing environment variables (stdio only)
|
||||
EnvFile string `json:"envFile,omitempty"`
|
||||
// Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
|
||||
Type string `json:"type,omitempty"`
|
||||
// URL is used for SSE/HTTP transport
|
||||
URL string `json:"url,omitempty"`
|
||||
// Headers are HTTP headers to send with requests (sse/http only)
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
// MCPConfig defines configuration for all MCP servers
|
||||
type MCPConfig struct {
|
||||
// Enabled globally enables/disables MCP integration
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
@@ -321,6 +350,38 @@ func DefaultConfig() *Config {
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: false,
|
||||
Servers: map[string]MCPServerConfig{
|
||||
"filesystem": {
|
||||
Enabled: false,
|
||||
Command: "npx",
|
||||
Args: []string{"-y", "@modelcontextprotocol/server-filesystem", "/tmp"},
|
||||
Env: map[string]string{},
|
||||
},
|
||||
"github": {
|
||||
Enabled: false,
|
||||
Command: "npx",
|
||||
Args: []string{"-y", "@modelcontextprotocol/server-github"},
|
||||
Env: map[string]string{
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN",
|
||||
},
|
||||
},
|
||||
"brave-search": {
|
||||
Enabled: false,
|
||||
Command: "npx",
|
||||
Args: []string{"-y", "@modelcontextprotocol/server-brave-search"},
|
||||
Env: map[string]string{
|
||||
"BRAVE_API_KEY": "YOUR_BRAVE_API_KEY",
|
||||
},
|
||||
},
|
||||
"postgres": {
|
||||
Enabled: false,
|
||||
Command: "npx",
|
||||
Args: []string{"-y", "@modelcontextprotocol/server-postgres", "postgresql://user:password@localhost/dbname"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -0,0 +1,432 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// headerTransport is an http.RoundTripper that adds custom headers to requests
|
||||
type headerTransport struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
// Add custom headers
|
||||
for key, value := range t.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Use the base transport
|
||||
base := t.base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// loadEnvFile loads environment variables from a file in .env format
|
||||
// Each line should be in the format: KEY=value
|
||||
// Lines starting with # are comments
|
||||
// Empty lines are ignored
|
||||
func loadEnvFile(path string) (map[string]string, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
envVars := make(map[string]string)
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse KEY=value
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
// Remove surrounding quotes if present
|
||||
if len(value) >= 2 {
|
||||
if (value[0] == '"' && value[len(value)-1] == '"') ||
|
||||
(value[0] == '\'' && value[len(value)-1] == '\'') {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
envVars[key] = value
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading env file: %w", err)
|
||||
}
|
||||
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
// ServerConnection represents a connection to an MCP server
|
||||
type ServerConnection struct {
|
||||
Name string
|
||||
Client *mcp.Client
|
||||
Session *mcp.ClientSession
|
||||
Tools []*mcp.Tool
|
||||
}
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
type Manager struct {
|
||||
servers map[string]*ServerConnection
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
servers: make(map[string]*ServerConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromConfig loads MCP servers from configuration
|
||||
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
|
||||
if !cfg.Tools.MCP.Enabled {
|
||||
logger.InfoCF("mcp", "MCP integration is disabled", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(cfg.Tools.MCP.Servers) == 0 {
|
||||
logger.InfoCF("mcp", "No MCP servers configured", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Initializing MCP servers",
|
||||
map[string]interface{}{
|
||||
"count": len(cfg.Tools.MCP.Servers),
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, len(cfg.Tools.MCP.Servers))
|
||||
|
||||
for name, serverCfg := range cfg.Tools.MCP.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
logger.DebugCF("mcp", "Skipping disabled server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(name string, serverCfg config.MCPServerConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to connect to MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
|
||||
}
|
||||
}(name, serverCfg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
// Collect errors
|
||||
var allErrors []error
|
||||
for err := range errs {
|
||||
allErrors = append(allErrors, err)
|
||||
}
|
||||
|
||||
if len(allErrors) > 0 {
|
||||
logger.WarnCF("mcp", "Some MCP servers failed to connect",
|
||||
map[string]interface{}{
|
||||
"failed": len(allErrors),
|
||||
"total": len(cfg.Tools.MCP.Servers),
|
||||
})
|
||||
// Don't fail completely if some servers fail to connect
|
||||
}
|
||||
|
||||
connectedCount := len(m.GetServers())
|
||||
logger.InfoCF("mcp", "MCP server initialization complete",
|
||||
map[string]interface{}{
|
||||
"connected": connectedCount,
|
||||
"total": len(cfg.Tools.MCP.Servers),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectServer connects to a single MCP server
|
||||
func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCPServerConfig) error {
|
||||
logger.InfoCF("mcp", "Connecting to MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
"args": cfg.Args,
|
||||
})
|
||||
|
||||
// Create client
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: "picoclaw",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
// Create transport based on configuration
|
||||
// Auto-detect transport type if not explicitly specified
|
||||
var transport mcp.Transport
|
||||
transportType := cfg.Type
|
||||
|
||||
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
|
||||
if transportType == "" {
|
||||
if cfg.URL != "" {
|
||||
transportType = "sse"
|
||||
} else if cfg.Command != "" {
|
||||
transportType = "stdio"
|
||||
} else {
|
||||
return fmt.Errorf("either URL or command must be provided")
|
||||
}
|
||||
}
|
||||
|
||||
switch transportType {
|
||||
case "sse", "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("URL is required for SSE/HTTP transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using SSE/HTTP transport",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"url": cfg.URL,
|
||||
})
|
||||
|
||||
sseTransport := &mcp.StreamableClientTransport{
|
||||
Endpoint: cfg.URL,
|
||||
}
|
||||
|
||||
// Add custom headers if provided
|
||||
if len(cfg.Headers) > 0 {
|
||||
// Create a custom HTTP client with header-injecting transport
|
||||
sseTransport.HTTPClient = &http.Client{
|
||||
Transport: &headerTransport{
|
||||
base: http.DefaultTransport,
|
||||
headers: cfg.Headers,
|
||||
},
|
||||
}
|
||||
logger.DebugCF("mcp", "Added custom HTTP headers",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"header_count": len(cfg.Headers),
|
||||
})
|
||||
}
|
||||
|
||||
transport = sseTransport
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("command is required for stdio transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using stdio transport",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
})
|
||||
// Create command with context
|
||||
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
|
||||
|
||||
// Set environment variables
|
||||
env := cmd.Environ()
|
||||
|
||||
// Load environment variables from file if specified
|
||||
if cfg.EnvFile != "" {
|
||||
envVars, err := loadEnvFile(cfg.EnvFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
||||
}
|
||||
for k, v := range envVars {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
logger.DebugCF("mcp", "Loaded environment variables from file",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"envFile": cfg.EnvFile,
|
||||
"var_count": len(envVars),
|
||||
})
|
||||
}
|
||||
|
||||
// Environment variables from config override those from file
|
||||
if len(cfg.Env) > 0 {
|
||||
for k, v := range cfg.Env {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
// Set environment if we added any variables
|
||||
if len(env) > len(cmd.Environ()) {
|
||||
cmd.Env = env
|
||||
}
|
||||
|
||||
transport = &mcp.CommandTransport{Command: cmd}
|
||||
default:
|
||||
return fmt.Errorf("unsupported transport type: %s (supported: stdio, sse, http)", transportType)
|
||||
}
|
||||
|
||||
// Connect to server
|
||||
session, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// Get server info
|
||||
initResult := session.InitializeResult()
|
||||
logger.InfoCF("mcp", "Connected to MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"serverName": initResult.ServerInfo.Name,
|
||||
"serverVersion": initResult.ServerInfo.Version,
|
||||
"protocol": initResult.ProtocolVersion,
|
||||
})
|
||||
|
||||
// List available tools if supported
|
||||
var tools []*mcp.Tool
|
||||
if initResult.Capabilities.Tools != nil {
|
||||
for tool, err := range session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "Error listing tool",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Listed tools from MCP server",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"toolCount": len(tools),
|
||||
})
|
||||
}
|
||||
|
||||
// Store connection
|
||||
m.mu.Lock()
|
||||
m.servers[name] = &ServerConnection{
|
||||
Name: name,
|
||||
Client: client,
|
||||
Session: session,
|
||||
Tools: tools,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServers returns all connected servers
|
||||
func (m *Manager) GetServers() map[string]*ServerConnection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ServerConnection, len(m.servers))
|
||||
for k, v := range m.servers {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetServer returns a specific server connection
|
||||
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
conn, ok := m.servers[name]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// CallTool calls a tool on a specific server
|
||||
func (m *Manager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
||||
conn, ok := m.GetServer(serverName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s not found", serverName)
|
||||
}
|
||||
|
||||
params := &mcp.CallToolParams{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
result, err := conn.Session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call tool: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Close closes all server connections
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
logger.InfoCF("mcp", "Closing all MCP server connections",
|
||||
map[string]interface{}{
|
||||
"count": len(m.servers),
|
||||
})
|
||||
|
||||
var errs []error
|
||||
for name, conn := range m.servers {
|
||||
if err := conn.Session.Close(); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to close server connection",
|
||||
map[string]interface{}{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.servers = make(map[string]*ServerConnection)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to close %d server(s)", len(errs))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllTools returns all tools from all connected servers
|
||||
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string][]*mcp.Tool)
|
||||
for name, conn := range m.servers {
|
||||
if len(conn.Tools) > 0 {
|
||||
result[name] = conn.Tools
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic env file",
|
||||
content: `API_KEY=secret123
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with comments and empty lines",
|
||||
content: `# This is a comment
|
||||
API_KEY=secret123
|
||||
|
||||
# Another comment
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with quoted values",
|
||||
content: `API_KEY="secret with spaces"
|
||||
NAME='single quoted'
|
||||
PLAIN=no-quotes`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret with spaces",
|
||||
"NAME": "single quoted",
|
||||
"PLAIN": "no-quotes",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with spaces around equals",
|
||||
content: `API_KEY = secret123
|
||||
DATABASE_URL= postgres://localhost/db
|
||||
PORT =8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format - no equals",
|
||||
content: `INVALID_LINE`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
content: ``,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "only comments",
|
||||
content: `# Comment 1
|
||||
# Comment 2`,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(tt.content), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
result, err := loadEnvFile(envFile)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
|
||||
}
|
||||
|
||||
for key, expectedValue := range tt.expected {
|
||||
if actualValue, ok := result[key]; !ok {
|
||||
t.Errorf("Expected key %s not found", key)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvFileNotFound(t *testing.T) {
|
||||
_, err := loadEnvFile("/nonexistent/file.env")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvFilePriority(t *testing.T) {
|
||||
// Create a temporary .env file
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
envContent := `API_KEY=from_file
|
||||
DATABASE_URL=from_file
|
||||
SHARED_VAR=from_file`
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(envContent), 0644); err != nil {
|
||||
t.Fatalf("Failed to create .env file: %v", err)
|
||||
}
|
||||
|
||||
// Load envFile
|
||||
envVars, err := loadEnvFile(envFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load env file: %v", err)
|
||||
}
|
||||
|
||||
// Verify envFile variables
|
||||
if envVars["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
|
||||
}
|
||||
|
||||
// Simulate config.Env overriding envFile
|
||||
configEnv := map[string]string{
|
||||
"SHARED_VAR": "from_config",
|
||||
"NEW_VAR": "from_config",
|
||||
}
|
||||
|
||||
// Merge: envFile first, then config overrides
|
||||
merged := make(map[string]string)
|
||||
for k, v := range envVars {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range configEnv {
|
||||
merged[k] = v
|
||||
}
|
||||
|
||||
// Verify priority: config.Env should override envFile
|
||||
if merged["SHARED_VAR"] != "from_config" {
|
||||
t.Errorf("Expected SHARED_VAR=from_config (config should override file), got %s", merged["SHARED_VAR"])
|
||||
}
|
||||
if merged["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
|
||||
}
|
||||
if merged["NEW_VAR"] != "from_config" {
|
||||
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
mcpPkg "github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
// MCPManager defines the interface for MCP manager operations
|
||||
// This allows for easier testing with mock implementations
|
||||
type MCPManager interface {
|
||||
CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
// MCPTool wraps an MCP tool to implement the Tool interface
|
||||
type MCPTool struct {
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
func NewMCPTool(manager *mcpPkg.Manager, serverName string, tool *mcp.Tool) *MCPTool {
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the tool name, prefixed with the server name
|
||||
func (t *MCPTool) Name() string {
|
||||
// Prefix with server name to avoid conflicts
|
||||
return fmt.Sprintf("mcp_%s_%s", t.serverName, t.tool.Name)
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *MCPTool) Description() string {
|
||||
desc := t.tool.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
|
||||
}
|
||||
// Add server info to description
|
||||
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *MCPTool) Parameters() map[string]interface{} {
|
||||
// The InputSchema is already a JSON Schema object
|
||||
schema := t.tool.InputSchema
|
||||
|
||||
// Convert to map[string]interface{} for compatibility
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Use reflection to convert the schema
|
||||
// The schema should already be in the correct format
|
||||
if schema != nil {
|
||||
// Attempt to convert directly
|
||||
if schemaMap, ok := schema.(map[string]interface{}); ok {
|
||||
return schemaMap
|
||||
}
|
||||
|
||||
// Otherwise, build it manually
|
||||
result["type"] = "object"
|
||||
result["properties"] = map[string]interface{}{}
|
||||
result["required"] = []string{}
|
||||
} else {
|
||||
// Default schema when nil
|
||||
result["type"] = "object"
|
||||
result["properties"] = map[string]interface{}{}
|
||||
result["required"] = []string{}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Execute executes the MCP tool
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
// Handle error result from server
|
||||
if result.IsError {
|
||||
errMsg := extractContentText(result.Content)
|
||||
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
|
||||
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
|
||||
}
|
||||
|
||||
// Extract text content from result
|
||||
output := extractContentText(result.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
// extractContentText extracts text from MCP content array
|
||||
func extractContentText(content []mcp.Content) string {
|
||||
var parts []string
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
parts = append(parts, v.Text)
|
||||
case *mcp.ImageContent:
|
||||
// For images, just indicate that an image was returned
|
||||
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
|
||||
default:
|
||||
// For other content types, use string representation
|
||||
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -0,0 +1,456 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MockMCPManager is a mock implementation of MCPManager interface for testing
|
||||
type MockMCPManager struct {
|
||||
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
func (m *MockMCPManager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
||||
if m.callToolFunc != nil {
|
||||
return m.callToolFunc(ctx, serverName, toolName, arguments)
|
||||
}
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "mock result"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newMCPToolForTest creates an MCP tool for testing with mock manager
|
||||
func newMCPToolForTest(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewMCPTool verifies MCP tool creation
|
||||
func TestNewMCPTool(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"input": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Test input",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := newMCPToolForTest(manager, "test_server", tool)
|
||||
|
||||
if mcpTool == nil {
|
||||
t.Fatal("NewMCPTool should not return nil")
|
||||
}
|
||||
// Verify tool properties we can access
|
||||
if mcpTool.Name() != "mcp_test_server_test_tool" {
|
||||
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Name verifies tool name with server prefix
|
||||
func TestMCPTool_Name(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple name",
|
||||
serverName: "github",
|
||||
toolName: "create_issue",
|
||||
expected: "mcp_github_create_issue",
|
||||
},
|
||||
{
|
||||
name: "filesystem server",
|
||||
serverName: "filesystem",
|
||||
toolName: "read_file",
|
||||
expected: "mcp_filesystem_read_file",
|
||||
},
|
||||
{
|
||||
name: "remote server",
|
||||
serverName: "remote-api",
|
||||
toolName: "fetch_data",
|
||||
expected: "mcp_remote-api_fetch_data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: tt.toolName}
|
||||
mcpTool := newMCPToolForTest(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Description verifies tool description generation
|
||||
func TestMCPTool_Description(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolDescription string
|
||||
expectContains []string
|
||||
}{
|
||||
{
|
||||
name: "with description",
|
||||
serverName: "github",
|
||||
toolDescription: "Create a GitHub issue",
|
||||
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
|
||||
},
|
||||
{
|
||||
name: "empty description",
|
||||
serverName: "filesystem",
|
||||
toolDescription: "",
|
||||
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: tt.toolDescription,
|
||||
}
|
||||
mcpTool := newMCPToolForTest(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Description()
|
||||
|
||||
for _, expected := range tt.expectContains {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Description should contain '%s', got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters verifies parameter schema conversion
|
||||
func TestMCPTool_Parameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputSchema interface{}
|
||||
expectType string
|
||||
}{
|
||||
{
|
||||
name: "map schema",
|
||||
inputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
expectType: "object",
|
||||
},
|
||||
{
|
||||
name: "nil schema",
|
||||
inputSchema: nil,
|
||||
expectType: "object",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: tt.inputSchema,
|
||||
}
|
||||
mcpTool := newMCPToolForTest(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
if params == nil {
|
||||
t.Fatal("Parameters should not be nil")
|
||||
}
|
||||
|
||||
if params["type"] != tt.expectType {
|
||||
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_Success tests successful tool execution
|
||||
func TestMCPTool_Execute_Success(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
||||
// Verify correct parameters passed
|
||||
if serverName != "github" {
|
||||
t.Errorf("Expected serverName 'github', got '%s'", serverName)
|
||||
}
|
||||
if toolName != "search_repos" {
|
||||
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Found 3 repositories"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "search_repos",
|
||||
Description: "Search GitHub repositories",
|
||||
}
|
||||
mcpTool := newMCPToolForTest(manager, "github", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"query": "golang mcp",
|
||||
}
|
||||
|
||||
result := mcpTool.Execute(ctx, args)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "Found 3 repositories" {
|
||||
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
|
||||
func TestMCPTool_Execute_ManagerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
||||
return nil, fmt.Errorf("connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := newMCPToolForTest(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]interface{}{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
|
||||
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "connection failed") {
|
||||
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ServerError tests execution when server returns error
|
||||
func TestMCPTool_Execute_ServerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Invalid API key"},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := newMCPToolForTest(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]interface{}{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
|
||||
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Invalid API key") {
|
||||
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
|
||||
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "First line"},
|
||||
&mcp.TextContent{Text: "Second line"},
|
||||
&mcp.TextContent{Text: "Third line"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "multi_output"}
|
||||
mcpTool := newMCPToolForTest(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]interface{}{})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
expected := "First line\nSecond line\nThird line"
|
||||
if result.ForLLM != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_TextContent tests text content extraction
|
||||
func TestExtractContentText_TextContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Hello World"},
|
||||
&mcp.TextContent{Text: "Second message"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
expected := "Hello World\nSecond message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_ImageContent tests image content extraction
|
||||
func TestExtractContentText_ImageContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("base64data"),
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Expected image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "image/png") {
|
||||
t.Errorf("Expected MIME type in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_MixedContent tests mixed content types
|
||||
func TestExtractContentText_MixedContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Description"},
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("data"),
|
||||
MIMEType: "image/jpeg",
|
||||
},
|
||||
&mcp.TextContent{Text: "More text"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "Description") {
|
||||
t.Errorf("Should contain text content, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Should contain image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "More text") {
|
||||
t.Errorf("Should contain second text, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_EmptyContent tests empty content array
|
||||
func TestExtractContentText_EmptyContent(t *testing.T) {
|
||||
content := []mcp.Content{}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty string for empty content, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
|
||||
func TestMCPTool_InterfaceCompliance(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: "test"}
|
||||
mcpTool := newMCPToolForTest(manager, "test_server", tool)
|
||||
|
||||
// Verify it implements Tool interface
|
||||
var _ Tool = mcpTool
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
|
||||
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"name": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The name parameter",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: schema,
|
||||
}
|
||||
mcpTool := newMCPToolForTest(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
// Should return the schema as-is when it's already a map
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got '%v'", params["type"])
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Properties should be a map")
|
||||
}
|
||||
|
||||
nameParam, ok := props["name"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Name parameter should exist")
|
||||
}
|
||||
|
||||
if nameParam["type"] != "string" {
|
||||
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user