Merge branch 'main' into feat/subturn-poc

This commit is contained in:
Administrator
2026-03-19 20:20:31 +08:00
12 changed files with 2673 additions and 142 deletions
+22
View File
@@ -1569,8 +1569,29 @@ func (al *AgentLoop) runLLMIteration(
"iteration": iteration,
})
// Send tool feedback to chat channel if enabled
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() && opts.Channel != "" {
feedbackPreview := utils.Truncate(
string(argsJSON),
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
)
feedbackMsg := fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", tc.Name, feedbackPreview)
fbCtx, fbCancel := context.WithTimeout(ctx, 3*time.Second)
_ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: feedbackMsg,
})
fbCancel()
}
// Create async callback for tools that implement AsyncExecutor.
// When the background work completes, this publishes the result
// as an inbound system message so processSystemMessage routes it
// back to the user via the normal agent loop.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
// Send ForUser content directly to the user (immediate feedback),
// mirroring the synchronous tool execution path.
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
@@ -1581,6 +1602,7 @@ func (al *AgentLoop) runLLMIteration(
})
}
// Determine content for the agent loop (ForLLM or error).
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
+3 -1
View File
@@ -296,7 +296,9 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
m.initChannel("wecom", "WeCom")
}
if channels.WeComAIBot.Enabled && channels.WeComAIBot.Token != "" {
if m.config.Channels.WeComAIBot.Enabled &&
((m.config.Channels.WeComAIBot.BotID != "" && m.config.Channels.WeComAIBot.Secret != "") ||
m.config.Channels.WeComAIBot.Token != "") {
m.initChannel("wecom_aibot", "WeCom AI Bot")
}
+90 -11
View File
@@ -22,6 +22,10 @@ import (
"github.com/sipeed/picoclaw/pkg/utils"
)
// responseURLHTTPClient is a shared HTTP client for posting to WeCom response_url.
// Reusing it enables connection pooling across replies.
var responseURLHTTPClient = &http.Client{Timeout: 15 * time.Second}
// WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人)
type WeComAIBotChannel struct {
*channels.BaseChannel
@@ -134,13 +138,25 @@ type WeComAIBotEncryptedResponse struct {
Nonce string `json:"nonce"`
}
// NewWeComAIBotChannel creates a new WeCom AI Bot channel instance
// NewWeComAIBotChannel creates a WeCom AI Bot channel instance.
// If cfg.BotID and cfg.Secret are both set, it returns a WeComAIBotWSChannel
// using the WebSocket long-connection API.
// Otherwise it returns the webhook-mode WeComAIBotChannel (requires Token +
// EncodingAESKey).
func NewWeComAIBotChannel(
cfg config.WeComAIBotConfig,
messageBus *bus.MessageBus,
) (*WeComAIBotChannel, error) {
) (channels.Channel, error) {
// WebSocket long-connection mode takes priority when BotID + Secret are set.
if cfg.BotID != "" && cfg.Secret != "" {
logger.InfoC("wecom_aibot", "BotID and Secret provided, using WebSocket mode")
return newWeComAIBotWSChannel(cfg, messageBus)
}
// Webhook (short-connection) mode.
if cfg.Token == "" || cfg.EncodingAESKey == "" {
return nil, fmt.Errorf("token and encoding_aes_key are required for WeCom AI Bot")
return nil, fmt.Errorf(
"WeCom AI Bot requires either (bot_id + secret) for WebSocket mode " +
"or (token + encoding_aes_key) for webhook mode")
}
base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom,
@@ -782,8 +798,7 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
resp, err := responseURLHTTPClient.Do(req)
if err != nil {
return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err)
}
@@ -793,7 +808,8 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro
return nil
}
respBody, err := io.ReadAll(resp.Body)
const maxErrBody = 64 << 10 // 64 KB is more than enough for any error response
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxErrBody))
if err != nil {
return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err)
}
@@ -895,17 +911,80 @@ func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string,
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// generateStreamID generates a random stream ID
func (c *WeComAIBotChannel) generateStreamID() string {
// func (c *WeComAIBotChannel) downloadAndDecryptImage(
// ctx context.Context,
// imageURL string,
// ) ([]byte, error) {
// // Download image
// req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
// if err != nil {
// return nil, fmt.Errorf("failed to create request: %w", err)
// }
// client := &http.Client{
// Timeout: 15 * time.Second,
// }
// resp, err := client.Do(req)
// if err != nil {
// return nil, fmt.Errorf("failed to download image: %w", err)
// }
// defer resp.Body.Close()
// if resp.StatusCode != http.StatusOK {
// return nil, fmt.Errorf("download failed with status: %d", resp.StatusCode)
// }
// // Limit image download to 20 MB to prevent memory exhaustion
// const maxImageSize = 20 << 20 // 20 MB
// encryptedData, err := io.ReadAll(io.LimitReader(resp.Body, maxImageSize+1))
// if err != nil {
// return nil, fmt.Errorf("failed to read image data: %w", err)
// }
// if len(encryptedData) > maxImageSize {
// return nil, fmt.Errorf("image too large (exceeds %d MB)", maxImageSize>>20)
// }
// logger.DebugCF("wecom_aibot", "Image downloaded", map[string]any{
// "size": len(encryptedData),
// })
// // Decode AES key
// aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
// if err != nil {
// return nil, err
// }
// // Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped)
// decryptedData, err := decryptAESCBC(aesKey, encryptedData)
// if err != nil {
// return nil, fmt.Errorf("failed to decrypt image: %w", err)
// }
// logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{
// "size": len(decryptedData),
// })
// return decryptedData, nil
// }
// generateRandomID generates a cryptographically random alphanumeric ID of
// length n. Used for stream IDs and WebSocket request IDs.
func generateRandomID(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, 10)
b := make([]byte, n)
for i := range b {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b[i] = letters[n.Int64()]
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b[i] = letters[num.Int64()]
}
return string(b)
}
// generateStreamID generates a random 10-character stream ID (webhook mode).
func (c *WeComAIBotChannel) generateStreamID() string {
return generateRandomID(10)
}
// cleanupLoop periodically cleans up old streaming tasks
func (c *WeComAIBotChannel) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
+280 -35
View File
@@ -3,12 +3,16 @@ package wecom
import (
"context"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestNewWeComAIBotChannel(t *testing.T) {
// ---- Webhook mode tests ----
func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) {
t.Run("success with valid config", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
@@ -22,14 +26,16 @@ func TestNewWeComAIBotChannel(t *testing.T) {
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if ch == nil {
t.Fatal("Expected channel to be created")
}
if ch.Name() != "wecom_aibot" {
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
}
// Webhook mode must implement WebhookHandler.
if _, ok := ch.(channels.WebhookHandler); !ok {
t.Error("Webhook mode channel should implement WebhookHandler")
}
})
t.Run("error with missing token", func(t *testing.T) {
@@ -37,10 +43,8 @@ func TestNewWeComAIBotChannel(t *testing.T) {
Enabled: true,
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing token, got nil")
}
@@ -51,17 +55,15 @@ func TestNewWeComAIBotChannel(t *testing.T) {
Enabled: true,
Token: "test_token",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing encoding key, got nil")
}
})
}
func TestWeComAIBotChannelStartStop(t *testing.T) {
func TestWeComAIBotWebhookChannelStartStop(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
@@ -76,22 +78,18 @@ func TestWeComAIBotChannelStartStop(t *testing.T) {
ctx := context.Background()
// Test Start
if err := ch.Start(ctx); err != nil {
t.Fatalf("Failed to start channel: %v", err)
}
if !ch.IsRunning() {
t.Error("Expected channel to be running")
t.Error("Expected channel to be running after Start")
}
// Test Stop
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Failed to stop channel: %v", err)
}
if ch.IsRunning() {
t.Error("Expected channel to be stopped")
t.Error("Expected channel to be stopped after Stop")
}
}
@@ -102,13 +100,16 @@ func TestWeComAIBotChannelWebhookPath(t *testing.T) {
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
wh, ok := ch.(channels.WebhookHandler)
if !ok {
t.Fatal("Expected channel to implement WebhookHandler")
}
expectedPath := "/webhook/wecom-aibot"
if ch.WebhookPath() != expectedPath {
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, ch.WebhookPath())
if wh.WebhookPath() != expectedPath {
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath())
}
})
@@ -120,12 +121,15 @@ func TestWeComAIBotChannelWebhookPath(t *testing.T) {
EncodingAESKey: "testkey1234567890123456789012345678901234567",
WebhookPath: customPath,
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
if ch.WebhookPath() != customPath {
t.Errorf("Expected webhook path '%s', got '%s'", customPath, ch.WebhookPath())
wh, ok := ch.(channels.WebhookHandler)
if !ok {
t.Fatal("Expected channel to implement WebhookHandler")
}
if wh.WebhookPath() != customPath {
t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath())
}
})
}
@@ -136,19 +140,19 @@ func TestGenerateStreamID(t *testing.T) {
Token: "test_token",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
webhookCh, ok := ch.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
// Generate multiple IDs and check they are unique
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := ch.generateStreamID()
id := webhookCh.generateStreamID()
if len(id) != 10 {
t.Errorf("Expected stream ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate stream ID generated: %s", id)
}
@@ -157,35 +161,33 @@ func TestGenerateStreamID(t *testing.T) {
}
func TestEncryptDecrypt(t *testing.T) {
// Use a valid 43-character base64 key (企业微信标准格式)
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters
}
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
webhookCh, ok := ch.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
plaintext := "Hello, World!"
receiveid := ""
// Encrypt
encrypted, err := ch.encryptMessage(plaintext, receiveid)
encrypted, err := webhookCh.encryptMessage(plaintext, receiveid)
if err != nil {
t.Fatalf("Failed to encrypt message: %v", err)
}
if encrypted == "" {
t.Fatal("Encrypted message is empty")
}
// Decrypt
decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid)
if err != nil {
t.Fatalf("Failed to decrypt message: %v", err)
}
if decrypted != plaintext {
t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
}
@@ -198,13 +200,256 @@ func TestGenerateSignature(t *testing.T) {
encrypt := "encrypted_msg"
signature := computeSignature(token, timestamp, nonce, encrypt)
if signature == "" {
t.Error("Generated signature is empty")
}
// Verify signature using verifySignature function
if !verifySignature(token, signature, timestamp, nonce, encrypt) {
t.Error("Generated signature does not verify correctly")
}
}
// ---- WebSocket long-connection mode tests ----
func TestNewWeComAIBotChannel_WSMode(t *testing.T) {
t.Run("success with bot_id and secret", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if ch == nil {
t.Fatal("Expected channel to be created")
}
if ch.Name() != "wecom_aibot" {
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
}
// WebSocket mode must NOT implement WebhookHandler.
if _, ok := ch.(channels.WebhookHandler); ok {
t.Error("WebSocket mode channel should NOT implement WebhookHandler")
}
})
t.Run("ws mode takes priority over webhook fields", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
Token: "also_set",
EncodingAESKey: "testkey1234567890123456789012345678901234567",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if _, ok := ch.(*WeComAIBotWSChannel); !ok {
t.Error("Expected WebSocket mode channel when both BotID+Secret and Token+Key are set")
}
})
t.Run("error with missing bot_id", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
Secret: "test_secret",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
// Missing bot_id alone means neither WS mode nor webhook mode is fully configured.
if err == nil {
t.Fatal("Expected error for missing bot_id, got nil")
}
})
t.Run("error with missing secret", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing secret, got nil")
}
})
}
func TestWeComAIBotWSChannelStartStop(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
}
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ctx := context.Background()
// Start launches a background goroutine; it should not block or return an error.
if err := ch.Start(ctx); err != nil {
t.Fatalf("Failed to start channel: %v", err)
}
if !ch.IsRunning() {
t.Error("Expected channel to be running after Start")
}
// Stop should work regardless of whether the WebSocket actually connected.
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Failed to stop channel: %v", err)
}
if ch.IsRunning() {
t.Error("Expected channel to be stopped after Stop")
}
}
func TestGenerateRandomID(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 200; i++ {
id := generateRandomID(10)
if len(id) != 10 {
t.Errorf("Expected ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate ID generated: %s", id)
}
ids[id] = true
}
}
func TestWSGenerateID(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 200; i++ {
id := wsGenerateID()
if len(id) != 10 {
t.Errorf("Expected ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate wsGenerateID result: %s", id)
}
ids[id] = true
}
}
// ---- Webhook streaming fallback tests ----
// makeWebhookChannel creates a started WeComAIBotChannel for testing.
func makeWebhookChannel(t *testing.T) *WeComAIBotChannel {
t.Helper()
cfg := config.WeComAIBotConfig{
Enabled: true,
Token: "test_token",
EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG",
}
ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("create channel: %v", err)
}
wc := ch.(*WeComAIBotChannel)
wc.ctx, wc.cancel = context.WithCancel(context.Background())
return wc
}
// makeStreamTask creates and registers a streamTask for testing.
func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask {
t.Helper()
task := &streamTask{
StreamID: streamID,
ChatID: chatID,
Deadline: deadline,
answerCh: make(chan string, 1),
}
task.ctx, task.cancel = context.WithCancel(ch.ctx)
ch.taskMu.Lock()
ch.streamTasks[streamID] = task
ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task)
ch.taskMu.Unlock()
return task
}
// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already
// placed its answer in answerCh, getStreamResponse returns a finish=true response
// and fully removes the task.
func TestGetStreamResponse_ImmediateAnswer(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second))
task.answerCh <- "hello from agent"
result := ch.getStreamResponse(task, "ts123", "nonce123")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, exists := ch.streamTasks["stream-1"]
ch.taskMu.RUnlock()
if exists {
t.Error("task should have been removed from streamTasks after normal finish")
}
if !task.Finished {
t.Error("task.Finished should be true after normal finish")
}
}
// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has
// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the
// task alive so the response_url fallback can still deliver the answer.
func TestGetStreamResponse_DeadlinePassed(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond))
result := ch.getStreamResponse(task, "ts456", "nonce456")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, stillStreaming := ch.streamTasks["stream-2"]
ch.taskMu.RUnlock()
if stillStreaming {
t.Error("task should have been removed from streamTasks after deadline")
}
if !task.StreamClosed {
t.Error("task.StreamClosed should be true after deadline")
}
if task.Finished {
t.Error("task.Finished must remain false: agent reply still expected via response_url")
}
}
// TestGetStreamResponse_StillPending verifies that when neither the agent has
// replied nor the deadline has passed, getStreamResponse returns without altering
// task state (client should poll again).
func TestGetStreamResponse_StillPending(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second))
result := ch.getStreamResponse(task, "ts789", "nonce789")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, exists := ch.streamTasks["stream-3"]
ch.taskMu.RUnlock()
if !exists {
t.Error("pending task should still be in streamTasks")
}
if task.Finished || task.StreamClosed {
t.Error("pending task should not be finished or stream-closed")
}
// Cleanup.
ch.removeTask(task)
}
File diff suppressed because it is too large Load Diff
+295
View File
@@ -0,0 +1,295 @@
package wecom
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing.
func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel {
t.Helper()
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
Secret: "test_secret",
}
ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("create WS channel: %v", err)
}
return ch
}
// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no
// MediaStore has been injected.
func TestStoreWSMedia_NilStore(t *testing.T) {
ch := newTestWSChannel(t)
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg")
if err == nil {
t.Fatal("expected error when no MediaStore is set")
}
}
// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors
// from the media server.
func TestStoreWSMedia_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
}))
defer srv.Close()
ch := newTestWSChannel(t)
ch.SetMediaStore(media.NewFileMediaStore())
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
if err == nil {
t.Fatal("expected error for HTTP 404")
}
}
// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear
// error when the media server cannot be reached.
func TestStoreWSMedia_ServerUnavailable(t *testing.T) {
ch := newTestWSChannel(t)
ch.SetMediaStore(media.NewFileMediaStore())
// Port 1 is reserved and will refuse the connection immediately.
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg")
if err == nil {
t.Fatal("expected error for unreachable server")
}
}
// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded,
// a media ref is returned, and the file persists and is readable via Resolve until
// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used.
func TestStoreWSMedia_Success_NoAES(t *testing.T) {
imageData := bytes.Repeat([]byte("x"), 256)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageData)
}))
defer srv.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ref == "" {
t.Fatal("expected non-empty ref")
}
// File must be accessible after storeWSMedia returns (no premature deletion).
path, err := store.Resolve(ref)
if err != nil {
t.Fatalf("ref should resolve: %v", err)
}
got, err := os.ReadFile(path)
if err != nil {
t.Fatalf("file should exist at %s: %v", path, err)
}
if !bytes.Equal(got, imageData) {
t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData))
}
// ReleaseAll must delete the file (store owns lifecycle).
scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1")
if err := store.ReleaseAll(scope); err != nil {
t.Fatalf("ReleaseAll failed: %v", err)
}
if _, err := os.Stat(path); !os.IsNotExist(err) {
t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err)
}
}
// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with
// different msgIDs do not collide and each resolve to distinct files.
func TestStoreWSMedia_MultipleMessages(t *testing.T) {
imageA := bytes.Repeat([]byte("a"), 64)
imageB := bytes.Repeat([]byte("b"), 64)
srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageA)
}))
defer srvA.Close()
srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageB)
}))
defer srvB.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg")
if err != nil {
t.Fatalf("storeWSMedia A: %v", err)
}
refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg")
if err != nil {
t.Fatalf("storeWSMedia B: %v", err)
}
if refA == refB {
t.Fatal("distinct messages must produce distinct refs")
}
pathA, _ := store.Resolve(refA)
pathB, _ := store.Resolve(refB)
if pathA == pathB {
t.Fatal("distinct messages must be stored at distinct paths")
}
gotA, _ := os.ReadFile(pathA)
gotB, _ := os.ReadFile(pathB)
if !bytes.Equal(gotA, imageA) {
t.Errorf("content mismatch for message A")
}
if !bytes.Equal(gotB, imageB) {
t.Errorf("content mismatch for message B")
}
}
// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred
// from the HTTP Content-Type header and the defaultExt fallback is used when the
// type is absent or unrecognized.
func TestStoreWSMedia_ContentTypeExt(t *testing.T) {
tests := []struct {
contentType string
wantExt string
}{
{"image/jpeg", ".jpg"},
{"image/png", ".png"},
{"video/mp4", ".mp4"},
{"application/pdf", ".pdf"},
{"application/zip", ".zip"},
// With parameters stripped.
{"video/mp4; codecs=avc1", ".mp4"},
// Unknown type → falls back to defaultExt.
{"", ""},
{"application/octet-stream", ""},
}
for _, tc := range tests {
got := wsMediaExtFromContentType(tc.contentType)
if got != tc.wantExt {
t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt)
}
}
// End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin.
// The stored file should carry the .mp4 extension, not .bin.
payload := bytes.Repeat([]byte("v"), 128)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "video/mp4")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(payload)
}))
defer srv.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin")
if err != nil {
t.Fatalf("storeWSMedia: %v", err)
}
path, err := store.Resolve(ref)
if err != nil {
t.Fatalf("resolve: %v", err)
}
if ext := path[len(path)-4:]; ext != ".mp4" {
t.Errorf("expected .mp4 extension from Content-Type, got %q", ext)
}
}
// TestSplitWSContent verifies byte-aware splitting of stream content.
func TestSplitWSContent(t *testing.T) {
t.Run("short content is not split", func(t *testing.T) {
chunks := splitWSContent("hello", 20480)
if len(chunks) != 1 || chunks[0] != "hello" {
t.Fatalf("unexpected chunks: %v", chunks)
}
})
t.Run("ASCII content split at byte boundary", func(t *testing.T) {
// Build a string just over the limit.
content := strings.Repeat("a", 20481)
chunks := splitWSContent(content, 20480)
if len(chunks) < 2 {
t.Fatalf("expected >= 2 chunks, got %d", len(chunks))
}
for i, c := range chunks {
if len(c) > 20480 {
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
}
}
// Reassembled content must equal the original (possibly without leading
// whitespace that splitWSContent trims between chunks).
joined := strings.Join(chunks, "")
if len(joined) < len(content)-len(chunks) {
t.Errorf("joined length %d too short (original %d)", len(joined), len(content))
}
})
t.Run("CJK content split within byte limit", func(t *testing.T) {
// Each CJK rune is 3 bytes in UTF-8.
// 7000 CJK chars = 21000 bytes, which exceeds 20480.
content := strings.Repeat("\u4e2d", 7000)
chunks := splitWSContent(content, 20480)
if len(chunks) < 2 {
t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks))
}
for i, c := range chunks {
if len(c) > 20480 {
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
}
// Every chunk must be valid UTF-8.
if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 {
// quick plausibility check — content was pure CJK
}
}
})
}
// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter.
func TestSplitAtByteBoundary(t *testing.T) {
t.Run("ASCII fits in one chunk", func(t *testing.T) {
parts := splitAtByteBoundary("hello world", 100)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d", len(parts))
}
})
t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) {
// 10 CJK characters = 30 bytes; split at 20 bytes.
s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes
parts := splitAtByteBoundary(s, 20)
for i, p := range parts {
if len(p) > 20 {
t.Errorf("part %d has %d bytes, want <= 20", i, len(p))
}
// Must be valid UTF-8 (no torn multi-byte sequences).
for j, r := range p {
if r == '\uFFFD' {
t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j)
}
}
}
})
}
+48 -27
View File
@@ -228,25 +228,31 @@ type SubTurnConfig struct {
ConcurrencyTimeoutSec int `json:"concurrency_timeout_sec" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_CONCURRENCY_TIMEOUT_SEC"`
}
type ToolFeedbackConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_ENABLED"`
MaxArgsLength int `json:"max_args_length" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_MAX_ARGS_LENGTH"`
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
@@ -258,6 +264,19 @@ func (d *AgentDefaults) GetMaxMediaSize() int {
return DefaultMaxMediaSize
}
// GetToolFeedbackMaxArgsLength returns the max args preview length for tool feedback messages.
func (d *AgentDefaults) GetToolFeedbackMaxArgsLength() int {
if d.ToolFeedback.MaxArgsLength > 0 {
return d.ToolFeedback.MaxArgsLength
}
return 300
}
// IsToolFeedbackEnabled returns true when tool feedback messages should be sent to the chat.
func (d *AgentDefaults) IsToolFeedbackEnabled() bool {
return d.ToolFeedback.Enabled
}
// GetModelName returns the effective model name for the agent defaults.
// It prefers the new "model_name" field but falls back to "model" for backward compatibility.
func (d *AgentDefaults) GetModelName() string {
@@ -463,15 +482,17 @@ type WeComAppConfig struct {
}
type WeComAIBotConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
BotID string `json:"bot_id,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_BOT_ID"`
Secret string `json:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"`
Token string `json:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
EncodingAESKey string `json:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
WebhookPath string `json:"webhook_path,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"`
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
}
type PicoConfig struct {
+4
View File
@@ -36,6 +36,10 @@ func DefaultConfig() *Config {
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
SteeringMode: "one-at-a-time",
ToolFeedback: ToolFeedbackConfig{
Enabled: true,
MaxArgsLength: 300,
},
},
},
Bindings: []AgentBinding{},