mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6612ca099a | |||
| 49204df678 | |||
| d920b78b41 | |||
| 9222351871 | |||
| 8431fa3e04 | |||
| 39a451d312 | |||
| 4a80c6f58c | |||
| 9b0a48ac6d |
@@ -980,6 +980,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio
|
||||
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Obtenir Clé](https://cerebras.ai) |
|
||||
| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Obtenir Clé](https://console.volcengine.com) |
|
||||
| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
@@ -921,6 +921,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [キーを取得](https://cerebras.ai) |
|
||||
| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [キーを取得](https://console.volcengine.com) |
|
||||
| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
@@ -1034,6 +1034,7 @@ This design also enables **multi-agent support** with flexible provider selectio
|
||||
| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) |
|
||||
| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
|
||||
| **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
@@ -1504,3 +1505,4 @@ This happens when another instance of the bot is running. Make sure only one `pi
|
||||
| **SearXNG** | Unlimited (self-hosted) | Privacy-focused metasearch (70+ engines) |
|
||||
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
|
||||
| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) |
|
||||
| **LongCat** | Up to 5M tokens/day | Fast inference (free tier) |
|
||||
|
||||
@@ -976,6 +976,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve
|
||||
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Obter Chave](https://cerebras.ai) |
|
||||
| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Obter Chave](https://console.volcengine.com) |
|
||||
| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
@@ -945,6 +945,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch
|
||||
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Lấy Khóa](https://cerebras.ai) |
|
||||
| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Lấy Khóa](https://console.volcengine.com) |
|
||||
| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
@@ -517,6 +517,7 @@ Agent 读取 HEARTBEAT.md
|
||||
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [获取密钥](https://cerebras.ai) |
|
||||
| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [获取密钥](https://console.volcengine.com) |
|
||||
| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
@@ -879,3 +880,4 @@ Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
|
||||
| **Brave Search** | 2000 次查询/月 | 网络搜索功能 |
|
||||
| **Tavily** | 1000 次查询/月 | AI Agent 搜索优化 |
|
||||
| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) |
|
||||
| **LongCat** | 最多 5M tokens/天 | 推理速度快 (免费额度) |
|
||||
|
||||
@@ -35,6 +35,11 @@
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"api_key": "sk-your-deepseek-key"
|
||||
},
|
||||
{
|
||||
"model_name": "longcat",
|
||||
"model": "longcat/LongCat-Flash-Thinking",
|
||||
"api_key": "your-longcat-api-key"
|
||||
},
|
||||
{
|
||||
"model_name": "loadbalanced-gpt4",
|
||||
"model": "openai/gpt-5.2",
|
||||
@@ -274,6 +279,10 @@
|
||||
"avian": {
|
||||
"api_key": "",
|
||||
"api_base": "https://api.avian.io/v1"
|
||||
},
|
||||
"longcat": {
|
||||
"api_key": "",
|
||||
"api_base": "https://api.longcat.chat/openai"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
|
||||
@@ -22,7 +22,8 @@ Add this to `config.json`:
|
||||
"enabled": true,
|
||||
"text": "Thinking..."
|
||||
},
|
||||
"reasoning_channel_id": ""
|
||||
"reasoning_channel_id": "",
|
||||
"message_format": "richtext"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -42,10 +43,12 @@ Add this to `config.json`:
|
||||
| group_trigger | object | No | Group trigger strategy (`mention_only` / `prefixes`) |
|
||||
| placeholder | object | No | Placeholder message config |
|
||||
| reasoning_channel_id | string | No | Target channel for reasoning output |
|
||||
| message_format | string | No | Output format: `"richtext"` (default) renders markdown as HTML; `"plain"` sends plain text only |
|
||||
|
||||
## 3. Currently Supported
|
||||
|
||||
- Text message send/receive
|
||||
- Text message send/receive with markdown rendering (bold, italic, headers, code blocks, etc.)
|
||||
- Configurable message format (`richtext` / `plain`)
|
||||
- Incoming image/audio/video/file download (MediaStore first, local path fallback)
|
||||
- Incoming audio normalization into existing transcription flow (`[audio: ...]`)
|
||||
- Outgoing image/audio/video/file upload and send
|
||||
|
||||
@@ -21,6 +21,7 @@ require (
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/openai/openai-go/v3 v3.22.0
|
||||
github.com/rivo/tview v0.42.0
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/slack-go/slack v0.17.3
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
@@ -50,7 +51,6 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
github.com/segmentio/asm v1.1.3 // indirect
|
||||
github.com/segmentio/encoding v0.5.3 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
|
||||
+18
-114
@@ -25,7 +25,6 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -48,6 +47,7 @@ type AgentLoop struct {
|
||||
mediaStore media.MediaStore
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -239,119 +239,8 @@ func registerSharedTools(
|
||||
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
// Initialize MCP servers for all agents
|
||||
if al.cfg.Tools.IsToolEnabled("mcp") {
|
||||
mcpManager := mcp.NewManager()
|
||||
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
|
||||
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
|
||||
defer func() {
|
||||
if err := mcpManager.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
var workspacePath string
|
||||
if defaultAgent != nil && defaultAgent.Workspace != "" {
|
||||
workspacePath = defaultAgent.Workspace
|
||||
} else {
|
||||
workspacePath = al.cfg.WorkspacePath()
|
||||
}
|
||||
|
||||
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
|
||||
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
// Register MCP tools for all agents
|
||||
servers := mcpManager.GetServers()
|
||||
uniqueTools := 0
|
||||
totalRegistrations := 0
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
agentCount := len(agentIDs)
|
||||
|
||||
for serverName, conn := range servers {
|
||||
uniqueTools += len(conn.Tools)
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
|
||||
if al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
} else {
|
||||
agent.Tools.Register(mcpTool)
|
||||
}
|
||||
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.InfoCF("agent", "MCP tools registered successfully",
|
||||
map[string]any{
|
||||
"server_count": len(servers),
|
||||
"unique_tools": uniqueTools,
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
|
||||
// Initializes Discovery Tools only if enabled by configuration
|
||||
if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25
|
||||
useRegex := al.cfg.Tools.MCP.Discovery.UseRegex
|
||||
|
||||
// Fail fast: If discovery is enabled but no search method is turned on
|
||||
if !useBM25 && !useRegex {
|
||||
return fmt.Errorf(
|
||||
"tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration",
|
||||
)
|
||||
}
|
||||
|
||||
ttl := al.cfg.Tools.MCP.Discovery.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = 5 // Default value
|
||||
}
|
||||
|
||||
maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults
|
||||
if maxSearchResults <= 0 {
|
||||
maxSearchResults = 5 // Default value
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Initializing tool discovery", map[string]any{
|
||||
"bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults,
|
||||
})
|
||||
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
if useBM25 {
|
||||
agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for al.running.Load() {
|
||||
@@ -431,6 +320,17 @@ func (al *AgentLoop) Stop() {
|
||||
|
||||
// Close releases resources held by agent session stores. Call after Stop.
|
||||
func (al *AgentLoop) Close() {
|
||||
mcpManager := al.mcp.takeManager()
|
||||
|
||||
if mcpManager != nil {
|
||||
if err := mcpManager.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
al.registry.Close()
|
||||
}
|
||||
|
||||
@@ -619,6 +519,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
ctx context.Context,
|
||||
content, sessionKey, channel, chatID string,
|
||||
) (string, error) {
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: channel,
|
||||
SenderID: "cron",
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
type mcpRuntime struct {
|
||||
initOnce sync.Once
|
||||
mu sync.Mutex
|
||||
manager *mcp.Manager
|
||||
initErr error
|
||||
}
|
||||
|
||||
func (r *mcpRuntime) setManager(manager *mcp.Manager) {
|
||||
r.mu.Lock()
|
||||
r.manager = manager
|
||||
r.initErr = nil
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *mcpRuntime) setInitErr(err error) {
|
||||
r.mu.Lock()
|
||||
r.initErr = err
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *mcpRuntime) getInitErr() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.initErr
|
||||
}
|
||||
|
||||
func (r *mcpRuntime) takeManager() *mcp.Manager {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
manager := r.manager
|
||||
r.manager = nil
|
||||
return manager
|
||||
}
|
||||
|
||||
func (r *mcpRuntime) hasManager() bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.manager != nil
|
||||
}
|
||||
|
||||
// ensureMCPInitialized loads MCP servers/tools once so both Run() and direct
|
||||
// agent mode share the same initialization path.
|
||||
func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
if !al.cfg.Tools.IsToolEnabled("mcp") {
|
||||
return nil
|
||||
}
|
||||
|
||||
al.mcp.initOnce.Do(func() {
|
||||
mcpManager := mcp.NewManager()
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
workspacePath := al.cfg.WorkspacePath()
|
||||
if defaultAgent != nil && defaultAgent.Workspace != "" {
|
||||
workspacePath = defaultAgent.Workspace
|
||||
}
|
||||
|
||||
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
|
||||
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
if closeErr := mcpManager.Close(); closeErr != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": closeErr.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Register MCP tools for all agents
|
||||
servers := mcpManager.GetServers()
|
||||
uniqueTools := 0
|
||||
totalRegistrations := 0
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
agentCount := len(agentIDs)
|
||||
|
||||
for serverName, conn := range servers {
|
||||
uniqueTools += len(conn.Tools)
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
|
||||
if al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
} else {
|
||||
agent.Tools.Register(mcpTool)
|
||||
}
|
||||
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.InfoCF("agent", "MCP tools registered successfully",
|
||||
map[string]any{
|
||||
"server_count": len(servers),
|
||||
"unique_tools": uniqueTools,
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
|
||||
// Initializes Discovery Tools only if enabled by configuration
|
||||
if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25
|
||||
useRegex := al.cfg.Tools.MCP.Discovery.UseRegex
|
||||
|
||||
// Fail fast: If discovery is enabled but no search method is turned on
|
||||
if !useBM25 && !useRegex {
|
||||
al.mcp.setInitErr(fmt.Errorf(
|
||||
"tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration",
|
||||
))
|
||||
if closeErr := mcpManager.Close(); closeErr != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": closeErr.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ttl := al.cfg.Tools.MCP.Discovery.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = 5 // Default value
|
||||
}
|
||||
|
||||
maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults
|
||||
if maxSearchResults <= 0 {
|
||||
maxSearchResults = 5 // Default value
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Initializing tool discovery", map[string]any{
|
||||
"bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults,
|
||||
})
|
||||
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
if useBM25 {
|
||||
agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
al.mcp.setManager(mcpManager)
|
||||
})
|
||||
|
||||
return al.mcp.getInitErr()
|
||||
}
|
||||
@@ -770,6 +770,56 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessDirectWithChannel_InitializesMCPInAgentMode(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
MCP: config.MCPConfig{
|
||||
ToolConfig: config.ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
defer al.Close()
|
||||
|
||||
if al.mcp.hasManager() {
|
||||
t.Fatal("expected MCP manager to be nil before first direct processing")
|
||||
}
|
||||
|
||||
_, err = al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"hello",
|
||||
"session-1",
|
||||
"cli",
|
||||
"direct",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
|
||||
if !al.mcp.hasManager() {
|
||||
t.Fatal("expected MCP manager to be initialized in direct agent mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
dinglog "github.com/open-dingtalk/dingtalk-stream-sdk-go/logger"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -39,6 +40,9 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
|
||||
return nil, fmt.Errorf("dingtalk client_id and client_secret are required")
|
||||
}
|
||||
|
||||
// Set the logger for the Stream SDK
|
||||
dinglog.SetLogger(logger.NewLogger("dingtalk"))
|
||||
|
||||
base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(20000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
|
||||
@@ -45,6 +45,14 @@ type DiscordChannel struct {
|
||||
}
|
||||
|
||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||
discordgo.Logger = logger.NewLogger("discord").
|
||||
WithLevels(map[int]logger.LogLevel{
|
||||
discordgo.LogError: logger.ERROR,
|
||||
discordgo.LogWarning: logger.WARN,
|
||||
discordgo.LogInformational: logger.INFO,
|
||||
discordgo.LogDebug: logger.DEBUG,
|
||||
}).Log
|
||||
|
||||
session, err := discordgo.New("Bot " + cfg.Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create discord session: %w", err)
|
||||
|
||||
@@ -13,6 +13,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gomarkdown/markdown"
|
||||
mdhtml "github.com/gomarkdown/markdown/html"
|
||||
"github.com/gomarkdown/markdown/parser"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
@@ -268,6 +271,12 @@ func (c *MatrixChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func markdownToHTML(md string) string {
|
||||
p := parser.NewWithExtensions(parser.CommonExtensions | parser.AutoHeadingIDs)
|
||||
renderer := mdhtml.NewRenderer(mdhtml.RendererOptions{Flags: mdhtml.CommonFlags})
|
||||
return strings.TrimSpace(string(markdown.ToHTML([]byte(md), p, renderer)))
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
@@ -283,16 +292,22 @@ func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{
|
||||
MsgType: event.MsgText,
|
||||
Body: content,
|
||||
})
|
||||
_, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, c.messageContent(content))
|
||||
if err != nil {
|
||||
return fmt.Errorf("matrix send: %w", channels.ErrTemporary)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MatrixChannel) messageContent(text string) *event.MessageEventContent {
|
||||
mc := &event.MessageEventContent{MsgType: event.MsgText, Body: text}
|
||||
if c.config.MessageFormat != "plain" {
|
||||
mc.Format = event.FormatHTML
|
||||
mc.FormattedBody = markdownToHTML(text)
|
||||
}
|
||||
return mc
|
||||
}
|
||||
|
||||
// SendMedia implements channels.MediaSender.
|
||||
func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
@@ -482,10 +497,7 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID string, messageI
|
||||
return fmt.Errorf("matrix message ID is empty")
|
||||
}
|
||||
|
||||
editContent := &event.MessageEventContent{
|
||||
MsgType: event.MsgText,
|
||||
Body: content,
|
||||
}
|
||||
editContent := c.messageContent(content)
|
||||
editContent.SetEdit(id.EventID(messageID))
|
||||
|
||||
_, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, editContent)
|
||||
|
||||
@@ -4,12 +4,15 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
|
||||
@@ -289,3 +292,50 @@ func TestMatrixOutboundContent(t *testing.T) {
|
||||
t.Fatalf("unexpected fallback body: %q", noCaption.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownToHTML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
contains string
|
||||
}{
|
||||
{"bold", "**hello**", "<strong>hello</strong>"},
|
||||
{"italic", "_world_", "<em>world</em>"},
|
||||
{"header", "### Title", "<h3"},
|
||||
{"code block", "```\nfoo()\n```", "<code>"},
|
||||
{"inline code", "`x`", "<code>x</code>"},
|
||||
{"plain text", "just text", "just text"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := markdownToHTML(tt.input)
|
||||
if !strings.Contains(got, tt.contains) {
|
||||
t.Fatalf("markdownToHTML(%q) = %q, want it to contain %q", tt.input, got, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageContent(t *testing.T) {
|
||||
richtext := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "richtext"}}
|
||||
plain := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "plain"}}
|
||||
defaultt := &MatrixChannel{config: config.MatrixConfig{}}
|
||||
|
||||
for _, c := range []*MatrixChannel{richtext, defaultt} {
|
||||
mc := c.messageContent("**hi**")
|
||||
if mc.Format != event.FormatHTML {
|
||||
t.Errorf("format %q: expected FormatHTML, got %q", c.config.MessageFormat, mc.Format)
|
||||
}
|
||||
if !strings.Contains(mc.FormattedBody, "<strong>hi</strong>") {
|
||||
t.Errorf("format %q: FormattedBody %q missing <strong>", c.config.MessageFormat, mc.FormattedBody)
|
||||
}
|
||||
if mc.Body != "**hi**" {
|
||||
t.Errorf("format %q: Body should remain plain, got %q", c.config.MessageFormat, mc.Body)
|
||||
}
|
||||
}
|
||||
|
||||
mc := plain.messageContent("**hi**")
|
||||
if mc.Format != "" || mc.FormattedBody != "" {
|
||||
t.Errorf("plain: expected no formatting, got format=%q formattedBody=%q", mc.Format, mc.FormattedBody)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@ func (c *QQChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("QQ app_id and app_secret not configured")
|
||||
}
|
||||
|
||||
botgo.SetLogger(logger.NewLogger("botgo"))
|
||||
logger.InfoC("qq", "Starting QQ bot (WebSocket mode)")
|
||||
|
||||
// Reinitialize shutdown signal for clean restart.
|
||||
|
||||
@@ -77,6 +77,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" {
|
||||
opts = append(opts, telego.WithAPIServer(baseURL))
|
||||
}
|
||||
opts = append(opts, telego.WithLogger(logger.NewLogger("telego")))
|
||||
|
||||
bot, err := telego.NewBot(telegramCfg.Token, opts...)
|
||||
if err != nil {
|
||||
|
||||
+38
-9
@@ -17,6 +17,8 @@ var rrCounter atomic.Uint64
|
||||
|
||||
// FlexibleStringSlice is a []string that also accepts JSON numbers,
|
||||
// so allow_from can contain both "123" and 123.
|
||||
// It also supports parsing comma-separated strings from environment variables,
|
||||
// including both English (,) and Chinese (,) commas.
|
||||
type FlexibleStringSlice []string
|
||||
|
||||
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
|
||||
@@ -48,6 +50,30 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing.
|
||||
// It handles comma-separated values with both English (,) and Chinese (,) commas.
|
||||
func (f *FlexibleStringSlice) UnmarshalText(text []byte) error {
|
||||
if len(text) == 0 {
|
||||
*f = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
s := string(text)
|
||||
// Replace Chinese comma with English comma, then split
|
||||
s = strings.ReplaceAll(s, ",", ",")
|
||||
parts := strings.Split(s, ",")
|
||||
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
}
|
||||
*f = result
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Agents AgentsConfig `json:"agents"`
|
||||
Bindings []AgentBinding `json:"bindings,omitempty"`
|
||||
@@ -350,16 +376,17 @@ type SlackConfig struct {
|
||||
}
|
||||
|
||||
type MatrixConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"`
|
||||
Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"`
|
||||
UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"`
|
||||
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_MATRIX_ACCESS_TOKEN"`
|
||||
DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"`
|
||||
JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"`
|
||||
Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"`
|
||||
UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"`
|
||||
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_MATRIX_ACCESS_TOKEN"`
|
||||
DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"`
|
||||
JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"`
|
||||
MessageFormat string `json:"message_format,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_MESSAGE_FORMAT"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
type LINEConfig struct {
|
||||
@@ -500,6 +527,7 @@ type ProvidersConfig struct {
|
||||
Mistral ProviderConfig `json:"mistral"`
|
||||
Avian ProviderConfig `json:"avian"`
|
||||
Minimax ProviderConfig `json:"minimax"`
|
||||
LongCat ProviderConfig `json:"longcat"`
|
||||
}
|
||||
|
||||
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
|
||||
@@ -526,7 +554,8 @@ func (p ProvidersConfig) IsEmpty() bool {
|
||||
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
|
||||
p.Avian.APIKey == "" && p.Avian.APIBase == "" &&
|
||||
p.Minimax.APIKey == "" && p.Minimax.APIBase == ""
|
||||
p.Minimax.APIKey == "" && p.Minimax.APIBase == "" &&
|
||||
p.LongCat.APIKey == "" && p.LongCat.APIBase == ""
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
|
||||
|
||||
@@ -505,3 +505,119 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFlexibleStringSlice_UnmarshalText tests UnmarshalText with various comma separators
|
||||
func TestFlexibleStringSlice_UnmarshalText(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "English commas only",
|
||||
input: "123,456,789",
|
||||
expected: []string{"123", "456", "789"},
|
||||
},
|
||||
{
|
||||
name: "Chinese commas only",
|
||||
input: "123,456,789",
|
||||
expected: []string{"123", "456", "789"},
|
||||
},
|
||||
{
|
||||
name: "Mixed English and Chinese commas",
|
||||
input: "123,456,789",
|
||||
expected: []string{"123", "456", "789"},
|
||||
},
|
||||
{
|
||||
name: "Single value",
|
||||
input: "123",
|
||||
expected: []string{"123"},
|
||||
},
|
||||
{
|
||||
name: "Values with whitespace",
|
||||
input: " 123 , 456 , 789 ",
|
||||
expected: []string{"123", "456", "789"},
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Only commas - English",
|
||||
input: ",,",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Only commas - Chinese",
|
||||
input: ",,",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Mixed commas with empty parts",
|
||||
input: "123,,456,,789",
|
||||
expected: []string{"123", "456", "789"},
|
||||
},
|
||||
{
|
||||
name: "Complex mixed values",
|
||||
input: "user1@example.com,user2@test.com, admin@domain.org",
|
||||
expected: []string{"user1@example.com", "user2@test.com", "admin@domain.org"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var f FlexibleStringSlice
|
||||
err := f.UnmarshalText([]byte(tt.input))
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalText(%q) error = %v", tt.input, err)
|
||||
}
|
||||
|
||||
if tt.expected == nil {
|
||||
if f != nil {
|
||||
t.Errorf("UnmarshalText(%q) = %v, want nil", tt.input, f)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(f) != len(tt.expected) {
|
||||
t.Errorf("UnmarshalText(%q) length = %d, want %d", tt.input, len(f), len(tt.expected))
|
||||
return
|
||||
}
|
||||
|
||||
for i, v := range tt.expected {
|
||||
if f[i] != v {
|
||||
t.Errorf("UnmarshalText(%q)[%d] = %q, want %q", tt.input, i, f[i], v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency tests nil vs empty slice behavior
|
||||
func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
|
||||
t.Run("Empty string returns nil", func(t *testing.T) {
|
||||
var f FlexibleStringSlice
|
||||
err := f.UnmarshalText([]byte(""))
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalText error = %v", err)
|
||||
}
|
||||
if f != nil {
|
||||
t.Errorf("Empty string should return nil, got %v", f)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Commas only returns empty slice", func(t *testing.T) {
|
||||
var f FlexibleStringSlice
|
||||
err := f.UnmarshalText([]byte(",,,"))
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalText error = %v", err)
|
||||
}
|
||||
if f == nil {
|
||||
t.Error("Commas only should return empty slice, not nil")
|
||||
}
|
||||
if len(f) != 0 {
|
||||
t.Errorf("Expected empty slice, got %v", f)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -355,6 +355,14 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// LongCat - https://longcat.chat/platform
|
||||
{
|
||||
ModelName: "LongCat-Flash-Thinking",
|
||||
Model: "longcat/LongCat-Flash-Thinking",
|
||||
APIBase: "https://api.longcat.chat/openai",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// VLLM (local) - http://localhost:8000
|
||||
{
|
||||
ModelName: "local-model",
|
||||
|
||||
@@ -407,6 +407,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"longcat"},
|
||||
protocol: "longcat",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.LongCat.APIKey == "" && p.LongCat.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "longcat",
|
||||
Model: "longcat/LongCat-Flash-Thinking",
|
||||
APIKey: p.LongCat.APIKey,
|
||||
APIBase: p.LongCat.APIBase,
|
||||
Proxy: p.LongCat.Proxy,
|
||||
RequestTimeout: p.LongCat.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Process each provider migration
|
||||
|
||||
@@ -162,14 +162,15 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
Mistral: ProviderConfig{APIKey: "key18"},
|
||||
Avian: ProviderConfig{APIKey: "key19"},
|
||||
LongCat: ProviderConfig{APIKey: "key-longcat"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
// All 21 providers should be converted
|
||||
if len(result) != 21 {
|
||||
t.Errorf("len(result) = %d, want 21", len(result))
|
||||
// All 22 providers should be converted
|
||||
if len(result) != 22 {
|
||||
t.Errorf("len(result) = %d, want 22", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+110
-80
@@ -1,24 +1,24 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type LogLevel int
|
||||
type LogLevel = zerolog.Level
|
||||
|
||||
const (
|
||||
DEBUG LogLevel = iota
|
||||
INFO
|
||||
WARN
|
||||
ERROR
|
||||
FATAL
|
||||
DEBUG = zerolog.DebugLevel
|
||||
INFO = zerolog.InfoLevel
|
||||
WARN = zerolog.WarnLevel
|
||||
ERROR = zerolog.ErrorLevel
|
||||
FATAL = zerolog.FatalLevel
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -31,27 +31,24 @@ var (
|
||||
}
|
||||
|
||||
currentLevel = INFO
|
||||
logger *Logger
|
||||
logger zerolog.Logger
|
||||
fileLogger zerolog.Logger
|
||||
logFile *os.File
|
||||
once sync.Once
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
file *os.File
|
||||
}
|
||||
|
||||
type LogEntry struct {
|
||||
Level string `json:"level"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Component string `json:"component,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Fields map[string]any `json:"fields,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
once.Do(func() {
|
||||
logger = &Logger{}
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
|
||||
consoleWriter := zerolog.ConsoleWriter{
|
||||
Out: os.Stdout,
|
||||
TimeFormat: "15:04:05", // TODO: make it configurable???
|
||||
}
|
||||
|
||||
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
|
||||
fileLogger = zerolog.Logger{}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -59,6 +56,7 @@ func SetLevel(level LogLevel) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
currentLevel = level
|
||||
zerolog.SetGlobalLevel(level)
|
||||
}
|
||||
|
||||
func GetLevel() LogLevel {
|
||||
@@ -71,17 +69,22 @@ func EnableFileLogging(filePath string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create log directory: %w", err)
|
||||
}
|
||||
|
||||
newFile, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open log file: %w", err)
|
||||
}
|
||||
|
||||
if logger.file != nil {
|
||||
logger.file.Close()
|
||||
// Close old file if exists
|
||||
if logFile != nil {
|
||||
logFile.Close()
|
||||
}
|
||||
|
||||
logger.file = file
|
||||
log.Println("File logging enabled:", filePath)
|
||||
logFile = newFile
|
||||
fileLogger = zerolog.New(logFile).With().Timestamp().Caller().Logger()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -89,10 +92,57 @@ func DisableFileLogging() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if logger.file != nil {
|
||||
logger.file.Close()
|
||||
logger.file = nil
|
||||
log.Println("File logging disabled")
|
||||
if logFile != nil {
|
||||
logFile.Close()
|
||||
logFile = nil
|
||||
}
|
||||
fileLogger = zerolog.Logger{}
|
||||
}
|
||||
|
||||
func getCallerInfo() (string, int, string) {
|
||||
for i := 2; i < 15; i++ {
|
||||
pc, file, line, ok := runtime.Caller(i)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
fn := runtime.FuncForPC(pc)
|
||||
if fn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// bypass common loggers
|
||||
if strings.HasSuffix(file, "/logger.go") ||
|
||||
strings.HasSuffix(file, "/log.go") {
|
||||
continue
|
||||
}
|
||||
|
||||
funcName := fn.Name()
|
||||
if strings.HasPrefix(funcName, "runtime.") {
|
||||
continue
|
||||
}
|
||||
|
||||
return filepath.Base(file), line, filepath.Base(funcName)
|
||||
}
|
||||
|
||||
return "???", 0, "???"
|
||||
}
|
||||
|
||||
//nolint:zerologlint
|
||||
func getEvent(logger zerolog.Logger, level LogLevel) *zerolog.Event {
|
||||
switch level {
|
||||
case zerolog.DebugLevel:
|
||||
return logger.Debug()
|
||||
case zerolog.InfoLevel:
|
||||
return logger.Info()
|
||||
case zerolog.WarnLevel:
|
||||
return logger.Warn()
|
||||
case zerolog.ErrorLevel:
|
||||
return logger.Error()
|
||||
case zerolog.FatalLevel:
|
||||
return logger.Fatal()
|
||||
default:
|
||||
return logger.Info()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,65 +151,41 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
return
|
||||
}
|
||||
|
||||
entry := LogEntry{
|
||||
Level: logLevelNames[level],
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Component: component,
|
||||
Message: message,
|
||||
Fields: fields,
|
||||
}
|
||||
callerFile, callerLine, callerFunc := getCallerInfo()
|
||||
|
||||
if pc, file, line, ok := runtime.Caller(2); ok {
|
||||
fn := runtime.FuncForPC(pc)
|
||||
if fn != nil {
|
||||
entry.Caller = fmt.Sprintf("%s:%d (%s)", file, line, fn.Name())
|
||||
}
|
||||
}
|
||||
event := getEvent(logger, level)
|
||||
|
||||
if logger.file != nil {
|
||||
jsonData, err := json.Marshal(entry)
|
||||
if err == nil {
|
||||
logger.file.Write(append(jsonData, '\n'))
|
||||
}
|
||||
}
|
||||
|
||||
var fieldStr string
|
||||
if len(fields) > 0 {
|
||||
fieldStr = " " + formatFields(fields)
|
||||
// Build combined field with component and caller
|
||||
if component != "" {
|
||||
event.Str("caller", fmt.Sprintf("%-6s %s:%d (%s)", component, callerFile, callerLine, callerFunc))
|
||||
} else {
|
||||
fieldStr = ""
|
||||
event.Str("caller", fmt.Sprintf("<none> %s:%d (%s)", callerFile, callerLine, callerFunc))
|
||||
}
|
||||
|
||||
logLine := fmt.Sprintf("[%s] [%s]%s %s%s",
|
||||
entry.Timestamp,
|
||||
logLevelNames[level],
|
||||
formatComponent(component),
|
||||
message,
|
||||
fieldStr,
|
||||
)
|
||||
for k, v := range fields {
|
||||
event.Interface(k, v)
|
||||
}
|
||||
|
||||
log.Println(logLine)
|
||||
event.Msg(message)
|
||||
|
||||
// Also log to file if enabled
|
||||
if fileLogger.GetLevel() != zerolog.NoLevel {
|
||||
fileEvent := getEvent(fileLogger, level)
|
||||
|
||||
if component != "" {
|
||||
fileEvent.Str("component", component)
|
||||
}
|
||||
for k, v := range fields {
|
||||
fileEvent.Interface(k, v)
|
||||
}
|
||||
fileEvent.Msg(message)
|
||||
}
|
||||
|
||||
if level == FATAL {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func formatComponent(component string) string {
|
||||
if component == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(" %s:", component)
|
||||
}
|
||||
|
||||
func formatFields(fields map[string]any) string {
|
||||
parts := make([]string, 0, len(fields))
|
||||
for k, v := range fields {
|
||||
parts = append(parts, fmt.Sprintf("%s=%v", k, v))
|
||||
}
|
||||
return fmt.Sprintf("{%s}", strings.Join(parts, ", "))
|
||||
}
|
||||
|
||||
func Debug(message string) {
|
||||
logMessage(DEBUG, "", message, nil)
|
||||
}
|
||||
@@ -232,6 +258,10 @@ func FatalC(component string, message string) {
|
||||
logMessage(FATAL, component, message, nil)
|
||||
}
|
||||
|
||||
func Fatalf(message string, ss ...any) {
|
||||
logMessage(FATAL, "", fmt.Sprintf(message, ss...), nil)
|
||||
}
|
||||
|
||||
func FatalF(message string, fields map[string]any) {
|
||||
logMessage(FATAL, "", message, fields)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
// this file is for compatible with 3rd party loggers, should not be called in PicoClaw project
|
||||
|
||||
package logger
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Logger implements common Logger interface
|
||||
type Logger struct {
|
||||
component string
|
||||
levels map[int]LogLevel
|
||||
}
|
||||
|
||||
// Debug logs debug messages
|
||||
func (b *Logger) Debug(v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprint(v...), nil)
|
||||
}
|
||||
|
||||
// Info logs info messages
|
||||
func (b *Logger) Info(v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprint(v...), nil)
|
||||
}
|
||||
|
||||
// Warn logs warning messages
|
||||
func (b *Logger) Warn(v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprint(v...), nil)
|
||||
}
|
||||
|
||||
// Error logs error messages
|
||||
func (b *Logger) Error(v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprint(v...), nil)
|
||||
}
|
||||
|
||||
// Debugf logs formatted debug messages
|
||||
func (b *Logger) Debugf(format string, v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil)
|
||||
}
|
||||
|
||||
// Infof logs formatted info messages
|
||||
func (b *Logger) Infof(format string, v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil)
|
||||
}
|
||||
|
||||
// Warnf logs formatted warning messages
|
||||
func (b *Logger) Warnf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
}
|
||||
|
||||
// Warningf logs formatted warning messages
|
||||
func (b *Logger) Warningf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
}
|
||||
|
||||
// Errorf logs formatted error messages
|
||||
func (b *Logger) Errorf(format string, v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil)
|
||||
}
|
||||
|
||||
// Fatalf logs formatted fatal messages and exits
|
||||
func (b *Logger) Fatalf(format string, v ...any) {
|
||||
logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil)
|
||||
}
|
||||
|
||||
// Log logs a message at a given level with caller information
|
||||
// the func name must be this because 3rd party loggers expect this
|
||||
// msgL: message level (DEBUG, INFO, WARN, ERROR, FATAL)
|
||||
// caller: unused parameter reserved for compatibility
|
||||
// format: format string
|
||||
// a: format arguments
|
||||
//
|
||||
//nolint:goprintffuncname
|
||||
func (b *Logger) Log(msgL, caller int, format string, a ...any) {
|
||||
level := LogLevel(msgL)
|
||||
if b.levels != nil {
|
||||
if lvl, ok := b.levels[msgL]; ok {
|
||||
level = lvl
|
||||
}
|
||||
}
|
||||
logMessage(level, b.component, fmt.Sprintf(format, a...), nil)
|
||||
}
|
||||
|
||||
// Sync flushes log buffer (no-op for this implementation)
|
||||
func (b *Logger) Sync() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithLevels sets log levels mapping for this logger
|
||||
func (b *Logger) WithLevels(levels map[int]LogLevel) *Logger {
|
||||
b.levels = levels
|
||||
return b
|
||||
}
|
||||
|
||||
// NewLogger creates a new logger instance with optional component name
|
||||
func NewLogger(component string) *Logger {
|
||||
return &Logger{component: component}
|
||||
}
|
||||
@@ -221,6 +221,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://api.minimaxi.com/v1"
|
||||
}
|
||||
}
|
||||
case "longcat":
|
||||
if cfg.Providers.LongCat.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.LongCat.APIKey
|
||||
sel.apiBase = cfg.Providers.LongCat.APIBase
|
||||
sel.proxy = cfg.Providers.LongCat.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.longcat.chat/openai"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
sel.providerType = providerTypeGitHubCopilot
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
@@ -352,6 +361,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.avian.io/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "longcat") || strings.HasPrefix(model, "longcat/")) && cfg.Providers.LongCat.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.LongCat.APIKey
|
||||
sel.apiBase = cfg.Providers.LongCat.APIBase
|
||||
sel.proxy = cfg.Providers.LongCat.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.longcat.chat/openai"
|
||||
}
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
|
||||
@@ -95,7 +95,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
"minimax":
|
||||
"minimax", "longcat":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -215,6 +215,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://api.avian.io/v1"
|
||||
case "minimax":
|
||||
return "https://api.minimaxi.com/v1"
|
||||
case "longcat":
|
||||
return "https://api.longcat.chat/openai"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -113,6 +113,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
{"vllm", "vllm"},
|
||||
{"deepseek", "deepseek"},
|
||||
{"ollama", "ollama"},
|
||||
{"longcat", "longcat"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -162,6 +163,29 @@ func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_LongCat(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-longcat",
|
||||
Model: "longcat/LongCat-Flash-Thinking",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://api.longcat.chat/openai",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "LongCat-Flash-Thinking" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "LongCat-Flash-Thinking")
|
||||
}
|
||||
if _, ok := provider.(*HTTPProvider); !ok {
|
||||
t.Fatalf("expected *HTTPProvider, got %T", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
|
||||
@@ -178,6 +178,26 @@ func TestResolveProviderSelection(t *testing.T) {
|
||||
wantAPIBase: "https://api.moonshot.cn/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit longcat provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "longcat"
|
||||
cfg.Providers.LongCat.APIKey = "longcat-key"
|
||||
cfg.Providers.LongCat.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.longcat.chat/openai",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "longcat model fallback uses longcat base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "longcat/LongCat-Flash-Thinking"
|
||||
cfg.Providers.LongCat.APIKey = "longcat-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.longcat.chat/openai",
|
||||
},
|
||||
{
|
||||
name: "missing keys returns model config error",
|
||||
setup: func(cfg *config.Config) {
|
||||
|
||||
@@ -156,9 +156,10 @@ func (p *Provider) Chat(
|
||||
// The key is typically the agent ID — stable per agent, shared across requests.
|
||||
// See: https://platform.openai.com/docs/guides/prompt-caching
|
||||
// Prompt caching is only supported by OpenAI-native endpoints.
|
||||
// Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs.
|
||||
// Non-OpenAI providers (Mistral, Gemini, DeepSeek, etc.) reject unknown
|
||||
// fields with 422 errors, so only include it for OpenAI APIs.
|
||||
if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
|
||||
if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") {
|
||||
if supportsPromptCacheKey(p.apiBase) {
|
||||
requestBody["prompt_cache_key"] = cacheKey
|
||||
}
|
||||
}
|
||||
@@ -283,8 +284,8 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
@@ -323,12 +324,7 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
|
||||
@@ -361,6 +357,39 @@ func parseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
@@ -476,3 +505,16 @@ func asFloat(v any) (float64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
|
||||
func supportsPromptCacheKey(apiBase string) bool {
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := u.Hostname()
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
}
|
||||
|
||||
@@ -108,6 +108,55 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ParsesToolCallsWithObjectArguments(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "",
|
||||
"tool_calls": []map[string]any{
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": map[string]any{
|
||||
"city": "SF",
|
||||
"metric": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["city"] != "SF" {
|
||||
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["metric"] != true {
|
||||
t.Fatalf("ToolCalls[0].Arguments[metric] = %v, want true", out.ToolCalls[0].Arguments["metric"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ParsesReasoningContent(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
@@ -669,6 +718,111 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// chatWithCacheKey sets up a test server, sends a Chat request with prompt_cache_key,
|
||||
// and returns the decoded request body for assertion.
|
||||
func chatWithCacheKey(t *testing.T, apiBase string) map[string]any {
|
||||
t.Helper()
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
p.apiBase = apiBase
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
r.URL, _ = url.Parse(server.URL + r.URL.Path)
|
||||
return http.DefaultTransport.RoundTrip(r)
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"test-model",
|
||||
map[string]any{"prompt_cache_key": "agent-main"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) {
|
||||
body := chatWithCacheKey(t, "https://api.openai.com/v1")
|
||||
if body["prompt_cache_key"] != "agent-main" {
|
||||
t.Fatalf("prompt_cache_key = %v, want %q", body["prompt_cache_key"], "agent-main")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
}{
|
||||
{"mistral", "https://api.mistral.ai/v1"},
|
||||
{"gemini", "https://generativelanguage.googleapis.com/v1beta"},
|
||||
{"deepseek", "https://api.deepseek.com/v1"},
|
||||
{"groq", "https://api.groq.com/openai/v1"},
|
||||
{"minimax", "https://api.minimaxi.com/v1"},
|
||||
{"ollama_local", "http://localhost:11434/v1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := chatWithCacheKey(t, tt.apiBase)
|
||||
if _, exists := body["prompt_cache_key"]; exists {
|
||||
t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsPromptCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
apiBase string
|
||||
want bool
|
||||
}{
|
||||
{"https://api.openai.com/v1", true},
|
||||
{"https://api.openai.com/v1/", true},
|
||||
{"https://myresource.openai.azure.com/openai/deployments/gpt-4", true},
|
||||
{"https://eastus.openai.azure.com/v1", true},
|
||||
{"https://api.mistral.ai/v1", false},
|
||||
{"https://generativelanguage.googleapis.com/v1beta", false},
|
||||
{"https://api.deepseek.com/v1", false},
|
||||
{"https://api.groq.com/openai/v1", false},
|
||||
{"http://localhost:11434/v1", false},
|
||||
{"https://openrouter.ai/api/v1", false},
|
||||
// Edge cases: proxy URLs with openai.com in path should NOT match
|
||||
{"https://my-proxy.com/api.openai.com/v1", false},
|
||||
{"https://proxy.example.com/openai.azure.com/v1", false},
|
||||
// Malformed or empty
|
||||
{"", false},
|
||||
{"not-a-url", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := supportsPromptCacheKey(tt.apiBase); got != tt.want {
|
||||
t.Errorf("supportsPromptCacheKey(%q) = %v, want %v", tt.apiBase, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user