mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into version
This commit is contained in:
@@ -11,18 +11,20 @@ PicoClaw 通过 QQ 开放平台的官方机器人 API 提供对 QQ 的支持。
|
||||
"enabled": true,
|
||||
"app_id": "YOUR_APP_ID",
|
||||
"app_secret": "YOUR_APP_SECRET",
|
||||
"allow_from": []
|
||||
"allow_from": [],
|
||||
"max_base64_file_size_mib": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 必填 | 描述 |
|
||||
| ---------- | ------ | ---- | -------------------------------- |
|
||||
| enabled | bool | 是 | 是否启用 QQ Channel |
|
||||
| app_id | string | 是 | QQ 机器人应用的 App ID |
|
||||
| app_secret | string | 是 | QQ 机器人应用的 App Secret |
|
||||
| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
|
||||
| 字段 | 类型 | 必填 | 描述 |
|
||||
| -------------------- | ------ | ---- | ------------------------------------------------------------ |
|
||||
| enabled | bool | 是 | 是否启用 QQ Channel |
|
||||
| app_id | string | 是 | QQ 机器人应用的 App ID |
|
||||
| app_secret | string | 是 | QQ 机器人应用的 App Secret |
|
||||
| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
|
||||
| max_base64_file_size_mib | int | 否 | 本地文件转 base64 上传的最大体积,单位 MiB;`0` 表示不限制。仅影响本地文件,不影响 URL 直传 |
|
||||
|
||||
## 设置流程
|
||||
|
||||
|
||||
+56
-11
@@ -158,7 +158,7 @@ and injected into the context for a configured number of turns (`ttl`).
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|----------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `enabled` | bool | false | If true, MCP tools are hidden and loaded on-demand via search. If false, all tools are loaded |
|
||||
| `enabled` | bool | false | Global default: if `true`, all MCP tools are hidden and loaded on-demand via search; if `false`, all tools are loaded into context. Individual servers can override this with the per-server `deferred` field. |
|
||||
| `ttl` | int | 5 | Number of conversational turns a discovered tool remains unlocked |
|
||||
| `max_search_results` | int | 5 | Maximum number of tools returned per search query |
|
||||
| `use_bm25` | bool | true | Enable the natural language/keyword search tool (`tool_search_tool_bm25`). **Warning**: consumes more resources than regex search |
|
||||
@@ -169,16 +169,17 @@ and injected into the context for a configured number of turns (`ttl`).
|
||||
|
||||
### Per-Server Config
|
||||
|
||||
| Config | Type | Required | Description |
|
||||
|------------|--------|----------|--------------------------------------------|
|
||||
| `enabled` | bool | yes | Enable this MCP server |
|
||||
| `type` | string | no | Transport type: `stdio`, `sse`, `http` |
|
||||
| `command` | string | stdio | Executable command for stdio transport |
|
||||
| `args` | array | no | Command arguments for stdio transport |
|
||||
| `env` | object | no | Environment variables for stdio process |
|
||||
| `env_file` | string | no | Path to environment file for stdio process |
|
||||
| `url` | string | sse/http | Endpoint URL for `sse`/`http` transport |
|
||||
| `headers` | object | no | HTTP headers for `sse`/`http` transport |
|
||||
| Config | Type | Required | Description |
|
||||
|------------|---------|----------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `enabled` | bool | yes | Enable this MCP server |
|
||||
| `deferred` | bool | no | Override deferred mode for this server only. `true` = tools are hidden and discoverable via search; `false` = tools are always visible in context. When omitted, the global `discovery.enabled` value applies. |
|
||||
| `type` | string | no | Transport type: `stdio`, `sse`, `http` |
|
||||
| `command` | string | stdio | Executable command for stdio transport |
|
||||
| `args` | array | no | Command arguments for stdio transport |
|
||||
| `env` | object | no | Environment variables for stdio process |
|
||||
| `env_file` | string | no | Path to environment file for stdio process |
|
||||
| `url` | string | sse/http | Endpoint URL for `sse`/`http` transport |
|
||||
| `headers` | object | no | HTTP headers for `sse`/`http` transport |
|
||||
|
||||
### Transport Behavior
|
||||
|
||||
@@ -291,6 +292,50 @@ dynamically only when requested by the user.*
|
||||
}
|
||||
```
|
||||
|
||||
#### 4) Mixed setup: per-server deferred override
|
||||
|
||||
*Discovery is enabled globally, but `filesystem` is pinned as always-visible while `context7` follows the global
|
||||
default (deferred). `aws` explicitly opts in to deferred mode even though it is the same as the global default.*
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcp": {
|
||||
"enabled": true,
|
||||
"discovery": {
|
||||
"enabled": true,
|
||||
"ttl": 5,
|
||||
"max_search_results": 5,
|
||||
"use_bm25": true
|
||||
},
|
||||
"servers": {
|
||||
"filesystem": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/workspace"],
|
||||
"deferred": false
|
||||
},
|
||||
"context7": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": ["-y", "@upstash/context7-mcp"]
|
||||
},
|
||||
"aws": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": ["-y", "aws-mcp-server"],
|
||||
"deferred": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **Tip:** `deferred` on a per-server basis is independent of `discovery.enabled`. You can keep
|
||||
> `discovery.enabled: false` globally (all tools visible by default) and still mark individual
|
||||
> high-volume servers as `"deferred": true` to avoid polluting the context with their tools.
|
||||
|
||||
## Skills Tool
|
||||
|
||||
The skills tool configures skill discovery and installation via registries like ClawHub.
|
||||
|
||||
+24
-1
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -111,6 +112,12 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
|
||||
for serverName, conn := range servers {
|
||||
uniqueTools += len(conn.Tools)
|
||||
|
||||
// Determine whether this server's tools should be deferred (hidden).
|
||||
// Per-server "deferred" field takes precedence over the global Discovery.Enabled.
|
||||
serverCfg := al.cfg.Tools.MCP.Servers[serverName]
|
||||
registerAsHidden := serverIsDeferred(al.cfg.Tools.MCP.Discovery.Enabled, serverCfg)
|
||||
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
@@ -120,7 +127,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
|
||||
if al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
if registerAsHidden {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
} else {
|
||||
agent.Tools.Register(mcpTool)
|
||||
@@ -133,6 +140,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
"deferred": registerAsHidden,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -198,3 +206,18 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
|
||||
return al.mcp.getInitErr()
|
||||
}
|
||||
|
||||
// serverIsDeferred reports whether an MCP server's tools should be registered
|
||||
// as hidden (deferred/discovery mode).
|
||||
//
|
||||
// The per-server Deferred field takes precedence over the global discoveryEnabled
|
||||
// default. When Deferred is nil, discoveryEnabled is used as the fallback.
|
||||
func serverIsDeferred(discoveryEnabled bool, serverCfg config.MCPServerConfig) bool {
|
||||
if !discoveryEnabled {
|
||||
return false
|
||||
}
|
||||
if serverCfg.Deferred != nil {
|
||||
return *serverCfg.Deferred
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
|
||||
func TestServerIsDeferred(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
discoveryEnabled bool
|
||||
serverDeferred *bool
|
||||
want bool
|
||||
}{
|
||||
// --- global false always wins: per-server deferred is ignored ---
|
||||
{
|
||||
name: "global false: per-server deferred=true is ignored",
|
||||
discoveryEnabled: false,
|
||||
serverDeferred: boolPtr(true),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "global false: per-server deferred=false stays false",
|
||||
discoveryEnabled: false,
|
||||
serverDeferred: boolPtr(false),
|
||||
want: false,
|
||||
},
|
||||
// --- global true: per-server override applies ---
|
||||
{
|
||||
name: "global true: per-server deferred=false opts out",
|
||||
discoveryEnabled: true,
|
||||
serverDeferred: boolPtr(false),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "global true: per-server deferred=true stays true",
|
||||
discoveryEnabled: true,
|
||||
serverDeferred: boolPtr(true),
|
||||
want: true,
|
||||
},
|
||||
// --- no per-server override: fall back to global ---
|
||||
{
|
||||
name: "no per-server field, global discovery enabled",
|
||||
discoveryEnabled: true,
|
||||
serverDeferred: nil,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no per-server field, global discovery disabled",
|
||||
discoveryEnabled: false,
|
||||
serverDeferred: nil,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
serverCfg := config.MCPServerConfig{Deferred: tt.serverDeferred}
|
||||
got := serverIsDeferred(tt.discoveryEnabled, serverCfg)
|
||||
if got != tt.want {
|
||||
t.Errorf("serverIsDeferred(discoveryEnabled=%v, deferred=%v) = %v, want %v",
|
||||
tt.discoveryEnabled, tt.serverDeferred, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+94
-27
@@ -86,9 +86,10 @@ type Manager struct {
|
||||
mux *http.ServeMux
|
||||
httpServer *http.Server
|
||||
mu sync.RWMutex
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
reactionUndos sync.Map // "channel:chatID" → reactionEntry
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
reactionUndos sync.Map // "channel:chatID" → reactionEntry
|
||||
channelHashes map[string]string // channel name → config hash
|
||||
}
|
||||
|
||||
type asyncTask struct {
|
||||
@@ -178,17 +179,21 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
|
||||
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
|
||||
m := &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
bus: messageBus,
|
||||
config: cfg,
|
||||
mediaStore: store,
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
bus: messageBus,
|
||||
config: cfg,
|
||||
mediaStore: store,
|
||||
channelHashes: make(map[string]string),
|
||||
}
|
||||
|
||||
if err := m.initChannels(); err != nil {
|
||||
if err := m.initChannels(&cfg.Channels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store initial config hashes for all channels
|
||||
m.channelHashes = toChannelHashes(cfg)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -232,15 +237,15 @@ func (m *Manager) initChannel(name, displayName string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) initChannels() error {
|
||||
func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
logger.InfoC("channels", "Initializing channel manager")
|
||||
|
||||
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
|
||||
if channels.Telegram.Enabled && channels.Telegram.Token != "" {
|
||||
m.initChannel("telegram", "Telegram")
|
||||
}
|
||||
|
||||
if m.config.Channels.WhatsApp.Enabled {
|
||||
waCfg := m.config.Channels.WhatsApp
|
||||
if channels.WhatsApp.Enabled {
|
||||
waCfg := channels.WhatsApp
|
||||
if waCfg.UseNative {
|
||||
m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
} else if waCfg.BridgeURL != "" {
|
||||
@@ -248,62 +253,62 @@ func (m *Manager) initChannels() error {
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.Feishu.Enabled {
|
||||
if channels.Feishu.Enabled {
|
||||
m.initChannel("feishu", "Feishu")
|
||||
}
|
||||
|
||||
if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" {
|
||||
if channels.Discord.Enabled && channels.Discord.Token != "" {
|
||||
m.initChannel("discord", "Discord")
|
||||
}
|
||||
|
||||
if m.config.Channels.MaixCam.Enabled {
|
||||
if channels.MaixCam.Enabled {
|
||||
m.initChannel("maixcam", "MaixCam")
|
||||
}
|
||||
|
||||
if m.config.Channels.QQ.Enabled {
|
||||
if channels.QQ.Enabled {
|
||||
m.initChannel("qq", "QQ")
|
||||
}
|
||||
|
||||
if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" {
|
||||
if channels.DingTalk.Enabled && channels.DingTalk.ClientID != "" {
|
||||
m.initChannel("dingtalk", "DingTalk")
|
||||
}
|
||||
|
||||
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
|
||||
if channels.Slack.Enabled && channels.Slack.BotToken != "" {
|
||||
m.initChannel("slack", "Slack")
|
||||
}
|
||||
|
||||
if m.config.Channels.Matrix.Enabled &&
|
||||
if channels.Matrix.Enabled &&
|
||||
m.config.Channels.Matrix.Homeserver != "" &&
|
||||
m.config.Channels.Matrix.UserID != "" &&
|
||||
m.config.Channels.Matrix.AccessToken != "" {
|
||||
m.initChannel("matrix", "Matrix")
|
||||
}
|
||||
|
||||
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
|
||||
if channels.LINE.Enabled && channels.LINE.ChannelAccessToken != "" {
|
||||
m.initChannel("line", "LINE")
|
||||
}
|
||||
|
||||
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
|
||||
if channels.OneBot.Enabled && channels.OneBot.WSUrl != "" {
|
||||
m.initChannel("onebot", "OneBot")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" {
|
||||
if channels.WeCom.Enabled && channels.WeCom.Token != "" {
|
||||
m.initChannel("wecom", "WeCom")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" {
|
||||
if channels.WeComAIBot.Enabled && channels.WeComAIBot.Token != "" {
|
||||
m.initChannel("wecom_aibot", "WeCom AI Bot")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
|
||||
if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" {
|
||||
m.initChannel("wecom_app", "WeCom App")
|
||||
}
|
||||
|
||||
if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" {
|
||||
if channels.Pico.Enabled && channels.Pico.Token != "" {
|
||||
m.initChannel("pico", "Pico")
|
||||
}
|
||||
|
||||
if m.config.Channels.IRC.Enabled && m.config.Channels.IRC.Server != "" {
|
||||
if channels.IRC.Enabled && channels.IRC.Server != "" {
|
||||
m.initChannel("irc", "IRC")
|
||||
}
|
||||
|
||||
@@ -825,6 +830,68 @@ func (m *Manager) GetEnabledChannels() []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// Reload updates the config reference without restarting channels.
|
||||
// This is used when channel config hasn't changed but other parts of the config have.
|
||||
func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
list := toChannelHashes(cfg)
|
||||
added, removed := compareChannels(m.channelHashes, list)
|
||||
for _, name := range removed {
|
||||
// Stop all channels
|
||||
channel := m.channels[name]
|
||||
logger.InfoCF("channels", "Stopping channel", map[string]any{
|
||||
"channel": name,
|
||||
})
|
||||
if err := channel.Stop(ctx); err != nil {
|
||||
logger.ErrorCF("channels", "Error stopping channel", map[string]any{
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
m.UnregisterChannel(name)
|
||||
}()
|
||||
}
|
||||
dispatchCtx, cancel := context.WithCancel(ctx)
|
||||
m.dispatchTask = &asyncTask{cancel: cancel}
|
||||
cc, err := toChannelConfig(cfg, added)
|
||||
if err != nil {
|
||||
logger.ErrorC("channels", fmt.Sprintf("toChannelConfig error: %v", err))
|
||||
return err
|
||||
}
|
||||
err = m.initChannels(cc)
|
||||
if err != nil {
|
||||
logger.ErrorC("channels", fmt.Sprintf("initChannels error: %v", err))
|
||||
return err
|
||||
}
|
||||
for _, name := range added {
|
||||
channel := m.channels[name]
|
||||
logger.InfoCF("channels", "Starting channel", map[string]any{
|
||||
"channel": name,
|
||||
})
|
||||
if err := channel.Start(ctx); err != nil {
|
||||
logger.ErrorCF("channels", "Failed to start channel", map[string]any{
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Lazily create worker only after channel starts successfully
|
||||
w := newChannelWorker(name, channel)
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
go func() {
|
||||
m.RegisterChannel(name, channel)
|
||||
}()
|
||||
}
|
||||
|
||||
m.config = cfg
|
||||
m.channelHashes = toChannelHashes(cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RegisterChannel(name string, channel Channel) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func toChannelHashes(cfg *config.Config) map[string]string {
|
||||
result := make(map[string]string)
|
||||
ch := cfg.Channels
|
||||
// should not be error
|
||||
marshal, _ := json.Marshal(ch)
|
||||
var channelConfig map[string]map[string]any
|
||||
_ = json.Unmarshal(marshal, &channelConfig)
|
||||
|
||||
for key, value := range channelConfig {
|
||||
if !value["enabled"].(bool) {
|
||||
continue
|
||||
}
|
||||
valueBytes, _ := json.Marshal(value)
|
||||
hash := md5.Sum(valueBytes)
|
||||
result[key] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func compareChannels(old, news map[string]string) (added, removed []string) {
|
||||
for key, newHash := range news {
|
||||
if oldHash, ok := old[key]; ok {
|
||||
if newHash != oldHash {
|
||||
removed = append(removed, key)
|
||||
added = append(added, key)
|
||||
}
|
||||
} else {
|
||||
added = append(added, key)
|
||||
}
|
||||
}
|
||||
for key := range old {
|
||||
if _, ok := news[key]; !ok {
|
||||
removed = append(removed, key)
|
||||
}
|
||||
}
|
||||
return added, removed
|
||||
}
|
||||
|
||||
func toChannelConfig(cfg *config.Config, list []string) (*config.ChannelsConfig, error) {
|
||||
result := &config.ChannelsConfig{}
|
||||
ch := cfg.Channels
|
||||
// should not be error
|
||||
marshal, _ := json.Marshal(ch)
|
||||
var channelConfig map[string]map[string]any
|
||||
_ = json.Unmarshal(marshal, &channelConfig)
|
||||
temp := make(map[string]map[string]any, 0)
|
||||
|
||||
for key, value := range channelConfig {
|
||||
found := false
|
||||
for _, s := range list {
|
||||
if key == s {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found || !value["enabled"].(bool) {
|
||||
continue
|
||||
}
|
||||
temp[key] = value
|
||||
}
|
||||
|
||||
marshal, err := json.Marshal(temp)
|
||||
if err != nil {
|
||||
logger.Errorf("marshal error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(marshal, result)
|
||||
if err != nil {
|
||||
logger.Errorf("unmarshal error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func TestToChannelHashes(t *testing.T) {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
cfg := config.DefaultConfig()
|
||||
results := toChannelHashes(cfg)
|
||||
assert.Equal(t, 0, len(results))
|
||||
logger.Debugf("results: %v", results)
|
||||
cfg2 := config.DefaultConfig()
|
||||
cfg2.Channels.DingTalk.Enabled = true
|
||||
results2 := toChannelHashes(cfg2)
|
||||
assert.Equal(t, 1, len(results2))
|
||||
logger.Debugf("results2: %v", results2)
|
||||
added, removed := compareChannels(results, results2)
|
||||
assert.EqualValues(t, []string{"dingtalk"}, added)
|
||||
assert.EqualValues(t, []string(nil), removed)
|
||||
cfg3 := config.DefaultConfig()
|
||||
cfg3.Channels.Telegram.Enabled = true
|
||||
results3 := toChannelHashes(cfg3)
|
||||
assert.Equal(t, 1, len(results3))
|
||||
logger.Debugf("results3: %v", results3)
|
||||
added, removed = compareChannels(results2, results3)
|
||||
assert.EqualValues(t, []string{"dingtalk"}, removed)
|
||||
assert.EqualValues(t, []string{"telegram"}, added)
|
||||
cfg3.Channels.Telegram.Token = "114314"
|
||||
results4 := toChannelHashes(cfg3)
|
||||
assert.Equal(t, 1, len(results4))
|
||||
logger.Debugf("results4: %v", results4)
|
||||
added, removed = compareChannels(results3, results4)
|
||||
assert.EqualValues(t, []string{"telegram"}, removed)
|
||||
assert.EqualValues(t, []string{"telegram"}, added)
|
||||
cc, err := toChannelConfig(cfg3, added)
|
||||
assert.NoError(t, err)
|
||||
logger.Debugf("cc: %#v", cc.Telegram)
|
||||
assert.Equal(t, "114314", cc.Telegram.Token)
|
||||
assert.Equal(t, true, cc.Telegram.Enabled)
|
||||
cc, err = toChannelConfig(cfg2, added)
|
||||
assert.NoError(t, err)
|
||||
logger.Debugf("cc: %#v", cc.Telegram)
|
||||
assert.Equal(t, "", cc.Telegram.Token)
|
||||
assert.Equal(t, false, cc.Telegram.Enabled)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package qq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// botGoLogger preserves useful SDK info logs while demoting noisy heartbeat
|
||||
// traffic to DEBUG so long-running QQ sessions do not spam the console.
|
||||
type botGoLogger struct {
|
||||
*logger.Logger
|
||||
}
|
||||
|
||||
func newBotGoLogger(component string) *botGoLogger {
|
||||
return &botGoLogger{Logger: logger.NewLogger(component)}
|
||||
}
|
||||
|
||||
func (b *botGoLogger) Info(v ...any) {
|
||||
message := fmt.Sprint(v...)
|
||||
if shouldDemoteBotGoInfo(message) {
|
||||
b.Logger.Debug(message)
|
||||
return
|
||||
}
|
||||
b.Logger.Info(message)
|
||||
}
|
||||
|
||||
func (b *botGoLogger) Infof(format string, v ...any) {
|
||||
message := fmt.Sprintf(format, v...)
|
||||
if shouldDemoteBotGoInfo(message) {
|
||||
b.Logger.Debug(message)
|
||||
return
|
||||
}
|
||||
b.Logger.Info(message)
|
||||
}
|
||||
|
||||
func shouldDemoteBotGoInfo(message string) bool {
|
||||
return strings.Contains(message, " write Heartbeat message") ||
|
||||
strings.Contains(message, " receive HeartbeatAck message")
|
||||
}
|
||||
+403
-113
@@ -2,7 +2,15 @@ package qq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -10,9 +18,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/tencent-connect/botgo"
|
||||
"github.com/tencent-connect/botgo/constant"
|
||||
"github.com/tencent-connect/botgo/dto"
|
||||
"github.com/tencent-connect/botgo/event"
|
||||
"github.com/tencent-connect/botgo/openapi"
|
||||
"github.com/tencent-connect/botgo/openapi/options"
|
||||
"github.com/tencent-connect/botgo/token"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
@@ -21,6 +30,8 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,16 +40,29 @@ const (
|
||||
dedupMaxSize = 10000 // hard cap on dedup map entries
|
||||
typingResend = 8 * time.Second
|
||||
typingSeconds = 10
|
||||
bytesPerMiB = 1024 * 1024
|
||||
)
|
||||
|
||||
type qqAPI interface {
|
||||
WS(ctx context.Context, params map[string]string, body string) (*dto.WebsocketAP, error)
|
||||
PostGroupMessage(
|
||||
ctx context.Context, groupID string, msg dto.APIMessage, opt ...options.Option,
|
||||
) (*dto.Message, error)
|
||||
PostC2CMessage(
|
||||
ctx context.Context, userID string, msg dto.APIMessage, opt ...options.Option,
|
||||
) (*dto.Message, error)
|
||||
Transport(ctx context.Context, method, url string, body any) ([]byte, error)
|
||||
}
|
||||
|
||||
type QQChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.QQConfig
|
||||
api openapi.OpenAPI
|
||||
api qqAPI
|
||||
tokenSource oauth2.TokenSource
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
sessionManager botgo.SessionManager
|
||||
downloadFn func(urlStr, filename string) string
|
||||
|
||||
// Chat routing: track whether a chatID is group or direct.
|
||||
chatType sync.Map // chatID → "group" | "direct"
|
||||
@@ -78,7 +102,7 @@ func (c *QQChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("QQ app_id and app_secret not configured")
|
||||
}
|
||||
|
||||
botgo.SetLogger(logger.NewLogger("botgo"))
|
||||
botgo.SetLogger(newBotGoLogger("botgo"))
|
||||
logger.InfoC("qq", "Starting QQ bot (WebSocket mode)")
|
||||
|
||||
// Reinitialize shutdown signal for clean restart.
|
||||
@@ -199,20 +223,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
msgToCreate.Content = ""
|
||||
}
|
||||
|
||||
// Attach passive reply msg_id and msg_seq if available.
|
||||
if v, ok := c.lastMsgID.Load(msg.ChatID); ok {
|
||||
if msgID, ok := v.(string); ok && msgID != "" {
|
||||
msgToCreate.MsgID = msgID
|
||||
|
||||
// Increment msg_seq atomically for multi-part replies.
|
||||
if counterVal, ok := c.msgSeqCounters.Load(msg.ChatID); ok {
|
||||
if counter, ok := counterVal.(*atomic.Uint64); ok {
|
||||
seq := counter.Add(1)
|
||||
msgToCreate.MsgSeq = uint32(seq)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.applyPassiveReplyMetadata(msg.ChatID, msgToCreate)
|
||||
|
||||
// Sanitize URLs in group messages to avoid QQ's URL blacklist rejection.
|
||||
if chatKind == "group" {
|
||||
@@ -305,9 +316,9 @@ func (c *QQChannel) StartTyping(ctx context.Context, chatID string) (func(), err
|
||||
}
|
||||
|
||||
// SendMedia implements the channels.MediaSender interface.
|
||||
// QQ RichMediaMessage requires an HTTP/HTTPS URL — local file paths are not supported.
|
||||
// If part.Ref is already an http(s) URL it is used directly; otherwise we try
|
||||
// the media store, and skip with a warning if the resolved path is not an HTTP URL.
|
||||
// QQ group/C2C media sending is a two-step flow:
|
||||
// 1. Upload media to /files using a remote URL or base64-encoded local bytes.
|
||||
// 2. Send a msg_type=7 message using the returned file_info.
|
||||
func (c *QQChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
@@ -316,69 +327,24 @@ func (c *QQChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage)
|
||||
chatKind := c.getChatKind(msg.ChatID)
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
// If the ref is already an HTTP(S) URL, use it directly.
|
||||
mediaURL := part.Ref
|
||||
if !isHTTPURL(mediaURL) {
|
||||
// Try resolving through media store.
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, no media store available", map[string]any{
|
||||
"ref": part.Ref,
|
||||
})
|
||||
continue
|
||||
fileInfo, err := c.uploadMedia(ctx, chatKind, msg.ChatID, part)
|
||||
if err != nil {
|
||||
logger.ErrorCF("qq", "Failed to upload media", map[string]any{
|
||||
"type": part.Type,
|
||||
"chat_id": msg.ChatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
if errors.Is(err, channels.ErrSendFailed) {
|
||||
return err
|
||||
}
|
||||
|
||||
resolved, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
logger.ErrorCF("qq", "Failed to resolve media ref", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if !isHTTPURL(resolved) {
|
||||
logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, local files not supported", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"resolved": resolved,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
mediaURL = resolved
|
||||
return fmt.Errorf("qq send media: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
// Map part type to QQ file type: 1=image, 2=video, 3=audio, 4=file.
|
||||
var fileType uint64
|
||||
switch part.Type {
|
||||
case "image":
|
||||
fileType = 1
|
||||
case "video":
|
||||
fileType = 2
|
||||
case "audio":
|
||||
fileType = 3
|
||||
default:
|
||||
fileType = 4 // file
|
||||
}
|
||||
|
||||
richMedia := &dto.RichMediaMessage{
|
||||
FileType: fileType,
|
||||
URL: mediaURL,
|
||||
SrvSendMsg: true,
|
||||
}
|
||||
|
||||
var sendErr error
|
||||
if chatKind == "group" {
|
||||
_, sendErr = c.api.PostGroupMessage(ctx, msg.ChatID, richMedia)
|
||||
} else {
|
||||
_, sendErr = c.api.PostC2CMessage(ctx, msg.ChatID, richMedia)
|
||||
}
|
||||
|
||||
if sendErr != nil {
|
||||
if err := c.sendUploadedMedia(ctx, chatKind, msg.ChatID, part, fileInfo); err != nil {
|
||||
logger.ErrorCF("qq", "Failed to send media", map[string]any{
|
||||
"type": part.Type,
|
||||
"chat_id": msg.ChatID,
|
||||
"error": sendErr.Error(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("qq send media: %w", channels.ErrTemporary)
|
||||
}
|
||||
@@ -387,6 +353,161 @@ func (c *QQChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
type qqMediaUpload struct {
|
||||
FileType uint64 `json:"file_type"`
|
||||
URL string `json:"url,omitempty"`
|
||||
FileData string `json:"file_data,omitempty"`
|
||||
SrvSendMsg bool `json:"srv_send_msg,omitempty"`
|
||||
}
|
||||
|
||||
func (c *QQChannel) uploadMedia(
|
||||
ctx context.Context,
|
||||
chatKind, chatID string,
|
||||
part bus.MediaPart,
|
||||
) ([]byte, error) {
|
||||
payload, err := c.buildMediaUpload(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := c.api.Transport(ctx, http.MethodPost, c.mediaUploadURL(chatKind, chatID), payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var uploaded dto.Message
|
||||
if err := json.Unmarshal(body, &uploaded); err != nil {
|
||||
return nil, fmt.Errorf("qq decode media upload response: %w", err)
|
||||
}
|
||||
if len(uploaded.FileInfo) == 0 {
|
||||
return nil, fmt.Errorf("qq upload media: missing file_info")
|
||||
}
|
||||
|
||||
return uploaded.FileInfo, nil
|
||||
}
|
||||
|
||||
func (c *QQChannel) buildMediaUpload(part bus.MediaPart) (*qqMediaUpload, error) {
|
||||
payload := &qqMediaUpload{
|
||||
FileType: qqFileType(part.Type),
|
||||
}
|
||||
|
||||
mediaRef := part.Ref
|
||||
if isHTTPURL(mediaRef) {
|
||||
payload.URL = mediaRef
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
resolved, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qq resolve media ref %q: %v: %w", part.Ref, err, channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
if isHTTPURL(resolved) {
|
||||
payload.URL = resolved
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
if limitBytes := c.maxBase64FileSizeBytes(); limitBytes > 0 {
|
||||
info, statErr := os.Stat(resolved)
|
||||
if statErr != nil {
|
||||
return nil, fmt.Errorf("qq stat local media %q: %v: %w", resolved, statErr, channels.ErrSendFailed)
|
||||
}
|
||||
if info.Size() > limitBytes {
|
||||
return nil, fmt.Errorf(
|
||||
"qq local media %q exceeds max_base64_file_size_mib (%d > %d bytes): %w",
|
||||
resolved,
|
||||
info.Size(),
|
||||
limitBytes,
|
||||
channels.ErrSendFailed,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(resolved)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qq read local media %q: %v: %w", resolved, err, channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
payload.FileData = base64.StdEncoding.EncodeToString(data)
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (c *QQChannel) sendUploadedMedia(
|
||||
ctx context.Context,
|
||||
chatKind, chatID string,
|
||||
part bus.MediaPart,
|
||||
fileInfo []byte,
|
||||
) error {
|
||||
msg := &dto.MessageToCreate{
|
||||
Content: part.Caption,
|
||||
MsgType: dto.RichMediaMsg,
|
||||
Media: &dto.MediaInfo{
|
||||
FileInfo: fileInfo,
|
||||
},
|
||||
}
|
||||
c.applyPassiveReplyMetadata(chatID, msg)
|
||||
|
||||
if chatKind == "group" && msg.Content != "" {
|
||||
msg.Content = sanitizeURLs(msg.Content)
|
||||
}
|
||||
|
||||
if chatKind == "group" {
|
||||
_, err := c.api.PostGroupMessage(ctx, chatID, msg)
|
||||
return err
|
||||
}
|
||||
_, err := c.api.PostC2CMessage(ctx, chatID, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *QQChannel) applyPassiveReplyMetadata(chatID string, msg *dto.MessageToCreate) {
|
||||
if v, ok := c.lastMsgID.Load(chatID); ok {
|
||||
if msgID, ok := v.(string); ok && msgID != "" {
|
||||
msg.MsgID = msgID
|
||||
|
||||
// Increment msg_seq atomically for multi-part replies.
|
||||
if counterVal, ok := c.msgSeqCounters.Load(chatID); ok {
|
||||
if counter, ok := counterVal.(*atomic.Uint64); ok {
|
||||
seq := counter.Add(1)
|
||||
msg.MsgSeq = uint32(seq)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *QQChannel) mediaUploadURL(chatKind, chatID string) string {
|
||||
base := constant.APIDomain
|
||||
if chatKind == "group" {
|
||||
return fmt.Sprintf("%s/v2/groups/%s/files", base, chatID)
|
||||
}
|
||||
return fmt.Sprintf("%s/v2/users/%s/files", base, chatID)
|
||||
}
|
||||
|
||||
func qqFileType(partType string) uint64 {
|
||||
switch partType {
|
||||
case "image":
|
||||
return 1
|
||||
case "video":
|
||||
return 2
|
||||
case "audio":
|
||||
return 3
|
||||
default:
|
||||
return 4
|
||||
}
|
||||
}
|
||||
|
||||
func (c *QQChannel) maxBase64FileSizeBytes() int64 {
|
||||
if c.config.MaxBase64FileSizeMiB <= 0 {
|
||||
return 0
|
||||
}
|
||||
return c.config.MaxBase64FileSizeMiB * bytesPerMiB
|
||||
}
|
||||
|
||||
// handleC2CMessage handles QQ private messages.
|
||||
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
|
||||
@@ -404,16 +525,30 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extract message content
|
||||
content := data.Content
|
||||
if content == "" {
|
||||
logger.DebugC("qq", "Received empty message, ignoring")
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "qq",
|
||||
PlatformID: data.Author.ID,
|
||||
CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(data.Content)
|
||||
mediaPaths, attachmentNotes := c.extractInboundAttachments(senderID, data.ID, data.Attachments)
|
||||
for _, note := range attachmentNotes {
|
||||
content = appendContent(content, note)
|
||||
}
|
||||
if content == "" && len(mediaPaths) == 0 {
|
||||
logger.DebugC("qq", "Received empty C2C message with no attachments, ignoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("qq", "Received C2C message", map[string]any{
|
||||
"sender": senderID,
|
||||
"length": len(content),
|
||||
"sender": senderID,
|
||||
"length": len(content),
|
||||
"media_count": len(mediaPaths),
|
||||
})
|
||||
|
||||
// Store chat routing context.
|
||||
@@ -427,23 +562,13 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
"account_id": senderID,
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "qq",
|
||||
PlatformID: data.Author.ID,
|
||||
CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
bus.Peer{Kind: "direct", ID: senderID},
|
||||
data.ID,
|
||||
senderID,
|
||||
senderID,
|
||||
content,
|
||||
[]string{},
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
@@ -469,24 +594,38 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extract message content (remove @ bot part)
|
||||
content := data.Content
|
||||
if content == "" {
|
||||
logger.DebugC("qq", "Received empty group message, ignoring")
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "qq",
|
||||
PlatformID: data.Author.ID,
|
||||
CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupAT event means bot is always mentioned; apply group trigger filtering
|
||||
content := strings.TrimSpace(data.Content)
|
||||
mediaPaths, attachmentNotes := c.extractInboundAttachments(data.GroupID, data.ID, data.Attachments)
|
||||
for _, note := range attachmentNotes {
|
||||
content = appendContent(content, note)
|
||||
}
|
||||
|
||||
// GroupAT event means bot is always mentioned; apply group trigger filtering.
|
||||
respond, cleaned := c.ShouldRespondInGroup(true, content)
|
||||
if !respond {
|
||||
return nil
|
||||
}
|
||||
content = cleaned
|
||||
if content == "" && len(mediaPaths) == 0 {
|
||||
logger.DebugC("qq", "Received empty group message with no attachments, ignoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("qq", "Received group AT message", map[string]any{
|
||||
"sender": senderID,
|
||||
"group": data.GroupID,
|
||||
"length": len(content),
|
||||
"sender": senderID,
|
||||
"group": data.GroupID,
|
||||
"length": len(content),
|
||||
"media_count": len(mediaPaths),
|
||||
})
|
||||
|
||||
// Store chat routing context using GroupID as chatID.
|
||||
@@ -501,23 +640,13 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
"group_id": data.GroupID,
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "qq",
|
||||
PlatformID: data.Author.ID,
|
||||
CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx,
|
||||
bus.Peer{Kind: "group", ID: data.GroupID},
|
||||
data.ID,
|
||||
senderID,
|
||||
data.GroupID,
|
||||
content,
|
||||
[]string{},
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
@@ -526,6 +655,157 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *QQChannel) extractInboundAttachments(
|
||||
chatID, messageID string,
|
||||
attachments []*dto.MessageAttachment,
|
||||
) ([]string, []string) {
|
||||
if len(attachments) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
scope := channels.BuildMediaScope("qq", chatID, messageID)
|
||||
mediaPaths := make([]string, 0, len(attachments))
|
||||
notes := make([]string, 0, len(attachments))
|
||||
|
||||
storeMedia := func(localPath string, attachment *dto.MessageAttachment) string {
|
||||
if store := c.GetMediaStore(); store != nil {
|
||||
ref, err := store.Store(localPath, media.MediaMeta{
|
||||
Filename: qqAttachmentFilename(attachment),
|
||||
ContentType: attachment.ContentType,
|
||||
Source: "qq",
|
||||
}, scope)
|
||||
if err == nil {
|
||||
return ref
|
||||
}
|
||||
}
|
||||
return localPath
|
||||
}
|
||||
|
||||
for _, attachment := range attachments {
|
||||
if attachment == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
filename := qqAttachmentFilename(attachment)
|
||||
if localPath := c.downloadAttachment(attachment.URL, filename); localPath != "" {
|
||||
mediaPaths = append(mediaPaths, storeMedia(localPath, attachment))
|
||||
} else if attachment.URL != "" {
|
||||
mediaPaths = append(mediaPaths, attachment.URL)
|
||||
}
|
||||
|
||||
notes = append(notes, qqAttachmentNote(attachment))
|
||||
}
|
||||
|
||||
return mediaPaths, notes
|
||||
}
|
||||
|
||||
func (c *QQChannel) downloadAttachment(urlStr, filename string) string {
|
||||
if urlStr == "" {
|
||||
return ""
|
||||
}
|
||||
if c.downloadFn != nil {
|
||||
return c.downloadFn(urlStr, filename)
|
||||
}
|
||||
|
||||
return utils.DownloadFile(urlStr, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "qq",
|
||||
ExtraHeaders: c.downloadHeaders(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *QQChannel) downloadHeaders() map[string]string {
|
||||
headers := map[string]string{}
|
||||
|
||||
if c.config.AppID != "" {
|
||||
headers["X-Union-Appid"] = c.config.AppID
|
||||
}
|
||||
|
||||
if c.tokenSource != nil {
|
||||
if tk, err := c.tokenSource.Token(); err == nil && tk.AccessToken != "" {
|
||||
auth := strings.TrimSpace(tk.TokenType + " " + tk.AccessToken)
|
||||
if auth != "" {
|
||||
headers["Authorization"] = auth
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func qqAttachmentFilename(attachment *dto.MessageAttachment) string {
|
||||
if attachment == nil {
|
||||
return "attachment"
|
||||
}
|
||||
if attachment.FileName != "" {
|
||||
return attachment.FileName
|
||||
}
|
||||
if attachment.URL != "" {
|
||||
if parsed, err := url.Parse(attachment.URL); err == nil {
|
||||
if base := path.Base(parsed.Path); base != "" && base != "." && base != "/" {
|
||||
return base
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch qqAttachmentKind(attachment) {
|
||||
case "image":
|
||||
return "image"
|
||||
case "audio":
|
||||
return "audio"
|
||||
case "video":
|
||||
return "video"
|
||||
default:
|
||||
return "attachment"
|
||||
}
|
||||
}
|
||||
|
||||
func qqAttachmentKind(attachment *dto.MessageAttachment) string {
|
||||
if attachment == nil {
|
||||
return "file"
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(attachment.ContentType)
|
||||
filename := strings.ToLower(attachment.FileName)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(contentType, "image/"):
|
||||
return "image"
|
||||
case strings.HasPrefix(contentType, "video/"):
|
||||
return "video"
|
||||
case strings.HasPrefix(contentType, "audio/"), contentType == "application/ogg", contentType == "application/x-ogg":
|
||||
return "audio"
|
||||
}
|
||||
|
||||
switch filepath.Ext(filename) {
|
||||
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg":
|
||||
return "image"
|
||||
case ".mp4", ".avi", ".mov", ".webm", ".mkv":
|
||||
return "video"
|
||||
case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus", ".silk":
|
||||
return "audio"
|
||||
default:
|
||||
return "file"
|
||||
}
|
||||
}
|
||||
|
||||
func qqAttachmentNote(attachment *dto.MessageAttachment) string {
|
||||
filename := qqAttachmentFilename(attachment)
|
||||
|
||||
switch qqAttachmentKind(attachment) {
|
||||
case "image":
|
||||
return fmt.Sprintf("[image: %s]", filename)
|
||||
case "audio":
|
||||
return fmt.Sprintf("[audio: %s]", filename)
|
||||
case "video":
|
||||
return fmt.Sprintf("[video: %s]", filename)
|
||||
default:
|
||||
return fmt.Sprintf("[file: %s]", filename)
|
||||
}
|
||||
}
|
||||
|
||||
// isDuplicate checks whether a message has been seen within the TTL window.
|
||||
// It also enforces a hard cap on map size by evicting oldest entries.
|
||||
func (c *QQChannel) isDuplicate(messageID string) bool {
|
||||
@@ -587,6 +867,16 @@ func isHTTPURL(s string) bool {
|
||||
return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://")
|
||||
}
|
||||
|
||||
func appendContent(content, suffix string) string {
|
||||
if suffix == "" {
|
||||
return content
|
||||
}
|
||||
if content == "" {
|
||||
return suffix
|
||||
}
|
||||
return content + "\n" + suffix
|
||||
}
|
||||
|
||||
// urlPattern matches URLs with explicit http(s):// scheme.
|
||||
// Only scheme-prefixed URLs are matched to avoid false positives on bare text
|
||||
// like version numbers (e.g., "1.2.3") or domain-like fragments.
|
||||
|
||||
@@ -2,13 +2,22 @@ package qq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tencent-connect/botgo/dto"
|
||||
"github.com/tencent-connect/botgo/openapi/options"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
|
||||
@@ -50,3 +59,438 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleC2CMessage_AttachmentOnlyPublishesMedia(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
store := media.NewFileMediaStore()
|
||||
localPath := writeTempFile(t, t.TempDir(), "image.png", []byte("fake-image"))
|
||||
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
downloadFn: func(urlStr, filename string) string {
|
||||
if filename != "image.png" {
|
||||
t.Fatalf("download filename = %q, want image.png", filename)
|
||||
}
|
||||
return localPath
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
err := ch.handleC2CMessage()(nil, &dto.WSC2CMessageData{
|
||||
ID: "msg-attachment",
|
||||
Content: "",
|
||||
Author: &dto.User{
|
||||
ID: "7750283E123456",
|
||||
},
|
||||
Attachments: []*dto.MessageAttachment{{
|
||||
URL: "https://example.com/image.png",
|
||||
FileName: "image.png",
|
||||
ContentType: "image/png",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handleC2CMessage() error = %v", err)
|
||||
}
|
||||
|
||||
inbound := waitInboundMessage(t, messageBus)
|
||||
if inbound.Content != "[image: image.png]" {
|
||||
t.Fatalf("inbound.Content = %q", inbound.Content)
|
||||
}
|
||||
if len(inbound.Media) != 1 {
|
||||
t.Fatalf("len(inbound.Media) = %d, want 1", len(inbound.Media))
|
||||
}
|
||||
if !strings.HasPrefix(inbound.Media[0], "media://") {
|
||||
t.Fatalf("inbound.Media[0] = %q, want media:// ref", inbound.Media[0])
|
||||
}
|
||||
_, meta, err := store.ResolveWithMeta(inbound.Media[0])
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveWithMeta() error = %v", err)
|
||||
}
|
||||
if meta.Filename != "image.png" {
|
||||
t.Fatalf("meta.Filename = %q, want image.png", meta.Filename)
|
||||
}
|
||||
if meta.ContentType != "image/png" {
|
||||
t.Fatalf("meta.ContentType = %q, want image/png", meta.ContentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGroupATMessage_AttachmentOnlyPublishesMedia(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
store := media.NewFileMediaStore()
|
||||
localPath := writeTempFile(t, t.TempDir(), "report.pdf", []byte("fake-pdf"))
|
||||
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
downloadFn: func(urlStr, filename string) string {
|
||||
if filename != "report.pdf" {
|
||||
t.Fatalf("download filename = %q, want report.pdf", filename)
|
||||
}
|
||||
return localPath
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
err := ch.handleGroupATMessage()(nil, &dto.WSGroupATMessageData{
|
||||
ID: "group-attachment",
|
||||
GroupID: "group-1",
|
||||
Content: "",
|
||||
Author: &dto.User{
|
||||
ID: "7750283E123456",
|
||||
},
|
||||
Attachments: []*dto.MessageAttachment{{
|
||||
URL: "https://example.com/report.pdf",
|
||||
FileName: "report.pdf",
|
||||
ContentType: "application/pdf",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handleGroupATMessage() error = %v", err)
|
||||
}
|
||||
|
||||
inbound := waitInboundMessage(t, messageBus)
|
||||
if inbound.Content != "[file: report.pdf]" {
|
||||
t.Fatalf("inbound.Content = %q", inbound.Content)
|
||||
}
|
||||
if len(inbound.Media) != 1 {
|
||||
t.Fatalf("len(inbound.Media) = %d, want 1", len(inbound.Media))
|
||||
}
|
||||
if !strings.HasPrefix(inbound.Media[0], "media://") {
|
||||
t.Fatalf("inbound.Media[0] = %q, want media:// ref", inbound.Media[0])
|
||||
}
|
||||
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-1" {
|
||||
t.Fatalf("inbound.Peer = %+v, want group/group-1", inbound.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_UploadsLocalFileAsBase64(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
store := media.NewFileMediaStore()
|
||||
|
||||
tmpFile, err := os.CreateTemp(t.TempDir(), "qq-media-*.png")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp() error = %v", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
content := []byte("local-image-data")
|
||||
if _, writeErr := tmpFile.Write(content); writeErr != nil {
|
||||
t.Fatalf("Write() error = %v", writeErr)
|
||||
}
|
||||
|
||||
ref, err := store.Store(tmpFile.Name(), media.MediaMeta{
|
||||
Filename: "reply.png",
|
||||
ContentType: "image/png",
|
||||
}, "qq:test")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
api := &fakeQQAPI{
|
||||
transportResp: mustJSON(t, dto.Message{FileInfo: []byte("uploaded-file-info")}),
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.SetMediaStore(store)
|
||||
ch.chatType.Store("group-1", "group")
|
||||
ch.lastMsgID.Store("group-1", "msg-1")
|
||||
ch.msgSeqCounters.Store("group-1", new(atomic.Uint64))
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "group-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "image",
|
||||
Ref: ref,
|
||||
Caption: "see https://example.com/image",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(api.transportCalls) != 1 {
|
||||
t.Fatalf("transportCalls = %d, want 1", len(api.transportCalls))
|
||||
}
|
||||
upload := api.transportCalls[0]
|
||||
if upload.method != "POST" {
|
||||
t.Fatalf("upload method = %q, want POST", upload.method)
|
||||
}
|
||||
if upload.url != "https://api.sgroup.qq.com/v2/groups/group-1/files" {
|
||||
t.Fatalf("upload url = %q", upload.url)
|
||||
}
|
||||
if upload.body.URL != "" {
|
||||
t.Fatalf("upload URL = %q, want empty", upload.body.URL)
|
||||
}
|
||||
wantBase64 := base64.StdEncoding.EncodeToString(content)
|
||||
if upload.body.FileData != wantBase64 {
|
||||
t.Fatalf("upload file_data = %q, want %q", upload.body.FileData, wantBase64)
|
||||
}
|
||||
if upload.body.FileType != 1 {
|
||||
t.Fatalf("upload file_type = %d, want 1", upload.body.FileType)
|
||||
}
|
||||
|
||||
if len(api.groupMessages) != 1 {
|
||||
t.Fatalf("groupMessages = %d, want 1", len(api.groupMessages))
|
||||
}
|
||||
msg, ok := api.groupMessages[0].(*dto.MessageToCreate)
|
||||
if !ok {
|
||||
t.Fatalf("groupMessages[0] type = %T, want *dto.MessageToCreate", api.groupMessages[0])
|
||||
}
|
||||
if msg.MsgType != dto.RichMediaMsg {
|
||||
t.Fatalf("msg.MsgType = %d, want %d", msg.MsgType, dto.RichMediaMsg)
|
||||
}
|
||||
if msg.MsgID != "msg-1" {
|
||||
t.Fatalf("msg.MsgID = %q, want msg-1", msg.MsgID)
|
||||
}
|
||||
if msg.MsgSeq != 1 {
|
||||
t.Fatalf("msg.MsgSeq = %d, want 1", msg.MsgSeq)
|
||||
}
|
||||
if msg.Content != "see https://example。com/image" {
|
||||
t.Fatalf("msg.Content = %q", msg.Content)
|
||||
}
|
||||
if msg.Media == nil || string(msg.Media.FileInfo) != "uploaded-file-info" {
|
||||
t.Fatalf("msg.Media.FileInfo = %q, want uploaded-file-info", string(msg.Media.FileInfo))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_UsesRemoteURLUploadForC2C(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
api := &fakeQQAPI{
|
||||
transportResp: mustJSON(t, dto.Message{FileInfo: []byte("remote-file-info")}),
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.chatType.Store("user-1", "direct")
|
||||
|
||||
err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "user-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "file",
|
||||
Ref: "https://cdn.example.com/report.pdf",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(api.transportCalls) != 1 {
|
||||
t.Fatalf("transportCalls = %d, want 1", len(api.transportCalls))
|
||||
}
|
||||
upload := api.transportCalls[0]
|
||||
if upload.url != "https://api.sgroup.qq.com/v2/users/user-1/files" {
|
||||
t.Fatalf("upload url = %q", upload.url)
|
||||
}
|
||||
if upload.body.URL != "https://cdn.example.com/report.pdf" {
|
||||
t.Fatalf("upload URL = %q", upload.body.URL)
|
||||
}
|
||||
if upload.body.FileData != "" {
|
||||
t.Fatalf("upload file_data = %q, want empty", upload.body.FileData)
|
||||
}
|
||||
if upload.body.FileType != 4 {
|
||||
t.Fatalf("upload file_type = %d, want 4", upload.body.FileType)
|
||||
}
|
||||
|
||||
if len(api.c2cMessages) != 1 {
|
||||
t.Fatalf("c2cMessages = %d, want 1", len(api.c2cMessages))
|
||||
}
|
||||
msg, ok := api.c2cMessages[0].(*dto.MessageToCreate)
|
||||
if !ok {
|
||||
t.Fatalf("c2cMessages[0] type = %T, want *dto.MessageToCreate", api.c2cMessages[0])
|
||||
}
|
||||
if msg.MsgType != dto.RichMediaMsg {
|
||||
t.Fatalf("msg.MsgType = %d, want %d", msg.MsgType, dto.RichMediaMsg)
|
||||
}
|
||||
if msg.Media == nil || string(msg.Media.FileInfo) != "remote-file-info" {
|
||||
t.Fatalf("msg.Media.FileInfo = %q, want remote-file-info", string(msg.Media.FileInfo))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_ReturnsSendFailedWithoutMediaStore(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
api: &fakeQQAPI{},
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.chatType.Store("group-1", "group")
|
||||
|
||||
err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "group-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "image",
|
||||
Ref: "media://missing",
|
||||
}},
|
||||
})
|
||||
if !errors.Is(err, channels.ErrSendFailed) {
|
||||
t.Fatalf("SendMedia() error = %v, want ErrSendFailed", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_ReturnsSendFailedWhenLocalFileExceedsBase64MiBLimit(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
store := media.NewFileMediaStore()
|
||||
|
||||
tmpFile, err := os.CreateTemp(t.TempDir(), "qq-media-too-large-*.bin")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp() error = %v", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
content := make([]byte, bytesPerMiB+1)
|
||||
if _, writeErr := tmpFile.Write(content); writeErr != nil {
|
||||
t.Fatalf("Write() error = %v", writeErr)
|
||||
}
|
||||
|
||||
ref, err := store.Store(tmpFile.Name(), media.MediaMeta{
|
||||
Filename: "large.bin",
|
||||
ContentType: "application/octet-stream",
|
||||
}, "qq:test")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
api := &fakeQQAPI{}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
config: config.QQConfig{
|
||||
MaxBase64FileSizeMiB: 1,
|
||||
},
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.SetMediaStore(store)
|
||||
ch.chatType.Store("group-1", "group")
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "group-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "file",
|
||||
Ref: ref,
|
||||
}},
|
||||
})
|
||||
if !errors.Is(err, channels.ErrSendFailed) {
|
||||
t.Fatalf("SendMedia() error = %v, want ErrSendFailed", err)
|
||||
}
|
||||
if len(api.transportCalls) != 0 {
|
||||
t.Fatalf("transportCalls = %d, want 0", len(api.transportCalls))
|
||||
}
|
||||
}
|
||||
|
||||
type fakeQQAPI struct {
|
||||
transportResp []byte
|
||||
transportErr error
|
||||
groupErr error
|
||||
c2cErr error
|
||||
transportCalls []fakeTransportCall
|
||||
groupMessages []dto.APIMessage
|
||||
c2cMessages []dto.APIMessage
|
||||
}
|
||||
|
||||
type fakeTransportCall struct {
|
||||
method string
|
||||
url string
|
||||
body qqMediaUpload
|
||||
}
|
||||
|
||||
func (f *fakeQQAPI) WS(
|
||||
context.Context,
|
||||
map[string]string,
|
||||
string,
|
||||
) (*dto.WebsocketAP, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeQQAPI) PostGroupMessage(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
msg dto.APIMessage,
|
||||
_ ...options.Option,
|
||||
) (*dto.Message, error) {
|
||||
f.groupMessages = append(f.groupMessages, msg)
|
||||
return &dto.Message{}, f.groupErr
|
||||
}
|
||||
|
||||
func (f *fakeQQAPI) PostC2CMessage(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
msg dto.APIMessage,
|
||||
_ ...options.Option,
|
||||
) (*dto.Message, error) {
|
||||
f.c2cMessages = append(f.c2cMessages, msg)
|
||||
return &dto.Message{}, f.c2cErr
|
||||
}
|
||||
|
||||
func (f *fakeQQAPI) Transport(_ context.Context, method, url string, body any) ([]byte, error) {
|
||||
upload, ok := body.(*qqMediaUpload)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected transport body type")
|
||||
}
|
||||
f.transportCalls = append(f.transportCalls, fakeTransportCall{
|
||||
method: method,
|
||||
url: url,
|
||||
body: *upload,
|
||||
})
|
||||
return f.transportResp, f.transportErr
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, v any) []byte {
|
||||
t.Helper()
|
||||
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func waitInboundMessage(t *testing.T, messageBus *bus.MessageBus) bus.InboundMessage {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for inbound message")
|
||||
case inbound, ok := <-messageBus.InboundChan():
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message")
|
||||
}
|
||||
return inbound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeTempFile(t *testing.T, dir, name string, content []byte) string {
|
||||
t.Helper()
|
||||
|
||||
path := dir + "/" + name
|
||||
if err := os.WriteFile(path, content, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
+13
-8
@@ -346,14 +346,15 @@ type MaixCamConfig struct {
|
||||
}
|
||||
|
||||
type QQConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
|
||||
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
|
||||
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
MaxMessageLength int `json:"max_message_length" env:"PICOCLAW_CHANNELS_QQ_MAX_MESSAGE_LENGTH"`
|
||||
SendMarkdown bool `json:"send_markdown" env:"PICOCLAW_CHANNELS_QQ_SEND_MARKDOWN"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
|
||||
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
|
||||
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
MaxMessageLength int `json:"max_message_length" env:"PICOCLAW_CHANNELS_QQ_MAX_MESSAGE_LENGTH"`
|
||||
MaxBase64FileSizeMiB int64 `json:"max_base64_file_size_mib" env:"PICOCLAW_CHANNELS_QQ_MAX_BASE64_FILE_SIZE_MIB"`
|
||||
SendMarkdown bool `json:"send_markdown" env:"PICOCLAW_CHANNELS_QQ_SEND_MARKDOWN"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
type DingTalkConfig struct {
|
||||
@@ -802,6 +803,10 @@ type ClawHubRegistryConfig struct {
|
||||
type MCPServerConfig struct {
|
||||
// Enabled indicates whether this MCP server is active
|
||||
Enabled bool `json:"enabled"`
|
||||
// Deferred controls whether this server's tools are registered as hidden (deferred/discovery mode).
|
||||
// When nil, the global Discovery.Enabled setting applies.
|
||||
// When explicitly set to true or false, it overrides the global setting for this server only.
|
||||
Deferred *bool `json:"deferred,omitempty"`
|
||||
// Command is the executable to run (e.g., "npx", "python", "/path/to/server")
|
||||
Command string `json:"command"`
|
||||
// Args are the arguments to pass to the command
|
||||
|
||||
@@ -83,11 +83,12 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
QQ: QQConfig{
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
MaxMessageLength: 2000,
|
||||
Enabled: false,
|
||||
AppID: "",
|
||||
AppSecret: "",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
MaxMessageLength: 2000,
|
||||
MaxBase64FileSizeMiB: 0,
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: false,
|
||||
|
||||
@@ -324,11 +324,12 @@ func setupAndStartServices(
|
||||
return runningServices, nil
|
||||
}
|
||||
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration, isReload bool) {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
if runningServices.ChannelManager != nil {
|
||||
// reload should not stop channel manager
|
||||
if !isReload && runningServices.ChannelManager != nil {
|
||||
runningServices.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if runningServices.DeviceService != nil {
|
||||
@@ -357,7 +358,7 @@ func shutdownGateway(
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout, false)
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
@@ -381,7 +382,7 @@ func handleConfigReload(
|
||||
logger.Infof(" New model is '%s', recreating provider...", newModel)
|
||||
|
||||
logger.Info(" Stopping all services...")
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout, true)
|
||||
|
||||
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
@@ -491,8 +492,8 @@ func restartServices(
|
||||
}
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return fmt.Errorf("error restarting channels: %w", err)
|
||||
if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil {
|
||||
return fmt.Errorf("error reload channels: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Channels restarted.")
|
||||
|
||||
|
||||
+47
-15
@@ -16,12 +16,14 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
userAgentHonest = "picoclaw/%s (+https://github.com/sipeed/picoclaw; AI assistant bot)"
|
||||
|
||||
// HTTP client timeouts for web tool providers.
|
||||
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
|
||||
@@ -913,28 +915,58 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
|
||||
doFetch := func(ua string) (*http.Response, []byte, error) {
|
||||
req, reqErr := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if reqErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create request: %w", reqErr)
|
||||
}
|
||||
req.Header.Set("User-Agent", ua)
|
||||
resp, doErr := t.client.Do(req)
|
||||
if doErr != nil {
|
||||
return nil, nil, fmt.Errorf("request failed: %w", doErr)
|
||||
}
|
||||
resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes)
|
||||
|
||||
b, readErr := io.ReadAll(resp.Body)
|
||||
return resp, b, readErr
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
resp, body, err := doFetch(userAgent)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
// Cloudflare (and similar WAFs) signal bot challenges with 403 + cf-mitigated: challenge.
|
||||
// Retry once with an honest User-Agent that identifies picoclaw, which some
|
||||
// operators explicitly allow-list for AI assistants.
|
||||
if resp.StatusCode == http.StatusForbidden && resp.Header.Get("Cf-Mitigated") == "challenge" {
|
||||
logger.DebugCF("tool", "Cloudflare challenge detected, retrying with honest User-Agent",
|
||||
map[string]any{"url": urlStr})
|
||||
honestUA := fmt.Sprintf(userAgentHonest, config.Version)
|
||||
resp2, body2, err2 := doFetch(honestUA)
|
||||
if resp2 != nil && resp2.Body != nil {
|
||||
defer resp2.Body.Close()
|
||||
}
|
||||
|
||||
if err2 == nil {
|
||||
resp, body = resp2, body2
|
||||
} else {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err2, &maxBytesErr) {
|
||||
return ErrorResult(
|
||||
fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes),
|
||||
)
|
||||
}
|
||||
return ErrorResult(err2.Error())
|
||||
}
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
@@ -1004,7 +1036,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
|
||||
truncated := len(text) > maxChars
|
||||
if truncated {
|
||||
text = text[:maxChars]
|
||||
text = text[:maxChars] + "\n[Content truncated due to size limit]"
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
|
||||
@@ -212,6 +212,132 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
||||
t.Errorf("Expected 'truncated' to be true in result")
|
||||
}
|
||||
|
||||
// Text should end with the truncation notice
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if !strings.HasSuffix(text, "[Content truncated due to size limit]") {
|
||||
t.Errorf("Expected text to end with truncation notice, got: %q", text[max(0, len(text)-60):])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_TruncationNotice verifies the truncation notice is appended
|
||||
// for all content formats (text/plain, text/html, markdown, application/json).
|
||||
func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
const truncationNotice = "[Content truncated due to size limit]"
|
||||
const maxChars = 100
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
body string
|
||||
format string
|
||||
}{
|
||||
{
|
||||
name: "plain text",
|
||||
contentType: "text/plain",
|
||||
body: strings.Repeat("a", 500),
|
||||
format: "plaintext",
|
||||
},
|
||||
{
|
||||
name: "html plaintext extractor",
|
||||
contentType: "text/html",
|
||||
body: "<html><body>" + strings.Repeat("b", 500) + "</body></html>",
|
||||
format: "plaintext",
|
||||
},
|
||||
{
|
||||
name: "html markdown extractor",
|
||||
contentType: "text/html",
|
||||
body: "<html><body>" + strings.Repeat("c", 500) + "</body></html>",
|
||||
format: "markdown",
|
||||
},
|
||||
{
|
||||
name: "json",
|
||||
contentType: "application/json",
|
||||
body: `"` + strings.Repeat("d", 500) + `"`,
|
||||
format: "plaintext",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", tt.contentType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(tt.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(maxChars, tt.format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebFetchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
||||
if result.IsError {
|
||||
t.Fatalf("unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
var resultMap map[string]any
|
||||
if err := json.Unmarshal([]byte(result.ForLLM), &resultMap); err != nil {
|
||||
t.Fatalf("failed to unmarshal result JSON: %v", err)
|
||||
}
|
||||
|
||||
text, ok := resultMap["text"].(string)
|
||||
if !ok {
|
||||
t.Fatal("missing 'text' field in result")
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(text, truncationNotice) {
|
||||
t.Errorf("expected text to end with %q, got suffix: %q", truncationNotice, text[max(0, len(text)-60):])
|
||||
}
|
||||
|
||||
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
||||
t.Errorf("expected truncated=true in result")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_NoTruncationNoticeWhenFitsInLimit verifies that the notice
|
||||
// is NOT appended when the content fits within the limit.
|
||||
func TestWebTool_WebFetch_NoTruncationNoticeWhenFitsInLimit(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
const truncationNotice = "[Content truncated due to size limit]"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("short content"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebFetchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
||||
if result.IsError {
|
||||
t.Fatalf("unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
var resultMap map[string]any
|
||||
if err := json.Unmarshal([]byte(result.ForLLM), &resultMap); err != nil {
|
||||
t.Fatalf("failed to unmarshal result JSON: %v", err)
|
||||
}
|
||||
|
||||
text, _ := resultMap["text"].(string)
|
||||
if strings.Contains(text, truncationNotice) {
|
||||
t.Errorf("expected no truncation notice for content within limit, got: %q", text)
|
||||
}
|
||||
|
||||
if truncated, _ := resultMap["truncated"].(bool); truncated {
|
||||
t.Errorf("expected truncated=false for content within limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
@@ -943,6 +1069,119 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetchTool_CloudflareChallenge_RetryWithHonestUA verifies that a 403 response
|
||||
// with cf-mitigated: challenge triggers a retry using the honest picoclaw User-Agent,
|
||||
// and that the retry response is returned when it succeeds.
|
||||
func TestWebFetchTool_CloudflareChallenge_RetryWithHonestUA(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
requestCount := 0
|
||||
var receivedUAs []string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
receivedUAs = append(receivedUAs, r.Header.Get("User-Agent"))
|
||||
|
||||
if requestCount == 1 {
|
||||
// First request: simulate Cloudflare challenge
|
||||
w.Header().Set("Cf-Mitigated", "challenge")
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte("<html><body>Cloudflare challenge</body></html>"))
|
||||
return
|
||||
}
|
||||
// Second request (honest UA retry): success
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("real content"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebFetchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success after retry, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "real content") {
|
||||
t.Errorf("expected retry response content, got: %s", result.ForLLM)
|
||||
}
|
||||
if requestCount != 2 {
|
||||
t.Errorf("expected exactly 2 requests, got %d", requestCount)
|
||||
}
|
||||
|
||||
// First request must use the generic user agent
|
||||
if receivedUAs[0] != userAgent {
|
||||
t.Errorf("first request UA = %q, want %q", receivedUAs[0], userAgent)
|
||||
}
|
||||
// Second request must use the honest picoclaw user agent
|
||||
if !strings.Contains(receivedUAs[1], "picoclaw") {
|
||||
t.Errorf("retry request UA = %q, want it to contain 'picoclaw'", receivedUAs[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetchTool_CloudflareChallenge_NoRetryOnOtherErrors verifies that a plain 403
|
||||
// (without cf-mitigated: challenge) does NOT trigger a retry.
|
||||
func TestWebFetchTool_CloudflareChallenge_NoRetryOnOtherErrors(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
requestCount := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte("plain forbidden"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebFetchTool() error: %v", err)
|
||||
}
|
||||
|
||||
tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
||||
|
||||
if requestCount != 1 {
|
||||
t.Errorf("expected exactly 1 request for plain 403, got %d", requestCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetchTool_CloudflareChallenge_RetryFailsToo verifies that if the honest-UA
|
||||
// retry also fails (e.g. still blocked), the error from the retry is returned.
|
||||
func TestWebFetchTool_CloudflareChallenge_RetryFailsToo(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Always return CF challenge regardless of UA
|
||||
w.Header().Set("Cf-Mitigated", "challenge")
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte("<html><body>still blocked</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebFetchTool() error: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
||||
|
||||
// Should not be an error — the retry response is used as-is (403 is a valid HTTP response)
|
||||
if result.IsError {
|
||||
t.Fatalf("expected non-error result even when retry is also blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
// Status in the JSON result should reflect the 403
|
||||
if !strings.Contains(result.ForLLM, "403") {
|
||||
t.Errorf("expected status 403 in result, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyPool(t *testing.T) {
|
||||
pool := NewAPIKeyPool([]string{"key1", "key2", "key3"})
|
||||
if len(pool.keys) != 3 {
|
||||
|
||||
@@ -138,6 +138,7 @@ export function GenericForm({
|
||||
real_name: t("channels.form.desc.realName"),
|
||||
channels: t("channels.form.desc.channels"),
|
||||
request_caps: t("channels.form.desc.requestCaps"),
|
||||
max_base64_file_size_mib: t("channels.form.desc.maxBase64FileSizeMiB"),
|
||||
}
|
||||
return (
|
||||
descriptions[key] ??
|
||||
|
||||
@@ -327,6 +327,7 @@
|
||||
"realName": "Displayed real name.",
|
||||
"channels": "IRC channels to join.",
|
||||
"requestCaps": "IRC capability list requested on connect.",
|
||||
"maxBase64FileSizeMiB": "Maximum size in MiB for converting local files to base64 before upload. 0 means unlimited. Applies only to local files, not URL uploads.",
|
||||
"genericField": "Used to configure {{field}}."
|
||||
}
|
||||
},
|
||||
|
||||
@@ -327,6 +327,7 @@
|
||||
"realName": "显示名称。",
|
||||
"channels": "要加入的 IRC 频道列表。",
|
||||
"requestCaps": "连接时请求的 IRC 扩展能力列表。",
|
||||
"maxBase64FileSizeMiB": "本地文件转为 base64 上传的最大体积,单位 MiB;0 表示不限制,仅影响本地文件,不影响 URL 直传。",
|
||||
"genericField": "用于配置{{field}}。"
|
||||
}
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user