mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50c58a3462 | |||
| 403e048821 | |||
| 4fde0175cf | |||
| 0d16525fab | |||
| 4cd3f99dd6 | |||
| 920e30a241 | |||
| 7b9b8104c8 | |||
| 881999aceb | |||
| f929268ab2 | |||
| 684e7413e1 | |||
| 7ce5b75178 | |||
| 40f90281e5 | |||
| 82856bc57a |
@@ -39,6 +39,8 @@ ifeq ($(UNAME_S),Linux)
|
||||
ARCH=amd64
|
||||
else ifeq ($(UNAME_M),aarch64)
|
||||
ARCH=arm64
|
||||
else ifeq ($(UNAME_M),loongarch64)
|
||||
ARCH=loong64
|
||||
else ifeq ($(UNAME_M),riscv64)
|
||||
ARCH=riscv64
|
||||
else
|
||||
@@ -84,6 +86,7 @@ build-all: generate
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
|
||||
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
|
||||
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
|
||||
|
||||
@@ -195,6 +195,9 @@ picoclaw onboard
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
@@ -697,6 +700,9 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
"search": {
|
||||
"apiKey": "BSA..."
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
@@ -774,6 +774,9 @@ picoclaw agent -m "Hello"
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
@@ -236,6 +236,9 @@ picoclaw onboard
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -644,6 +647,9 @@ picoclaw agent -m "你好"
|
||||
"search": {
|
||||
"api_key": "BSA..."
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
@@ -562,7 +562,8 @@ func gatewayCmd() {
|
||||
})
|
||||
|
||||
// Setup cron tool and service
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace)
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout)
|
||||
|
||||
heartbeatService := heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
@@ -987,14 +988,14 @@ func getConfigPath() string {
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
}
|
||||
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool) *cron.CronService {
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *cron.CronService {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict)
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout)
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
|
||||
// Set the onJob handler
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
"enabled": false,
|
||||
"token": "YOUR_TELEGRAM_BOT_TOKEN",
|
||||
"proxy": "",
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
"allow_from": [
|
||||
"YOUR_USER_ID"
|
||||
]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": false,
|
||||
@@ -115,9 +117,40 @@
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "pplx-xxx",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
},
|
||||
"mcp": {
|
||||
"enabled": false,
|
||||
"servers": {
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/tmp"
|
||||
],
|
||||
"protocol": "mcp",
|
||||
"env": {},
|
||||
"working_dir": "",
|
||||
"init_timeout_seconds": 60,
|
||||
"call_timeout_seconds": 30,
|
||||
"max_response_bytes": 65536,
|
||||
"include_tools": [],
|
||||
"exclude_tools": []
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
+62
-4
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
@@ -44,8 +45,12 @@ type AgentLoop struct {
|
||||
running atomic.Bool
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
channelManager *channels.Manager
|
||||
mcpManager *mcp.Manager
|
||||
mcpCloseOnce sync.Once
|
||||
}
|
||||
|
||||
const defaultWebFetchMaxChars = 50000
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
@@ -60,7 +65,14 @@ type processOptions struct {
|
||||
|
||||
// createToolRegistry creates a tool registry with common tools.
|
||||
// This is shared between main agent and subagents.
|
||||
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
|
||||
func createToolRegistry(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
mcpManager *mcp.Manager,
|
||||
discoveredMCPTools []mcp.RegisteredTool,
|
||||
) *tools.ToolRegistry {
|
||||
registry := tools.NewToolRegistry()
|
||||
|
||||
// File system tools
|
||||
@@ -79,10 +91,15 @@ func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msg
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
}); searchTool != nil {
|
||||
registry.Register(searchTool)
|
||||
}
|
||||
registry.Register(tools.NewWebFetchTool(50000))
|
||||
registry.Register(tools.NewWebFetchTool(defaultWebFetchMaxChars))
|
||||
|
||||
tools.RegisterKnownMCPTools(registry, mcpManager, discoveredMCPTools)
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
registry.Register(tools.NewI2CTool())
|
||||
@@ -110,12 +127,35 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
|
||||
var (
|
||||
mcpManager *mcp.Manager
|
||||
discoveredMCPTools []mcp.RegisteredTool
|
||||
)
|
||||
if cfg.Tools.MCP.Enabled {
|
||||
bootstrap, err := bootstrapMCP(cfg.Tools.MCP)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "MCP tool bootstrap failed",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else if bootstrap != nil {
|
||||
mcpManager = bootstrap.Manager
|
||||
discoveredMCPTools = bootstrap.Tools
|
||||
if len(discoveredMCPTools) > 0 {
|
||||
logger.InfoCF("agent", "MCP tools registered",
|
||||
map[string]interface{}{
|
||||
"count": len(discoveredMCPTools),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create tool registry for main agent
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||
|
||||
// Create subagent manager with its own tool registry
|
||||
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
||||
subagentManager.SetTools(subagentTools)
|
||||
|
||||
@@ -148,11 +188,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
contextBuilder: contextBuilder,
|
||||
tools: toolsRegistry,
|
||||
summarizing: sync.Map{},
|
||||
mcpManager: mcpManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
defer al.closeMCP()
|
||||
|
||||
for al.running.Load() {
|
||||
select {
|
||||
@@ -195,6 +237,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
|
||||
func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
al.closeMCP()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) closeMCP() {
|
||||
if al.mcpManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
al.mcpCloseOnce.Do(func() {
|
||||
if err := al.mcpManager.Close(); err != nil {
|
||||
logger.WarnCF("agent", "Failed to close MCP manager",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
mcpBootstrapMinTimeout = 10 * time.Second
|
||||
mcpBootstrapMaxTimeout = 5 * time.Minute
|
||||
mcpBootstrapGraceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type mcpBootstrapResult struct {
|
||||
Manager *mcp.Manager
|
||||
Tools []mcp.RegisteredTool
|
||||
}
|
||||
|
||||
func bootstrapMCP(cfg config.MCPToolsConfig) (*mcpBootstrapResult, error) {
|
||||
serverConfigs := buildMCPServerConfigs(cfg)
|
||||
if len(serverConfigs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
manager := mcp.NewManager(serverConfigs)
|
||||
|
||||
discoveryTimeout := calculateMCPDiscoveryTimeout(serverConfigs)
|
||||
discoveryCtx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
|
||||
defer cancel()
|
||||
|
||||
discoveredTools, err := manager.DiscoverTools(discoveryCtx)
|
||||
if err != nil {
|
||||
_ = manager.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mcpBootstrapResult{
|
||||
Manager: manager,
|
||||
Tools: discoveredTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func calculateMCPDiscoveryTimeout(serverConfigs map[string]mcp.ServerConfig) time.Duration {
|
||||
maxInitTimeout := mcpBootstrapMinTimeout
|
||||
|
||||
for _, serverConfig := range serverConfigs {
|
||||
initTimeout := serverConfig.InitTimeout()
|
||||
if initTimeout > maxInitTimeout {
|
||||
maxInitTimeout = initTimeout
|
||||
}
|
||||
}
|
||||
|
||||
timeout := maxInitTimeout + mcpBootstrapGraceTimeout
|
||||
if timeout < mcpBootstrapMinTimeout {
|
||||
return mcpBootstrapMinTimeout
|
||||
}
|
||||
if timeout > mcpBootstrapMaxTimeout {
|
||||
return mcpBootstrapMaxTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func buildMCPServerConfigs(cfg config.MCPToolsConfig) map[string]mcp.ServerConfig {
|
||||
servers := make(map[string]mcp.ServerConfig, len(cfg.Servers))
|
||||
|
||||
for serverName, serverCfg := range cfg.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
envCopy := make(map[string]string, len(serverCfg.Env))
|
||||
for key, value := range serverCfg.Env {
|
||||
envCopy[key] = value
|
||||
}
|
||||
|
||||
servers[serverName] = mcp.ServerConfig{
|
||||
Name: serverName,
|
||||
Command: serverCfg.Command,
|
||||
Args: append([]string{}, serverCfg.Args...),
|
||||
Env: envCopy,
|
||||
WorkingDir: serverCfg.WorkingDir,
|
||||
Protocol: inferMCPProtocol(serverCfg.Protocol, serverCfg.Command),
|
||||
InitTimeoutSeconds: serverCfg.InitTimeoutSeconds,
|
||||
CallTimeoutSeconds: serverCfg.CallTimeoutSeconds,
|
||||
MaxResponseBytes: serverCfg.MaxResponseBytes,
|
||||
IncludeTools: append([]string{}, serverCfg.IncludeTools...),
|
||||
ExcludeTools: append([]string{}, serverCfg.ExcludeTools...),
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
func inferMCPProtocol(configuredProtocol, command string) string {
|
||||
if protocol := strings.TrimSpace(configuredProtocol); protocol != "" {
|
||||
return protocol
|
||||
}
|
||||
|
||||
// Context7 currently emits JSON-RPC messages as JSONL on stdio,
|
||||
// so defaulting avoids long startup waits when protocol is omitted.
|
||||
if strings.Contains(strings.ToLower(command), "context7-mcp") {
|
||||
return mcp.ProtocolJSONLines
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -18,7 +18,6 @@ type MaixCamChannel struct {
|
||||
listener net.Listener
|
||||
clients map[net.Conn]bool
|
||||
clientsMux sync.RWMutex
|
||||
running bool
|
||||
}
|
||||
|
||||
type MaixCamMessage struct {
|
||||
@@ -35,7 +34,6 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
clients: make(map[net.Conn]bool),
|
||||
running: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+103
-2
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
@@ -51,7 +52,10 @@ type Config struct {
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
Devices DevicesConfig `json:"devices"`
|
||||
mu sync.RWMutex
|
||||
// MCPServers is a compatibility alias for configs using top-level "mcpServers".
|
||||
// Canonical config remains tools.mcp.servers.
|
||||
MCPServers map[string]LegacyMCPServerConfig `json:"mcpServers,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type AgentsConfig struct {
|
||||
@@ -206,13 +210,54 @@ type DuckDuckGoConfig struct {
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type PerplexityConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
Brave BraveConfig `json:"brave"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
||||
}
|
||||
|
||||
type MCPServerConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
WorkingDir string `json:"working_dir"`
|
||||
Protocol string `json:"protocol"`
|
||||
InitTimeoutSeconds int `json:"init_timeout_seconds"`
|
||||
CallTimeoutSeconds int `json:"call_timeout_seconds"`
|
||||
MaxResponseBytes int `json:"max_response_bytes"`
|
||||
IncludeTools []string `json:"include_tools"`
|
||||
ExcludeTools []string `json:"exclude_tools"`
|
||||
}
|
||||
|
||||
type MCPToolsConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
Servers map[string]MCPServerConfig `json:"servers"`
|
||||
}
|
||||
|
||||
// LegacyMCPServerConfig supports compatibility with "mcpServers" style config.
|
||||
type LegacyMCPServerConfig struct {
|
||||
Type string `json:"type"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
Protocol string `json:"protocol"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
MCP MCPToolsConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
@@ -321,6 +366,18 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
MCP: MCPToolsConfig{
|
||||
Enabled: false,
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
@@ -353,9 +410,53 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.applyLegacyMCPServers()
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) applyLegacyMCPServers() {
|
||||
// If canonical MCP config already exists, keep it as source of truth.
|
||||
if len(c.Tools.MCP.Servers) > 0 {
|
||||
return
|
||||
}
|
||||
if len(c.MCPServers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Tools.MCP.Servers == nil {
|
||||
c.Tools.MCP.Servers = map[string]MCPServerConfig{}
|
||||
}
|
||||
|
||||
for name, legacy := range c.MCPServers {
|
||||
if strings.TrimSpace(legacy.Command) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := true
|
||||
if legacy.Type != "" && legacy.Type != "stdio" {
|
||||
enabled = false
|
||||
}
|
||||
|
||||
envCopy := make(map[string]string, len(legacy.Env))
|
||||
for key, value := range legacy.Env {
|
||||
envCopy[key] = value
|
||||
}
|
||||
|
||||
c.Tools.MCP.Servers[name] = MCPServerConfig{
|
||||
Enabled: enabled,
|
||||
Command: legacy.Command,
|
||||
Args: append([]string{}, legacy.Args...),
|
||||
Env: envCopy,
|
||||
Protocol: legacy.Protocol,
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.Tools.MCP.Servers) > 0 {
|
||||
c.Tools.MCP.Enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
cfg.mu.RLock()
|
||||
defer cfg.mu.RUnlock()
|
||||
|
||||
@@ -0,0 +1,603 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// Client is the transport-agnostic MCP client contract.
|
||||
type Client interface {
|
||||
Start(ctx context.Context) error
|
||||
ListTools(ctx context.Context) ([]RemoteTool, error)
|
||||
CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers).
|
||||
type StdioClient struct {
|
||||
config ServerConfig
|
||||
mode string
|
||||
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
|
||||
started bool
|
||||
closed bool
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
stderr io.ReadCloser
|
||||
waitCh chan struct{}
|
||||
pending map[string]chan rpcResponse
|
||||
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
type rpcRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type rpcResponseEnvelope struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *rpcError `json:"error,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
}
|
||||
|
||||
type rpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type rpcResponse struct {
|
||||
result json.RawMessage
|
||||
rpcErr *rpcError
|
||||
err error
|
||||
}
|
||||
|
||||
type initializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
ClientInfo map[string]interface{} `json:"clientInfo"`
|
||||
}
|
||||
|
||||
func NewStdioClient(config ServerConfig) *StdioClient {
|
||||
return &StdioClient{
|
||||
config: config,
|
||||
mode: normalizeProtocol(config.Protocol),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) Start(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.started {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(c.config.Command) == "" {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("mcp server %q command is empty", c.config.Name)
|
||||
}
|
||||
|
||||
cmd := exec.Command(c.config.Command, c.config.Args...)
|
||||
if c.config.WorkingDir != "" {
|
||||
cmd.Dir = c.config.WorkingDir
|
||||
}
|
||||
cmd.Env = buildProcessEnv(c.config.Env)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
c.started = true
|
||||
c.closed = false
|
||||
c.cmd = cmd
|
||||
c.stdin = stdin
|
||||
c.stdout = stdout
|
||||
c.stderr = stderr
|
||||
c.waitCh = make(chan struct{})
|
||||
c.pending = make(map[string]chan rpcResponse)
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.readLoop()
|
||||
go c.waitLoop()
|
||||
go c.drainStderr()
|
||||
|
||||
initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout())
|
||||
defer cancel()
|
||||
|
||||
_, err = c.request(initCtx, "initialize", initializeParams{
|
||||
ProtocolVersion: "2024-11-05",
|
||||
Capabilities: map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
ClientInfo: map[string]any{
|
||||
"name": "picoclaw",
|
||||
"version": "dev",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialize failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.notify("notifications/initialized", map[string]any{}); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialized notification failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) {
|
||||
if err := c.Start(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type listToolsResponse struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"inputSchema"`
|
||||
} `json:"tools"`
|
||||
NextCursor string `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
allTools := make([]RemoteTool, 0, 8)
|
||||
cursor := ""
|
||||
|
||||
for page := 0; page < maxToolListPages; page++ {
|
||||
params := map[string]any{}
|
||||
if cursor != "" {
|
||||
params["cursor"] = cursor
|
||||
}
|
||||
|
||||
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||
raw, err := c.request(callCtx, "tools/list", params)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response listToolsResponse
|
||||
if err := json.Unmarshal(raw, &response); err != nil {
|
||||
return nil, fmt.Errorf("decode tools/list response: %w", err)
|
||||
}
|
||||
|
||||
for _, tool := range response.Tools {
|
||||
allTools = append(allTools, RemoteTool{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
InputSchema: tool.InputSchema,
|
||||
})
|
||||
}
|
||||
|
||||
if response.NextCursor == "" {
|
||||
return allTools, nil
|
||||
}
|
||||
cursor = response.NextCursor
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages)
|
||||
}
|
||||
|
||||
func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) {
|
||||
if err := c.Start(ctx); err != nil {
|
||||
return CallResult{}, err
|
||||
}
|
||||
|
||||
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||
defer cancel()
|
||||
|
||||
raw, err := c.request(callCtx, "tools/call", map[string]any{
|
||||
"name": toolName,
|
||||
"arguments": arguments,
|
||||
})
|
||||
if err != nil {
|
||||
return CallResult{}, err
|
||||
}
|
||||
|
||||
return formatCallPayload(raw, c.config.ResponseLimit())
|
||||
}
|
||||
|
||||
func (c *StdioClient) Close() error {
|
||||
c.mu.Lock()
|
||||
if !c.started || c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
cmd := c.cmd
|
||||
stdin := c.stdin
|
||||
waitCh := c.waitCh
|
||||
c.mu.Unlock()
|
||||
|
||||
c.failPending(errors.New("mcp client closed"))
|
||||
|
||||
if stdin != nil {
|
||||
_ = stdin.Close()
|
||||
}
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
|
||||
if waitCh != nil {
|
||||
select {
|
||||
case <-waitCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) {
|
||||
id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10)
|
||||
responseCh := make(chan rpcResponse, 1)
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("mcp server %q is closed", c.config.Name)
|
||||
}
|
||||
c.pending[id] = responseCh
|
||||
c.mu.Unlock()
|
||||
|
||||
req := rpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
if err := c.writeMessage(req); err != nil {
|
||||
c.removePending(id)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.removePending(id)
|
||||
return nil, ctx.Err()
|
||||
case response := <-responseCh:
|
||||
if response.err != nil {
|
||||
return nil, response.err
|
||||
}
|
||||
if response.rpcErr != nil {
|
||||
return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message)
|
||||
}
|
||||
return response.result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) notify(method string, params any) error {
|
||||
req := rpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
return c.writeMessage(req)
|
||||
}
|
||||
|
||||
func (c *StdioClient) writeMessage(payload any) error {
|
||||
c.mu.Lock()
|
||||
if c.closed || c.stdin == nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("mcp server %q is not writable", c.config.Name)
|
||||
}
|
||||
stdin := c.stdin
|
||||
c.mu.Unlock()
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal json-rpc payload: %w", err)
|
||||
}
|
||||
|
||||
if c.mode == ProtocolJSONLines {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if _, err := stdin.Write(append(data, '\n')); err != nil {
|
||||
return fmt.Errorf("write jsonl body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
|
||||
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if _, err := io.WriteString(stdin, frameHeader); err != nil {
|
||||
return fmt.Errorf("write frame header: %w", err)
|
||||
}
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("write frame body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) readLoop() {
|
||||
if c.mode == ProtocolJSONLines {
|
||||
c.readJSONLLoop()
|
||||
return
|
||||
}
|
||||
|
||||
c.readMCPFrameLoop()
|
||||
}
|
||||
|
||||
func (c *StdioClient) readMCPFrameLoop() {
|
||||
reader := bufio.NewReader(c.stdout)
|
||||
|
||||
for {
|
||||
payload, err := readFramePayload(reader)
|
||||
if err != nil {
|
||||
c.failPending(err)
|
||||
return
|
||||
}
|
||||
|
||||
var envelope rpcResponseEnvelope
|
||||
if err := json.Unmarshal(payload, &envelope); err != nil {
|
||||
continue
|
||||
}
|
||||
c.dispatchResponse(envelope)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) readJSONLLoop() {
|
||||
scanner := bufio.NewScanner(c.stdout)
|
||||
scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var envelope rpcResponseEnvelope
|
||||
if err := json.Unmarshal([]byte(line), &envelope); err != nil {
|
||||
continue
|
||||
}
|
||||
c.dispatchResponse(envelope)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
c.failPending(err)
|
||||
return
|
||||
}
|
||||
c.failPending(io.EOF)
|
||||
}
|
||||
|
||||
func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) {
|
||||
if len(envelope.ID) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
id, ok := parseRPCID(envelope.ID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
responseCh := c.pending[id]
|
||||
if responseCh != nil {
|
||||
delete(c.pending, id)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if responseCh == nil {
|
||||
return
|
||||
}
|
||||
|
||||
response := rpcResponse{
|
||||
result: envelope.Result,
|
||||
rpcErr: envelope.Error,
|
||||
}
|
||||
select {
|
||||
case responseCh <- response:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) waitLoop() {
|
||||
c.mu.Lock()
|
||||
cmd := c.cmd
|
||||
waitCh := c.waitCh
|
||||
serverName := c.config.Name
|
||||
c.mu.Unlock()
|
||||
|
||||
if cmd == nil {
|
||||
if waitCh != nil {
|
||||
close(waitCh)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := cmd.Wait()
|
||||
if waitCh != nil {
|
||||
close(waitCh)
|
||||
}
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "MCP process exited with error",
|
||||
map[string]any{
|
||||
"server": serverName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) drainStderr() {
|
||||
c.mu.Lock()
|
||||
stderr := c.stderr
|
||||
serverName := c.config.Name
|
||||
c.mu.Unlock()
|
||||
|
||||
if stderr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
logger.DebugCF("mcp", "MCP server stderr",
|
||||
map[string]any{
|
||||
"server": serverName,
|
||||
"line": line,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) failPending(err error) {
|
||||
c.mu.Lock()
|
||||
pending := c.pending
|
||||
c.pending = make(map[string]chan rpcResponse)
|
||||
c.mu.Unlock()
|
||||
|
||||
if len(pending) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range pending {
|
||||
select {
|
||||
case ch <- rpcResponse{err: err}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) removePending(id string) {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func readFramePayload(reader *bufio.Reader) ([]byte, error) {
|
||||
contentLength := -1
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if trimmed == "" {
|
||||
break
|
||||
}
|
||||
|
||||
parts := strings.SplitN(trimmed, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
headerName := strings.TrimSpace(strings.ToLower(parts[0]))
|
||||
if headerName != "content-length" {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(parts[1])
|
||||
length, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid content-length %q: %w", value, err)
|
||||
}
|
||||
contentLength = length
|
||||
}
|
||||
|
||||
if contentLength <= 0 {
|
||||
return nil, fmt.Errorf("missing content-length")
|
||||
}
|
||||
if contentLength > maxFrameBytes {
|
||||
return nil, fmt.Errorf("frame too large (%d bytes)", contentLength)
|
||||
}
|
||||
|
||||
payload := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func parseRPCID(raw json.RawMessage) (string, bool) {
|
||||
var stringID string
|
||||
if err := json.Unmarshal(raw, &stringID); err == nil {
|
||||
return stringID, true
|
||||
}
|
||||
|
||||
var numberID float64
|
||||
if err := json.Unmarshal(raw, &numberID); err == nil {
|
||||
return strconv.FormatInt(int64(numberID), 10), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if _, hasDeadline := parent.Deadline(); hasDeadline {
|
||||
return context.WithCancel(parent)
|
||||
}
|
||||
return context.WithTimeout(parent, timeout)
|
||||
}
|
||||
|
||||
func buildProcessEnv(custom map[string]string) []string {
|
||||
base := os.Environ()
|
||||
if len(custom) == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(custom))
|
||||
for key := range custom {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
env := make([]string, 0, len(base)+len(keys))
|
||||
env = append(env, base...)
|
||||
for _, key := range keys {
|
||||
env = append(env, key+"="+custom[key])
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func normalizeProtocol(protocol string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(protocol)) {
|
||||
case "", ProtocolMCPFrames:
|
||||
return ProtocolMCPFrames
|
||||
case ProtocolJSONLines:
|
||||
return ProtocolJSONLines
|
||||
default:
|
||||
return ProtocolMCPFrames
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type callResponse struct {
|
||||
Content []contentBlock `json:"content"`
|
||||
StructuredContent any `json:"structuredContent,omitempty"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
type contentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
func formatCallPayload(raw json.RawMessage, responseLimit int) (CallResult, error) {
|
||||
var payload callResponse
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
// Fallback for servers that return non-standard payloads.
|
||||
return CallResult{
|
||||
Content: truncateString(strings.TrimSpace(string(raw)), responseLimit),
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(payload.Content)+1)
|
||||
for _, block := range payload.Content {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
parts = append(parts, block.Text)
|
||||
}
|
||||
}
|
||||
|
||||
if payload.StructuredContent != nil {
|
||||
if encoded, err := json.Marshal(payload.StructuredContent); err == nil {
|
||||
parts = append(parts, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(strings.Join(parts, "\n"))
|
||||
if content == "" {
|
||||
content = "{}"
|
||||
}
|
||||
|
||||
return CallResult{
|
||||
Content: truncateString(content, responseLimit),
|
||||
IsError: payload.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func truncateString(value string, maxBytes int) string {
|
||||
if maxBytes <= 0 || len(value) <= maxBytes {
|
||||
return value
|
||||
}
|
||||
if maxBytes <= 12 {
|
||||
return value[:maxBytes]
|
||||
}
|
||||
return value[:maxBytes-12] + "\n...[truncated]"
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type clientFactory func(config ServerConfig) Client
|
||||
|
||||
type managedServer struct {
|
||||
config ServerConfig
|
||||
client Client
|
||||
}
|
||||
|
||||
// Manager owns MCP servers and maps discovered MCP tools to PicoClaw tools.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
servers map[string]*managedServer
|
||||
tools map[string]RegisteredTool
|
||||
|
||||
discovered bool
|
||||
newClient clientFactory
|
||||
}
|
||||
|
||||
func NewManager(configs map[string]ServerConfig) *Manager {
|
||||
servers := make(map[string]*managedServer, len(configs))
|
||||
for name, cfg := range configs {
|
||||
copied := cfg
|
||||
copied.Name = name
|
||||
servers[name] = &managedServer{config: copied}
|
||||
}
|
||||
return &Manager{
|
||||
servers: servers,
|
||||
tools: make(map[string]RegisteredTool),
|
||||
discovered: false,
|
||||
newClient: func(config ServerConfig) Client {
|
||||
return NewStdioClient(config)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DiscoverTools starts configured MCP servers and returns discovered tool metadata.
|
||||
func (m *Manager) DiscoverTools(ctx context.Context) ([]RegisteredTool, error) {
|
||||
m.mu.Lock()
|
||||
if m.discovered {
|
||||
tools := toolsFromMap(m.tools)
|
||||
m.mu.Unlock()
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
discoveryErrors := make([]string, 0)
|
||||
|
||||
for serverName, server := range m.servers {
|
||||
client := m.newClient(server.config)
|
||||
if err := client.Start(ctx); err != nil {
|
||||
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
remoteTools, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
server.client = client
|
||||
for _, remoteTool := range remoteTools {
|
||||
if !isToolAllowed(remoteTool.Name, server.config.IncludeTools, server.config.ExcludeTools) {
|
||||
continue
|
||||
}
|
||||
|
||||
qualifiedName := m.makeUniqueToolName(serverName, remoteTool.Name)
|
||||
parameters := normalizeSchema(remoteTool.InputSchema)
|
||||
m.tools[qualifiedName] = RegisteredTool{
|
||||
QualifiedName: qualifiedName,
|
||||
ServerName: serverName,
|
||||
ToolName: remoteTool.Name,
|
||||
Description: remoteTool.Description,
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.discovered = true
|
||||
tools := toolsFromMap(m.tools)
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(tools) == 0 && len(discoveryErrors) > 0 {
|
||||
return nil, fmt.Errorf("mcp tool discovery failed: %s", strings.Join(discoveryErrors, "; "))
|
||||
}
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (m *Manager) CallTool(ctx context.Context, qualifiedName string, args map[string]any) (CallResult, error) {
|
||||
m.mu.RLock()
|
||||
tool, ok := m.tools[qualifiedName]
|
||||
if !ok {
|
||||
m.mu.RUnlock()
|
||||
return CallResult{}, fmt.Errorf("mcp tool %q not found", qualifiedName)
|
||||
}
|
||||
|
||||
server := m.servers[tool.ServerName]
|
||||
if server == nil || server.client == nil {
|
||||
m.mu.RUnlock()
|
||||
return CallResult{}, fmt.Errorf("mcp server %q is not active", tool.ServerName)
|
||||
}
|
||||
client := server.client
|
||||
toolName := tool.ToolName
|
||||
m.mu.RUnlock()
|
||||
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
return client.CallTool(ctx, toolName, args)
|
||||
}
|
||||
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
servers := make([]*managedServer, 0, len(m.servers))
|
||||
for _, server := range m.servers {
|
||||
servers = append(servers, server)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
var firstErr error
|
||||
for _, server := range servers {
|
||||
if server.client == nil {
|
||||
continue
|
||||
}
|
||||
if err := server.client.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (m *Manager) makeUniqueToolName(serverName, toolName string) string {
|
||||
base := QualifiedToolName(serverName, toolName)
|
||||
if _, exists := m.tools[base]; !exists {
|
||||
return base
|
||||
}
|
||||
|
||||
for index := 2; ; index++ {
|
||||
candidate := fmt.Sprintf("%s_%d", base, index)
|
||||
if len(candidate) > qualifiedNameMaxLen {
|
||||
overflow := len(candidate) - qualifiedNameMaxLen
|
||||
if overflow < len(base) {
|
||||
candidate = base[:len(base)-overflow] + fmt.Sprintf("_%d", index)
|
||||
} else {
|
||||
candidate = candidate[:qualifiedNameMaxLen]
|
||||
}
|
||||
}
|
||||
if _, exists := m.tools[candidate]; !exists {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSchema(schema map[string]any) map[string]any {
|
||||
if len(schema) == 0 {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func isToolAllowed(name string, include, exclude []string) bool {
|
||||
if len(include) > 0 && !slices.Contains(include, name) {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(exclude, name) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func toolsFromMap(tools map[string]RegisteredTool) []RegisteredTool {
|
||||
out := make([]RegisteredTool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
out = append(out, tool)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package mcp
|
||||
|
||||
import "strings"
|
||||
|
||||
const qualifiedNameMaxLen = 64
|
||||
|
||||
// QualifiedToolName creates a stable, provider-safe function name.
|
||||
func QualifiedToolName(serverName, toolName string) string {
|
||||
prefix := "mcp_" + sanitizeName(serverName) + "__"
|
||||
tool := sanitizeName(toolName)
|
||||
maxToolLen := qualifiedNameMaxLen - len(prefix)
|
||||
if maxToolLen <= 0 {
|
||||
return prefix[:qualifiedNameMaxLen]
|
||||
}
|
||||
if len(tool) > maxToolLen {
|
||||
tool = tool[:maxToolLen]
|
||||
}
|
||||
return prefix + tool
|
||||
}
|
||||
|
||||
func sanitizeName(value string) string {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
if trimmed == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(trimmed))
|
||||
|
||||
lastUnderscore := false
|
||||
for i := 0; i < len(trimmed); i++ {
|
||||
ch := trimmed[i]
|
||||
isAlphaNum := (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')
|
||||
if isAlphaNum {
|
||||
b.WriteByte(ch)
|
||||
lastUnderscore = false
|
||||
continue
|
||||
}
|
||||
if !lastUnderscore {
|
||||
b.WriteByte('_')
|
||||
lastUnderscore = true
|
||||
}
|
||||
}
|
||||
|
||||
s := strings.Trim(b.String(), "_")
|
||||
if s == "" {
|
||||
s = "unknown"
|
||||
}
|
||||
if s[0] >= '0' && s[0] <= '9' {
|
||||
return "t_" + s
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package mcp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultInitTimeoutSeconds = 60
|
||||
defaultCallTimeoutSeconds = 30
|
||||
defaultMaxResponseBytes = 64 * 1024
|
||||
defaultScannerBufferBytes = 64 * 1024
|
||||
maxFrameBytes = 2 * 1024 * 1024
|
||||
maxToolListPages = 50
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolMCPFrames = "mcp"
|
||||
ProtocolJSONLines = "jsonl"
|
||||
)
|
||||
|
||||
// ServerConfig defines one MCP server connection.
|
||||
type ServerConfig struct {
|
||||
Name string
|
||||
Command string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
WorkingDir string
|
||||
Protocol string
|
||||
InitTimeoutSeconds int
|
||||
CallTimeoutSeconds int
|
||||
MaxResponseBytes int
|
||||
IncludeTools []string
|
||||
ExcludeTools []string
|
||||
}
|
||||
|
||||
func (c ServerConfig) InitTimeout() time.Duration {
|
||||
seconds := c.InitTimeoutSeconds
|
||||
if seconds <= 0 {
|
||||
seconds = defaultInitTimeoutSeconds
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c ServerConfig) CallTimeout() time.Duration {
|
||||
seconds := c.CallTimeoutSeconds
|
||||
if seconds <= 0 {
|
||||
seconds = defaultCallTimeoutSeconds
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c ServerConfig) ResponseLimit() int {
|
||||
if c.MaxResponseBytes <= 0 {
|
||||
return defaultMaxResponseBytes
|
||||
}
|
||||
return c.MaxResponseBytes
|
||||
}
|
||||
|
||||
// RemoteTool is an MCP tool discovered from a server.
|
||||
type RemoteTool struct {
|
||||
Name string
|
||||
Description string
|
||||
InputSchema map[string]any
|
||||
}
|
||||
|
||||
// RegisteredTool is a discovered tool with a PicoClaw-facing qualified name.
|
||||
type RegisteredTool struct {
|
||||
QualifiedName string
|
||||
ServerName string
|
||||
ToolName string
|
||||
Description string
|
||||
Parameters map[string]any
|
||||
}
|
||||
|
||||
// CallResult is a normalized MCP tool call result.
|
||||
type CallResult struct {
|
||||
Content string
|
||||
IsError bool
|
||||
}
|
||||
@@ -217,12 +217,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
name, args, ok := resolveCodexToolCall(tc)
|
||||
if !ok {
|
||||
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{
|
||||
"call_id": tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Arguments: string(argsJSON),
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -260,10 +266,6 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
params.Instructions = openai.Opt(defaultCodexInstructions)
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForCodex(tools)
|
||||
}
|
||||
@@ -271,6 +273,30 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
return params
|
||||
}
|
||||
|
||||
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
|
||||
name = tc.Name
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
if name == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if len(tc.Arguments) > 0 {
|
||||
argsJSON, err := json.Marshal(tc.Arguments)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return name, string(argsJSON), true
|
||||
}
|
||||
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
return name, tc.Function.Arguments, true
|
||||
}
|
||||
|
||||
return name, "{}", true
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
|
||||
@@ -29,6 +29,9 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
if params.Instructions.Or("") != defaultCodexInstructions {
|
||||
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
|
||||
}
|
||||
if params.MaxOutputTokens.Valid() {
|
||||
t.Fatalf("MaxOutputTokens should not be set for Codex backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
@@ -65,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Read a file"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
if len(params.Input.OfInputItemList) != 3 {
|
||||
t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||
}
|
||||
|
||||
fc := params.Input.OfInputItemList[1].OfFunctionCall
|
||||
if fc == nil {
|
||||
t.Fatal("assistant tool call should be converted to function_call input item")
|
||||
}
|
||||
if fc.Name != "read_file" {
|
||||
t.Errorf("Function call name = %q, want %q", fc.Name, "read_file")
|
||||
}
|
||||
if fc.Arguments != `{"path":"README.md"}` {
|
||||
t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
@@ -214,6 +256,10 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
@@ -293,6 +339,10 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
|
||||
http.Error(w, "temperature is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
|
||||
+5
-2
@@ -28,12 +28,15 @@ type CronTool struct {
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool {
|
||||
// execTimeout: 0 means no timeout, >0 sets the timeout duration
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *CronTool {
|
||||
execTool := NewExecTool(workspace, restrict)
|
||||
execTool.SetTimeout(execTimeout) // 0 means no timeout
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: NewExecTool(workspace, restrict),
|
||||
execTool: execTool,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
type MCPTool struct {
|
||||
manager *mcp.Manager
|
||||
name string
|
||||
description string
|
||||
parameters map[string]any
|
||||
}
|
||||
|
||||
func NewMCPTool(manager *mcp.Manager, tool mcp.RegisteredTool) *MCPTool {
|
||||
description := tool.Description
|
||||
if description == "" {
|
||||
description = fmt.Sprintf("MCP tool %s from server %s", tool.ToolName, tool.ServerName)
|
||||
}
|
||||
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
name: tool.QualifiedName,
|
||||
description: description,
|
||||
parameters: tool.Parameters,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *MCPTool) Description() string {
|
||||
return t.description
|
||||
}
|
||||
|
||||
func (t *MCPTool) Parameters() map[string]interface{} {
|
||||
return t.parameters
|
||||
}
|
||||
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("MCP manager is not configured")
|
||||
}
|
||||
|
||||
result, err := t.manager.CallTool(ctx, t.name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool %s failed: %v", t.name, err)).WithError(err)
|
||||
}
|
||||
if result.IsError {
|
||||
err := errors.New(result.Content)
|
||||
return ErrorResult(result.Content).WithError(err)
|
||||
}
|
||||
return SilentResult(result.Content)
|
||||
}
|
||||
|
||||
// RegisterMCPTools discovers tools from MCP servers and registers them into the registry.
|
||||
func RegisterMCPTools(ctx context.Context, registry *ToolRegistry, manager *mcp.Manager) (int, error) {
|
||||
if registry == nil || manager == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
discoveredTools, err := manager.DiscoverTools(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return RegisterKnownMCPTools(registry, manager, discoveredTools), nil
|
||||
}
|
||||
|
||||
// RegisterKnownMCPTools registers already-discovered MCP tools.
|
||||
// This avoids repeated discovery work when multiple registries share one manager.
|
||||
func RegisterKnownMCPTools(registry *ToolRegistry, manager *mcp.Manager, discoveredTools []mcp.RegisteredTool) int {
|
||||
if registry == nil || manager == nil || len(discoveredTools) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, tool := range discoveredTools {
|
||||
registry.Register(NewMCPTool(manager, tool))
|
||||
}
|
||||
return len(discoveredTools)
|
||||
}
|
||||
+8
-1
@@ -89,7 +89,14 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
|
||||
return ErrorResult(guardError)
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
|
||||
// timeout == 0 means no timeout
|
||||
var cmdCtx context.Context
|
||||
var cancel context.CancelFunc
|
||||
if t.timeout > 0 {
|
||||
cmdCtx, cancel = context.WithTimeout(ctx, t.timeout)
|
||||
} else {
|
||||
cmdCtx, cancel = context.WithCancel(ctx)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
+75
-2
@@ -176,6 +176,71 @@ func stripTags(content string) string {
|
||||
return re.ReplaceAllString(content, "")
|
||||
}
|
||||
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model": "sonar",
|
||||
"messages": []map[string]string{
|
||||
{"role": "system", "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary."},
|
||||
{"role": "user", "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count)},
|
||||
},
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("Perplexity API error: %s", string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(searchResp.Choices) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type WebSearchTool struct {
|
||||
provider SearchProvider
|
||||
maxResults int
|
||||
@@ -187,14 +252,22 @@ type WebSearchToolOptions struct {
|
||||
BraveEnabled bool
|
||||
DuckDuckGoMaxResults int
|
||||
DuckDuckGoEnabled bool
|
||||
PerplexityAPIKey string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Brave > DuckDuckGo
|
||||
if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
// Priority: Perplexity > Brave > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
|
||||
Reference in New Issue
Block a user