diff --git a/config/config.example.json b/config/config.example.json index aa75c8338..75478af4d 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -14,7 +14,9 @@ "enabled": false, "token": "YOUR_TELEGRAM_BOT_TOKEN", "proxy": "", - "allow_from": ["YOUR_USER_ID"] + "allow_from": [ + "YOUR_USER_ID" + ] }, "discord": { "enabled": false, @@ -115,6 +117,62 @@ "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 } + }, + "mcp": { + "enabled": false, + "servers": { + "filesystem": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp" + ], + "env": {} + }, + "github": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN" + }, + "envFile": ".env" + }, + "brave-search": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-brave-search" + ], + "env": { + "BRAVE_API_KEY": "YOUR_BRAVE_API_KEY" + } + }, + "postgres": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-postgres", + "postgresql://user:password@localhost/dbname" + ] + }, + "remote-http-example": { + "enabled": false, + "url": "https://mcp-server.example.com/stream", + "type": "sse", + "headers": { + "Authorization": "Bearer YOUR_TOKEN", + "X-Custom-Header": "custom-value" + } + } + } } }, "heartbeat": { @@ -129,4 +187,4 @@ "host": "0.0.0.0", "port": 18790 } -} +} \ No newline at end of file diff --git a/go.mod b/go.mod index 98aecd6ab..528bc40b9 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 + github.com/modelcontextprotocol/go-sdk v1.3.0 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 @@ -19,7 +20,7 @@ require ( golang.org/x/oauth2 v0.35.0 ) - +require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect require ( github.com/andybalholm/brotli v1.2.0 // indirect @@ -28,9 +29,9 @@ require ( github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/github/copilot-sdk/go v0.1.23 - github.com/google/jsonschema-go v0.4.2 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect github.com/grbit/go-json v0.11.0 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect @@ -47,5 +48,4 @@ require ( golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect - ) diff --git a/go.sum b/go.sum index 6a565b93e..5469dd7dd 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -58,6 +60,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -84,6 +88,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= +github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs= +github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE= github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= @@ -141,6 +147,8 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -228,6 +236,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f3dd94090..742ea5496 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -22,6 +22,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/mcp" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/state" @@ -40,6 +41,7 @@ type AgentLoop struct { state *state.Manager contextBuilder *ContextBuilder tools *tools.ToolRegistry + mcpManager *mcp.Manager // MCP server manager for resource cleanup running atomic.Bool summarizing sync.Map // Tracks which sessions are currently being summarized } @@ -58,7 +60,7 @@ type processOptions struct { // createToolRegistry creates a tool registry with common tools. // This is shared between main agent and subagents. -func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { +func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus, mcpManager *mcp.Manager) *tools.ToolRegistry { registry := tools.NewToolRegistry() // File system tools @@ -99,6 +101,23 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg }) registry.Register(messageTool) + // Register MCP tools from all connected servers + if mcpManager != nil { + servers := mcpManager.GetServers() + for serverName, conn := range servers { + for _, tool := range conn.Tools { + mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) + registry.Register(mcpTool) + logger.DebugCF("agent", "Registered MCP tool", + map[string]interface{}{ + "server": serverName, + "tool": tool.Name, + "name": mcpTool.Name(), + }) + } + } + } + return registry } @@ -108,12 +127,22 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers restrict := cfg.Agents.Defaults.RestrictToWorkspace + // Initialize MCP Manager and load servers + mcpManager := mcp.NewManager() + ctx := context.Background() + if err := mcpManager.LoadFromConfig(ctx, cfg); err != nil { + logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available", + map[string]interface{}{ + "error": err.Error(), + }) + } + // Create tool registry for main agent - toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager) // Create subagent manager with its own tool registry subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) - subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) + subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager) // Subagent doesn't need spawn/subagent tools to avoid recursion subagentManager.SetTools(subagentTools) @@ -145,6 +174,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers state: stateManager, contextBuilder: contextBuilder, tools: toolsRegistry, + mcpManager: mcpManager, summarizing: sync.Map{}, } } @@ -193,6 +223,16 @@ func (al *AgentLoop) Run(ctx context.Context) error { func (al *AgentLoop) Stop() { al.running.Store(false) + + // Clean up MCP connections + if al.mcpManager != nil { + if err := al.mcpManager.Close(); err != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]interface{}{ + "error": err.Error(), + }) + } + } } func (al *AgentLoop) RegisterTool(tool tools.Tool) { diff --git a/pkg/config/config.go b/pkg/config/config.go index d76ec8095..237eade65 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -212,6 +212,35 @@ type WebToolsConfig struct { type ToolsConfig struct { Web WebToolsConfig `json:"web"` + MCP MCPConfig `json:"mcp"` +} + +// MCPServerConfig defines configuration for a single MCP server +type MCPServerConfig struct { + // Enabled indicates whether this MCP server is active + Enabled bool `json:"enabled"` + // Command is the executable to run (e.g., "npx", "python", "/path/to/server") + Command string `json:"command"` + // Args are the arguments to pass to the command + Args []string `json:"args,omitempty"` + // Env are environment variables to set for the server process (stdio only) + Env map[string]string `json:"env,omitempty"` + // EnvFile is the path to a file containing environment variables (stdio only) + EnvFile string `json:"envFile,omitempty"` + // Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set) + Type string `json:"type,omitempty"` + // URL is used for SSE/HTTP transport + URL string `json:"url,omitempty"` + // Headers are HTTP headers to send with requests (sse/http only) + Headers map[string]string `json:"headers,omitempty"` +} + +// MCPConfig defines configuration for all MCP servers +type MCPConfig struct { + // Enabled globally enables/disables MCP integration + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"` + // Servers is a map of server name to server configuration + Servers map[string]MCPServerConfig `json:"servers,omitempty"` } func DefaultConfig() *Config { @@ -321,6 +350,38 @@ func DefaultConfig() *Config { MaxResults: 5, }, }, + MCP: MCPConfig{ + Enabled: false, + Servers: map[string]MCPServerConfig{ + "filesystem": { + Enabled: false, + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-filesystem", "/tmp"}, + Env: map[string]string{}, + }, + "github": { + Enabled: false, + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-github"}, + Env: map[string]string{ + "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN", + }, + }, + "brave-search": { + Enabled: false, + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-brave-search"}, + Env: map[string]string{ + "BRAVE_API_KEY": "YOUR_BRAVE_API_KEY", + }, + }, + "postgres": { + Enabled: false, + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-postgres", "postgresql://user:password@localhost/dbname"}, + }, + }, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go new file mode 100644 index 000000000..d6ca28f76 --- /dev/null +++ b/pkg/mcp/manager.go @@ -0,0 +1,432 @@ +package mcp + +import ( + "bufio" + "context" + "fmt" + "net/http" + "os" + "os/exec" + "strings" + "sync" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// headerTransport is an http.RoundTripper that adds custom headers to requests +type headerTransport struct { + base http.RoundTripper + headers map[string]string +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid modifying the original + req = req.Clone(req.Context()) + + // Add custom headers + for key, value := range t.headers { + req.Header.Set(key, value) + } + + // Use the base transport + base := t.base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(req) +} + +// loadEnvFile loads environment variables from a file in .env format +// Each line should be in the format: KEY=value +// Lines starting with # are comments +// Empty lines are ignored +func loadEnvFile(path string) (map[string]string, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open env file: %w", err) + } + defer file.Close() + + envVars := make(map[string]string) + scanner := bufio.NewScanner(file) + lineNum := 0 + + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Parse KEY=value + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Remove surrounding quotes if present + if len(value) >= 2 { + if (value[0] == '"' && value[len(value)-1] == '"') || + (value[0] == '\'' && value[len(value)-1] == '\'') { + value = value[1 : len(value)-1] + } + } + + envVars[key] = value + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading env file: %w", err) + } + + return envVars, nil +} + +// ServerConnection represents a connection to an MCP server +type ServerConnection struct { + Name string + Client *mcp.Client + Session *mcp.ClientSession + Tools []*mcp.Tool +} + +// Manager manages multiple MCP server connections +type Manager struct { + servers map[string]*ServerConnection + mu sync.RWMutex +} + +// NewManager creates a new MCP manager +func NewManager() *Manager { + return &Manager{ + servers: make(map[string]*ServerConnection), + } +} + +// LoadFromConfig loads MCP servers from configuration +func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error { + if !cfg.Tools.MCP.Enabled { + logger.InfoCF("mcp", "MCP integration is disabled", nil) + return nil + } + + if len(cfg.Tools.MCP.Servers) == 0 { + logger.InfoCF("mcp", "No MCP servers configured", nil) + return nil + } + + logger.InfoCF("mcp", "Initializing MCP servers", + map[string]interface{}{ + "count": len(cfg.Tools.MCP.Servers), + }) + + var wg sync.WaitGroup + errs := make(chan error, len(cfg.Tools.MCP.Servers)) + + for name, serverCfg := range cfg.Tools.MCP.Servers { + if !serverCfg.Enabled { + logger.DebugCF("mcp", "Skipping disabled server", + map[string]interface{}{ + "server": name, + }) + continue + } + + wg.Add(1) + go func(name string, serverCfg config.MCPServerConfig) { + defer wg.Done() + + if err := m.ConnectServer(ctx, name, serverCfg); err != nil { + logger.ErrorCF("mcp", "Failed to connect to MCP server", + map[string]interface{}{ + "server": name, + "error": err.Error(), + }) + errs <- fmt.Errorf("failed to connect to server %s: %w", name, err) + } + }(name, serverCfg) + } + + wg.Wait() + close(errs) + + // Collect errors + var allErrors []error + for err := range errs { + allErrors = append(allErrors, err) + } + + if len(allErrors) > 0 { + logger.WarnCF("mcp", "Some MCP servers failed to connect", + map[string]interface{}{ + "failed": len(allErrors), + "total": len(cfg.Tools.MCP.Servers), + }) + // Don't fail completely if some servers fail to connect + } + + connectedCount := len(m.GetServers()) + logger.InfoCF("mcp", "MCP server initialization complete", + map[string]interface{}{ + "connected": connectedCount, + "total": len(cfg.Tools.MCP.Servers), + }) + + return nil +} + +// ConnectServer connects to a single MCP server +func (m *Manager) ConnectServer(ctx context.Context, name string, cfg config.MCPServerConfig) error { + logger.InfoCF("mcp", "Connecting to MCP server", + map[string]interface{}{ + "server": name, + "command": cfg.Command, + "args": cfg.Args, + }) + + // Create client + client := mcp.NewClient(&mcp.Implementation{ + Name: "picoclaw", + Version: "1.0.0", + }, nil) + + // Create transport based on configuration + // Auto-detect transport type if not explicitly specified + var transport mcp.Transport + transportType := cfg.Type + + // Auto-detect: if URL is provided, use SSE; if command is provided, use stdio + if transportType == "" { + if cfg.URL != "" { + transportType = "sse" + } else if cfg.Command != "" { + transportType = "stdio" + } else { + return fmt.Errorf("either URL or command must be provided") + } + } + + switch transportType { + case "sse", "http": + if cfg.URL == "" { + return fmt.Errorf("URL is required for SSE/HTTP transport") + } + logger.DebugCF("mcp", "Using SSE/HTTP transport", + map[string]interface{}{ + "server": name, + "url": cfg.URL, + }) + + sseTransport := &mcp.StreamableClientTransport{ + Endpoint: cfg.URL, + } + + // Add custom headers if provided + if len(cfg.Headers) > 0 { + // Create a custom HTTP client with header-injecting transport + sseTransport.HTTPClient = &http.Client{ + Transport: &headerTransport{ + base: http.DefaultTransport, + headers: cfg.Headers, + }, + } + logger.DebugCF("mcp", "Added custom HTTP headers", + map[string]interface{}{ + "server": name, + "header_count": len(cfg.Headers), + }) + } + + transport = sseTransport + case "stdio": + if cfg.Command == "" { + return fmt.Errorf("command is required for stdio transport") + } + logger.DebugCF("mcp", "Using stdio transport", + map[string]interface{}{ + "server": name, + "command": cfg.Command, + }) + // Create command with context + cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...) + + // Set environment variables + env := cmd.Environ() + + // Load environment variables from file if specified + if cfg.EnvFile != "" { + envVars, err := loadEnvFile(cfg.EnvFile) + if err != nil { + return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err) + } + for k, v := range envVars { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + logger.DebugCF("mcp", "Loaded environment variables from file", + map[string]interface{}{ + "server": name, + "envFile": cfg.EnvFile, + "var_count": len(envVars), + }) + } + + // Environment variables from config override those from file + if len(cfg.Env) > 0 { + for k, v := range cfg.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + } + + // Set environment if we added any variables + if len(env) > len(cmd.Environ()) { + cmd.Env = env + } + + transport = &mcp.CommandTransport{Command: cmd} + default: + return fmt.Errorf("unsupported transport type: %s (supported: stdio, sse, http)", transportType) + } + + // Connect to server + session, err := client.Connect(ctx, transport, nil) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + // Get server info + initResult := session.InitializeResult() + logger.InfoCF("mcp", "Connected to MCP server", + map[string]interface{}{ + "server": name, + "serverName": initResult.ServerInfo.Name, + "serverVersion": initResult.ServerInfo.Version, + "protocol": initResult.ProtocolVersion, + }) + + // List available tools if supported + var tools []*mcp.Tool + if initResult.Capabilities.Tools != nil { + for tool, err := range session.Tools(ctx, nil) { + if err != nil { + logger.WarnCF("mcp", "Error listing tool", + map[string]interface{}{ + "server": name, + "error": err.Error(), + }) + continue + } + tools = append(tools, tool) + } + + logger.InfoCF("mcp", "Listed tools from MCP server", + map[string]interface{}{ + "server": name, + "toolCount": len(tools), + }) + } + + // Store connection + m.mu.Lock() + m.servers[name] = &ServerConnection{ + Name: name, + Client: client, + Session: session, + Tools: tools, + } + m.mu.Unlock() + + return nil +} + +// GetServers returns all connected servers +func (m *Manager) GetServers() map[string]*ServerConnection { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]*ServerConnection, len(m.servers)) + for k, v := range m.servers { + result[k] = v + } + return result +} + +// GetServer returns a specific server connection +func (m *Manager) GetServer(name string) (*ServerConnection, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + conn, ok := m.servers[name] + return conn, ok +} + +// CallTool calls a tool on a specific server +func (m *Manager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + conn, ok := m.GetServer(serverName) + if !ok { + return nil, fmt.Errorf("server %s not found", serverName) + } + + params := &mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + } + + result, err := conn.Session.CallTool(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to call tool: %w", err) + } + + return result, nil +} + +// Close closes all server connections +func (m *Manager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + logger.InfoCF("mcp", "Closing all MCP server connections", + map[string]interface{}{ + "count": len(m.servers), + }) + + var errs []error + for name, conn := range m.servers { + if err := conn.Session.Close(); err != nil { + logger.ErrorCF("mcp", "Failed to close server connection", + map[string]interface{}{ + "server": name, + "error": err.Error(), + }) + errs = append(errs, err) + } + } + + m.servers = make(map[string]*ServerConnection) + + if len(errs) > 0 { + return fmt.Errorf("failed to close %d server(s)", len(errs)) + } + + return nil +} + +// GetAllTools returns all tools from all connected servers +func (m *Manager) GetAllTools() map[string][]*mcp.Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string][]*mcp.Tool) + for name, conn := range m.servers { + if len(conn.Tools) > 0 { + result[name] = conn.Tools + } + } + return result +} diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go new file mode 100644 index 000000000..e888d8764 --- /dev/null +++ b/pkg/mcp/manager_test.go @@ -0,0 +1,182 @@ +package mcp + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadEnvFile(t *testing.T) { + tests := []struct { + name string + content string + expected map[string]string + expectErr bool + }{ + { + name: "basic env file", + content: `API_KEY=secret123 +DATABASE_URL=postgres://localhost/db +PORT=8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "with comments and empty lines", + content: `# This is a comment +API_KEY=secret123 + +# Another comment +DATABASE_URL=postgres://localhost/db + +PORT=8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "with quoted values", + content: `API_KEY="secret with spaces" +NAME='single quoted' +PLAIN=no-quotes`, + expected: map[string]string{ + "API_KEY": "secret with spaces", + "NAME": "single quoted", + "PLAIN": "no-quotes", + }, + expectErr: false, + }, + { + name: "with spaces around equals", + content: `API_KEY = secret123 +DATABASE_URL= postgres://localhost/db +PORT =8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "invalid format - no equals", + content: `INVALID_LINE`, + expectErr: true, + }, + { + name: "empty file", + content: ``, + expected: map[string]string{}, + expectErr: false, + }, + { + name: "only comments", + content: `# Comment 1 +# Comment 2`, + expected: map[string]string{}, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + + if err := os.WriteFile(envFile, []byte(tt.content), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + result, err := loadEnvFile(envFile) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if len(result) != len(tt.expected) { + t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result)) + } + + for key, expectedValue := range tt.expected { + if actualValue, ok := result[key]; !ok { + t.Errorf("Expected key %s not found", key) + } else if actualValue != expectedValue { + t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue) + } + } + }) + } +} + +func TestLoadEnvFileNotFound(t *testing.T) { + _, err := loadEnvFile("/nonexistent/file.env") + if err == nil { + t.Error("Expected error for nonexistent file") + } +} + +func TestEnvFilePriority(t *testing.T) { + // Create a temporary .env file + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + + envContent := `API_KEY=from_file +DATABASE_URL=from_file +SHARED_VAR=from_file` + + if err := os.WriteFile(envFile, []byte(envContent), 0644); err != nil { + t.Fatalf("Failed to create .env file: %v", err) + } + + // Load envFile + envVars, err := loadEnvFile(envFile) + if err != nil { + t.Fatalf("Failed to load env file: %v", err) + } + + // Verify envFile variables + if envVars["API_KEY"] != "from_file" { + t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"]) + } + + // Simulate config.Env overriding envFile + configEnv := map[string]string{ + "SHARED_VAR": "from_config", + "NEW_VAR": "from_config", + } + + // Merge: envFile first, then config overrides + merged := make(map[string]string) + for k, v := range envVars { + merged[k] = v + } + for k, v := range configEnv { + merged[k] = v + } + + // Verify priority: config.Env should override envFile + if merged["SHARED_VAR"] != "from_config" { + t.Errorf("Expected SHARED_VAR=from_config (config should override file), got %s", merged["SHARED_VAR"]) + } + if merged["API_KEY"] != "from_file" { + t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"]) + } + if merged["NEW_VAR"] != "from_config" { + t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"]) + } +} diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go new file mode 100644 index 000000000..e08a526ca --- /dev/null +++ b/pkg/tools/mcp_tool.go @@ -0,0 +1,119 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + mcpPkg "github.com/sipeed/picoclaw/pkg/mcp" +) + +// MCPManager defines the interface for MCP manager operations +// This allows for easier testing with mock implementations +type MCPManager interface { + CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) +} + +// MCPTool wraps an MCP tool to implement the Tool interface +type MCPTool struct { + manager MCPManager + serverName string + tool *mcp.Tool +} + +// NewMCPTool creates a new MCP tool wrapper +func NewMCPTool(manager *mcpPkg.Manager, serverName string, tool *mcp.Tool) *MCPTool { + return &MCPTool{ + manager: manager, + serverName: serverName, + tool: tool, + } +} + +// Name returns the tool name, prefixed with the server name +func (t *MCPTool) Name() string { + // Prefix with server name to avoid conflicts + return fmt.Sprintf("mcp_%s_%s", t.serverName, t.tool.Name) +} + +// Description returns the tool description +func (t *MCPTool) Description() string { + desc := t.tool.Description + if desc == "" { + desc = fmt.Sprintf("MCP tool from %s server", t.serverName) + } + // Add server info to description + return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc) +} + +// Parameters returns the tool parameters schema +func (t *MCPTool) Parameters() map[string]interface{} { + // The InputSchema is already a JSON Schema object + schema := t.tool.InputSchema + + // Convert to map[string]interface{} for compatibility + result := make(map[string]interface{}) + + // Use reflection to convert the schema + // The schema should already be in the correct format + if schema != nil { + // Attempt to convert directly + if schemaMap, ok := schema.(map[string]interface{}); ok { + return schemaMap + } + + // Otherwise, build it manually + result["type"] = "object" + result["properties"] = map[string]interface{}{} + result["required"] = []string{} + } else { + // Default schema when nil + result["type"] = "object" + result["properties"] = map[string]interface{}{} + result["required"] = []string{} + } + + return result +} + +// Execute executes the MCP tool +func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args) + if err != nil { + return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err) + } + + // Handle error result from server + if result.IsError { + errMsg := extractContentText(result.Content) + return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)). + WithError(fmt.Errorf("MCP tool error: %s", errMsg)) + } + + // Extract text content from result + output := extractContentText(result.Content) + + return &ToolResult{ + ForLLM: output, + IsError: false, + } +} + +// extractContentText extracts text from MCP content array +func extractContentText(content []mcp.Content) string { + var parts []string + for _, c := range content { + switch v := c.(type) { + case *mcp.TextContent: + parts = append(parts, v.Text) + case *mcp.ImageContent: + // For images, just indicate that an image was returned + parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType)) + default: + // For other content types, use string representation + parts = append(parts, fmt.Sprintf("[Content: %T]", v)) + } + } + return strings.Join(parts, "\n") +} diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go new file mode 100644 index 000000000..3be5d95a0 --- /dev/null +++ b/pkg/tools/mcp_tool_test.go @@ -0,0 +1,456 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MockMCPManager is a mock implementation of MCPManager interface for testing +type MockMCPManager struct { + callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) +} + +func (m *MockMCPManager) CallTool(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + if m.callToolFunc != nil { + return m.callToolFunc(ctx, serverName, toolName, arguments) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "mock result"}, + }, + IsError: false, + }, nil +} + +// newMCPToolForTest creates an MCP tool for testing with mock manager +func newMCPToolForTest(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool { + return &MCPTool{ + manager: manager, + serverName: serverName, + tool: tool, + } +} + +// TestNewMCPTool verifies MCP tool creation +func TestNewMCPTool(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "input": map[string]interface{}{ + "type": "string", + "description": "Test input", + }, + }, + }, + } + + mcpTool := newMCPToolForTest(manager, "test_server", tool) + + if mcpTool == nil { + t.Fatal("NewMCPTool should not return nil") + } + // Verify tool properties we can access + if mcpTool.Name() != "mcp_test_server_test_tool" { + t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name()) + } +} + +// TestMCPTool_Name verifies tool name with server prefix +func TestMCPTool_Name(t *testing.T) { + tests := []struct { + name string + serverName string + toolName string + expected string + }{ + { + name: "simple name", + serverName: "github", + toolName: "create_issue", + expected: "mcp_github_create_issue", + }, + { + name: "filesystem server", + serverName: "filesystem", + toolName: "read_file", + expected: "mcp_filesystem_read_file", + }, + { + name: "remote server", + serverName: "remote-api", + toolName: "fetch_data", + expected: "mcp_remote-api_fetch_data", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{Name: tt.toolName} + mcpTool := newMCPToolForTest(manager, tt.serverName, tool) + + result := mcpTool.Name() + if result != tt.expected { + t.Errorf("Expected name '%s', got '%s'", tt.expected, result) + } + }) + } +} + +// TestMCPTool_Description verifies tool description generation +func TestMCPTool_Description(t *testing.T) { + tests := []struct { + name string + serverName string + toolDescription string + expectContains []string + }{ + { + name: "with description", + serverName: "github", + toolDescription: "Create a GitHub issue", + expectContains: []string{"[MCP:github]", "Create a GitHub issue"}, + }, + { + name: "empty description", + serverName: "filesystem", + toolDescription: "", + expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + Description: tt.toolDescription, + } + mcpTool := newMCPToolForTest(manager, tt.serverName, tool) + + result := mcpTool.Description() + + for _, expected := range tt.expectContains { + if !strings.Contains(result, expected) { + t.Errorf("Description should contain '%s', got: %s", expected, result) + } + } + }) + } +} + +// TestMCPTool_Parameters verifies parameter schema conversion +func TestMCPTool_Parameters(t *testing.T) { + tests := []struct { + name string + inputSchema interface{} + expectType string + }{ + { + name: "map schema", + inputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "Search query", + }, + }, + "required": []string{"query"}, + }, + expectType: "object", + }, + { + name: "nil schema", + inputSchema: nil, + expectType: "object", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + InputSchema: tt.inputSchema, + } + mcpTool := newMCPToolForTest(manager, "test_server", tool) + + params := mcpTool.Parameters() + + if params == nil { + t.Fatal("Parameters should not be nil") + } + + if params["type"] != tt.expectType { + t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"]) + } + }) + } +} + +// TestMCPTool_Execute_Success tests successful tool execution +func TestMCPTool_Execute_Success(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + // Verify correct parameters passed + if serverName != "github" { + t.Errorf("Expected serverName 'github', got '%s'", serverName) + } + if toolName != "search_repos" { + t.Errorf("Expected toolName 'search_repos', got '%s'", toolName) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Found 3 repositories"}, + }, + IsError: false, + }, nil + }, + } + + tool := &mcp.Tool{ + Name: "search_repos", + Description: "Search GitHub repositories", + } + mcpTool := newMCPToolForTest(manager, "github", tool) + + ctx := context.Background() + args := map[string]interface{}{ + "query": "golang mcp", + } + + result := mcpTool.Execute(ctx, args) + + if result == nil { + t.Fatal("Result should not be nil") + } + if result.IsError { + t.Errorf("Expected no error, got error: %s", result.ForLLM) + } + if result.ForLLM != "Found 3 repositories" { + t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM) + } +} + +// TestMCPTool_Execute_ManagerError tests execution when manager returns error +func TestMCPTool_Execute_ManagerError(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + return nil, fmt.Errorf("connection failed") + }, + } + + tool := &mcp.Tool{Name: "test_tool"} + mcpTool := newMCPToolForTest(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]interface{}{}) + + if result == nil { + t.Fatal("Result should not be nil") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if !strings.Contains(result.ForLLM, "MCP tool execution failed") { + t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "connection failed") { + t.Errorf("Error message should include original error, got: %s", result.ForLLM) + } +} + +// TestMCPTool_Execute_ServerError tests execution when server returns error +func TestMCPTool_Execute_ServerError(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Invalid API key"}, + }, + IsError: true, + }, nil + }, + } + + tool := &mcp.Tool{Name: "test_tool"} + mcpTool := newMCPToolForTest(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]interface{}{}) + + if result == nil { + t.Fatal("Result should not be nil") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if !strings.Contains(result.ForLLM, "MCP tool returned error") { + t.Errorf("Error message should mention server error, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Invalid API key") { + t.Errorf("Error message should include server message, got: %s", result.ForLLM) + } +} + +// TestMCPTool_Execute_MultipleContent tests execution with multiple content items +func TestMCPTool_Execute_MultipleContent(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]interface{}) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "First line"}, + &mcp.TextContent{Text: "Second line"}, + &mcp.TextContent{Text: "Third line"}, + }, + IsError: false, + }, nil + }, + } + + tool := &mcp.Tool{Name: "multi_output"} + mcpTool := newMCPToolForTest(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]interface{}{}) + + if result.IsError { + t.Errorf("Expected no error, got: %s", result.ForLLM) + } + + expected := "First line\nSecond line\nThird line" + if result.ForLLM != expected { + t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM) + } +} + +// TestExtractContentText_TextContent tests text content extraction +func TestExtractContentText_TextContent(t *testing.T) { + content := []mcp.Content{ + &mcp.TextContent{Text: "Hello World"}, + &mcp.TextContent{Text: "Second message"}, + } + + result := extractContentText(content) + expected := "Hello World\nSecond message" + + if result != expected { + t.Errorf("Expected '%s', got '%s'", expected, result) + } +} + +// TestExtractContentText_ImageContent tests image content extraction +func TestExtractContentText_ImageContent(t *testing.T) { + content := []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("base64data"), + MIMEType: "image/png", + }, + } + + result := extractContentText(content) + + if !strings.Contains(result, "[Image:") { + t.Errorf("Expected image indicator, got: %s", result) + } + if !strings.Contains(result, "image/png") { + t.Errorf("Expected MIME type in output, got: %s", result) + } +} + +// TestExtractContentText_MixedContent tests mixed content types +func TestExtractContentText_MixedContent(t *testing.T) { + content := []mcp.Content{ + &mcp.TextContent{Text: "Description"}, + &mcp.ImageContent{ + Data: []byte("data"), + MIMEType: "image/jpeg", + }, + &mcp.TextContent{Text: "More text"}, + } + + result := extractContentText(content) + + if !strings.Contains(result, "Description") { + t.Errorf("Should contain text content, got: %s", result) + } + if !strings.Contains(result, "[Image:") { + t.Errorf("Should contain image indicator, got: %s", result) + } + if !strings.Contains(result, "More text") { + t.Errorf("Should contain second text, got: %s", result) + } +} + +// TestExtractContentText_EmptyContent tests empty content array +func TestExtractContentText_EmptyContent(t *testing.T) { + content := []mcp.Content{} + + result := extractContentText(content) + + if result != "" { + t.Errorf("Expected empty string for empty content, got: %s", result) + } +} + +// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface +func TestMCPTool_InterfaceCompliance(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{Name: "test"} + mcpTool := newMCPToolForTest(manager, "test_server", tool) + + // Verify it implements Tool interface + var _ Tool = mcpTool +} + +// TestMCPTool_Parameters_MapSchema tests schema that's already a map +func TestMCPTool_Parameters_MapSchema(t *testing.T) { + manager := &MockMCPManager{} + schema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "The name parameter", + }, + }, + "required": []string{"name"}, + } + + tool := &mcp.Tool{ + Name: "test_tool", + InputSchema: schema, + } + mcpTool := newMCPToolForTest(manager, "test_server", tool) + + params := mcpTool.Parameters() + + // Should return the schema as-is when it's already a map + if params["type"] != "object" { + t.Errorf("Expected type 'object', got '%v'", params["type"]) + } + + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Error("Properties should be a map") + } + + nameParam, ok := props["name"].(map[string]interface{}) + if !ok { + t.Error("Name parameter should exist") + } + + if nameParam["type"] != "string" { + t.Errorf("Name type should be 'string', got '%v'", nameParam["type"]) + } +}