mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #1955 from alexhoshina/refactor/wecom
Refactor/wecom
This commit is contained in:
@@ -17,6 +17,7 @@ func NewAuthCommand() *cobra.Command {
|
||||
newStatusCommand(),
|
||||
newModelsCommand(),
|
||||
newWeixinCommand(),
|
||||
newWeComCommand(),
|
||||
)
|
||||
|
||||
return cmd
|
||||
|
||||
@@ -33,6 +33,7 @@ func TestNewAuthCommand(t *testing.T) {
|
||||
"status",
|
||||
"models",
|
||||
"weixin",
|
||||
"wecom",
|
||||
}
|
||||
|
||||
subcommands := cmd.Commands()
|
||||
|
||||
@@ -0,0 +1,407 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomQRSourceID = "picoclaw"
|
||||
wecomQRGenerateEndpoint = "https://work.weixin.qq.com/ai/qc/generate"
|
||||
wecomQRQueryEndpoint = "https://work.weixin.qq.com/ai/qc/query_result"
|
||||
wecomQRPageEndpoint = "https://work.weixin.qq.com/ai/qc/gen"
|
||||
wecomQRHTTPTimeout = 15 * time.Second
|
||||
wecomQRPollInterval = 3 * time.Second
|
||||
wecomQRPollTimeout = 5 * time.Minute
|
||||
wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com"
|
||||
)
|
||||
|
||||
type wecomQRScanner func(context.Context, wecomQRFlowOptions) (wecomQRBotInfo, error)
|
||||
|
||||
type wecomQRFlowOptions struct {
|
||||
HTTPClient *http.Client
|
||||
GenerateURL string
|
||||
QueryURL string
|
||||
QRCodePageURL string
|
||||
SourceID string
|
||||
PollInterval time.Duration
|
||||
PollTimeout time.Duration
|
||||
Writer io.Writer
|
||||
}
|
||||
|
||||
type wecomQRBotInfo struct {
|
||||
BotID string
|
||||
Secret string
|
||||
}
|
||||
|
||||
type wecomQRSession struct {
|
||||
SCode string
|
||||
AuthURL string
|
||||
}
|
||||
|
||||
type wecomQRGenerateResponse struct {
|
||||
ErrCode int `json:"errcode,omitempty"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
Data struct {
|
||||
SCode string `json:"scode"`
|
||||
AuthURL string `json:"auth_url"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type wecomQRQueryResponse struct {
|
||||
ErrCode int `json:"errcode,omitempty"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
Data struct {
|
||||
Status string `json:"status"`
|
||||
BotInfo struct {
|
||||
BotID string `json:"botid"`
|
||||
Secret string `json:"secret"`
|
||||
} `json:"bot_info"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func newWeComCommand() *cobra.Command {
|
||||
var timeout time.Duration
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "wecom",
|
||||
Short: "Scan a WeCom QR code and configure channels.wecom",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
return authWeComCmd(timeout)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().DurationVar(&timeout, "timeout", wecomQRPollTimeout, "How long to wait for QR confirmation")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func authWeComCmd(timeout time.Duration) error {
|
||||
return authWeComCmdWithScanner(context.Background(), os.Stdout, timeout, scanWeComQRCodeInteractive)
|
||||
}
|
||||
|
||||
func authWeComCmdWithScanner(
|
||||
ctx context.Context,
|
||||
writer io.Writer,
|
||||
timeout time.Duration,
|
||||
scanner wecomQRScanner,
|
||||
) error {
|
||||
if scanner == nil {
|
||||
return fmt.Errorf("wecom QR scanner is nil")
|
||||
}
|
||||
if writer == nil {
|
||||
writer = os.Stdout
|
||||
}
|
||||
|
||||
cfg, err := internal.LoadConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
opts := defaultWeComQRFlowOptions(timeout)
|
||||
opts.Writer = writer
|
||||
|
||||
botInfo, err := scanner(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyWeComAuthResult(cfg, botInfo)
|
||||
|
||||
if saveErr := config.SaveConfig(internal.GetConfigPath(), cfg); saveErr != nil {
|
||||
return fmt.Errorf("failed to save config: %w", saveErr)
|
||||
}
|
||||
|
||||
fmt.Fprintln(writer)
|
||||
fmt.Fprintln(writer, "WeCom connected.")
|
||||
fmt.Fprintf(writer, "Bot ID: %s\n", botInfo.BotID)
|
||||
fmt.Fprintf(writer, "Config: %s\n", internal.GetConfigPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultWeComQRFlowOptions(timeout time.Duration) wecomQRFlowOptions {
|
||||
if timeout <= 0 {
|
||||
timeout = wecomQRPollTimeout
|
||||
}
|
||||
|
||||
return wecomQRFlowOptions{
|
||||
HTTPClient: &http.Client{Timeout: wecomQRHTTPTimeout},
|
||||
GenerateURL: wecomQRGenerateEndpoint,
|
||||
QueryURL: wecomQRQueryEndpoint,
|
||||
QRCodePageURL: wecomQRPageEndpoint,
|
||||
SourceID: wecomQRSourceID,
|
||||
PollInterval: wecomQRPollInterval,
|
||||
PollTimeout: timeout,
|
||||
Writer: os.Stdout,
|
||||
}
|
||||
}
|
||||
|
||||
func applyWeComAuthResult(cfg *config.Config, botInfo wecomQRBotInfo) {
|
||||
cfg.Channels.WeCom.Enabled = true
|
||||
cfg.Channels.WeCom.BotID = botInfo.BotID
|
||||
cfg.Channels.WeCom.SetSecret(botInfo.Secret)
|
||||
if strings.TrimSpace(cfg.Channels.WeCom.WebSocketURL) == "" {
|
||||
cfg.Channels.WeCom.WebSocketURL = wecomDefaultWebSocketURL
|
||||
}
|
||||
}
|
||||
|
||||
func scanWeComQRCodeInteractive(ctx context.Context, opts wecomQRFlowOptions) (wecomQRBotInfo, error) {
|
||||
opts = normalizeWeComQRFlowOptions(opts)
|
||||
|
||||
fmt.Fprintln(opts.Writer, "Requesting WeCom QR code...")
|
||||
|
||||
session, err := fetchWeComQRCode(ctx, opts)
|
||||
if err != nil {
|
||||
return wecomQRBotInfo{}, err
|
||||
}
|
||||
|
||||
fmt.Fprintln(opts.Writer)
|
||||
fmt.Fprintln(opts.Writer, "=======================================================")
|
||||
fmt.Fprintln(opts.Writer, "Please scan the following QR code with WeCom:")
|
||||
fmt.Fprintln(opts.Writer, "=======================================================")
|
||||
fmt.Fprintln(opts.Writer)
|
||||
|
||||
qrterminal.GenerateWithConfig(session.AuthURL, qrterminal.Config{
|
||||
Level: qrterminal.L,
|
||||
Writer: opts.Writer,
|
||||
HalfBlocks: true,
|
||||
})
|
||||
|
||||
pageURL, err := buildWeComQRCodePageURL(opts.QRCodePageURL, opts.SourceID, session.SCode)
|
||||
if err != nil {
|
||||
return wecomQRBotInfo{}, err
|
||||
}
|
||||
|
||||
fmt.Fprintln(opts.Writer)
|
||||
fmt.Fprintf(opts.Writer, "QR Code Link: %s\n", pageURL)
|
||||
fmt.Fprintln(opts.Writer)
|
||||
fmt.Fprintln(opts.Writer, "Waiting for scan...")
|
||||
|
||||
return pollWeComQRCodeResult(ctx, opts, session.SCode)
|
||||
}
|
||||
|
||||
func normalizeWeComQRFlowOptions(opts wecomQRFlowOptions) wecomQRFlowOptions {
|
||||
if opts.HTTPClient == nil {
|
||||
opts.HTTPClient = &http.Client{Timeout: wecomQRHTTPTimeout}
|
||||
}
|
||||
if strings.TrimSpace(opts.GenerateURL) == "" {
|
||||
opts.GenerateURL = wecomQRGenerateEndpoint
|
||||
}
|
||||
if strings.TrimSpace(opts.QueryURL) == "" {
|
||||
opts.QueryURL = wecomQRQueryEndpoint
|
||||
}
|
||||
if strings.TrimSpace(opts.QRCodePageURL) == "" {
|
||||
opts.QRCodePageURL = wecomQRPageEndpoint
|
||||
}
|
||||
if strings.TrimSpace(opts.SourceID) == "" {
|
||||
opts.SourceID = wecomQRSourceID
|
||||
}
|
||||
if opts.PollInterval <= 0 {
|
||||
opts.PollInterval = wecomQRPollInterval
|
||||
}
|
||||
if opts.PollTimeout <= 0 {
|
||||
opts.PollTimeout = wecomQRPollTimeout
|
||||
}
|
||||
if opts.Writer == nil {
|
||||
opts.Writer = os.Stdout
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func fetchWeComQRCode(ctx context.Context, opts wecomQRFlowOptions) (wecomQRSession, error) {
|
||||
generateURL, err := buildWeComQRGenerateURL(opts.GenerateURL, opts.SourceID, wecomPlatformCode())
|
||||
if err != nil {
|
||||
return wecomQRSession{}, err
|
||||
}
|
||||
|
||||
var resp wecomQRGenerateResponse
|
||||
if err := doWeComJSONGet(ctx, opts.HTTPClient, generateURL, &resp); err != nil {
|
||||
return wecomQRSession{}, fmt.Errorf("failed to get WeCom QR code: %w", err)
|
||||
}
|
||||
if resp.ErrCode != 0 {
|
||||
return wecomQRSession{}, fmt.Errorf(
|
||||
"failed to get WeCom QR code: errcode=%d errmsg=%s",
|
||||
resp.ErrCode,
|
||||
resp.ErrMsg,
|
||||
)
|
||||
}
|
||||
if resp.Data.SCode == "" || resp.Data.AuthURL == "" {
|
||||
return wecomQRSession{}, fmt.Errorf("failed to get WeCom QR code: response missing scode or auth_url")
|
||||
}
|
||||
|
||||
return wecomQRSession{
|
||||
SCode: resp.Data.SCode,
|
||||
AuthURL: resp.Data.AuthURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func pollWeComQRCodeResult(ctx context.Context, opts wecomQRFlowOptions, scode string) (wecomQRBotInfo, error) {
|
||||
if strings.TrimSpace(scode) == "" {
|
||||
return wecomQRBotInfo{}, fmt.Errorf("missing WeCom QR scode")
|
||||
}
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, opts.PollTimeout)
|
||||
defer cancel()
|
||||
|
||||
var scannedPrinted bool
|
||||
|
||||
for {
|
||||
status, err := queryWeComQRCodeStatus(timeoutCtx, opts, scode)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
||||
return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan timed out after %s", opts.PollTimeout)
|
||||
}
|
||||
return wecomQRBotInfo{}, err
|
||||
}
|
||||
|
||||
switch strings.ToLower(status.Data.Status) {
|
||||
case "success":
|
||||
if status.Data.BotInfo.BotID == "" || status.Data.BotInfo.Secret == "" {
|
||||
return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan succeeded but bot credentials are missing")
|
||||
}
|
||||
return wecomQRBotInfo{
|
||||
BotID: status.Data.BotInfo.BotID,
|
||||
Secret: status.Data.BotInfo.Secret,
|
||||
}, nil
|
||||
case "expired":
|
||||
return wecomQRBotInfo{}, fmt.Errorf("WeCom QR code expired, please retry")
|
||||
case "scaned", "scanned":
|
||||
if !scannedPrinted {
|
||||
fmt.Fprintln(opts.Writer, "QR code scanned. Confirm the login in WeCom.")
|
||||
scannedPrinted = true
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
||||
return wecomQRBotInfo{}, fmt.Errorf("WeCom QR scan timed out after %s", opts.PollTimeout)
|
||||
}
|
||||
return wecomQRBotInfo{}, timeoutCtx.Err()
|
||||
case <-time.After(opts.PollInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func queryWeComQRCodeStatus(ctx context.Context, opts wecomQRFlowOptions, scode string) (wecomQRQueryResponse, error) {
|
||||
queryURL, err := buildWeComQRQueryURL(opts.QueryURL, scode)
|
||||
if err != nil {
|
||||
return wecomQRQueryResponse{}, err
|
||||
}
|
||||
|
||||
var resp wecomQRQueryResponse
|
||||
if err := doWeComJSONGet(ctx, opts.HTTPClient, queryURL, &resp); err != nil {
|
||||
return wecomQRQueryResponse{}, fmt.Errorf("failed to query WeCom QR result: %w", err)
|
||||
}
|
||||
if resp.ErrCode != 0 {
|
||||
return wecomQRQueryResponse{}, fmt.Errorf(
|
||||
"failed to query WeCom QR result: errcode=%d errmsg=%s",
|
||||
resp.ErrCode,
|
||||
resp.ErrMsg,
|
||||
)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func buildWeComQRGenerateURL(baseURL, sourceID string, platformCode int) (string, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid WeCom QR generate URL: %w", err)
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
query.Set("source", sourceID)
|
||||
query.Set("sourceID", sourceID)
|
||||
query.Set("plat", strconv.Itoa(platformCode))
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func buildWeComQRQueryURL(baseURL, scode string) (string, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid WeCom QR query URL: %w", err)
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
query.Set("scode", scode)
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func buildWeComQRCodePageURL(baseURL, sourceID, scode string) (string, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid WeCom QR page URL: %w", err)
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
query.Set("source", sourceID)
|
||||
query.Set("sourceID", sourceID)
|
||||
query.Set("scode", scode)
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func doWeComJSONGet(ctx context.Context, client *http.Client, targetURL string, out any) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("unexpected status %s", resp.Status)
|
||||
}
|
||||
return fmt.Errorf("unexpected status %s: %s", resp.Status, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
|
||||
return fmt.Errorf("decode JSON response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func wecomPlatformCode() int {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return 1
|
||||
case "windows":
|
||||
return 2
|
||||
case "linux":
|
||||
return 3
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestNewWeComCommand(t *testing.T) {
|
||||
cmd := newWeComCommand()
|
||||
|
||||
require.NotNil(t, cmd)
|
||||
assert.Equal(t, "wecom", cmd.Use)
|
||||
assert.Equal(t, "Scan a WeCom QR code and configure channels.wecom", cmd.Short)
|
||||
assert.NotNil(t, cmd.Flags().Lookup("timeout"))
|
||||
}
|
||||
|
||||
func TestBuildWeComQRGenerateURL(t *testing.T) {
|
||||
rawURL, err := buildWeComQRGenerateURL("https://example.com/ai/qc/generate", wecomQRSourceID, 3)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wecomQRSourceID, parsed.Query().Get("source"))
|
||||
assert.Equal(t, wecomQRSourceID, parsed.Query().Get("sourceID"))
|
||||
assert.Equal(t, "3", parsed.Query().Get("plat"))
|
||||
}
|
||||
|
||||
func TestBuildWeComQRCodePageURL(t *testing.T) {
|
||||
rawURL, err := buildWeComQRCodePageURL("https://example.com/ai/qc/gen", wecomQRSourceID, "scode-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wecomQRSourceID, parsed.Query().Get("source"))
|
||||
assert.Equal(t, wecomQRSourceID, parsed.Query().Get("sourceID"))
|
||||
assert.Equal(t, "scode-1", parsed.Query().Get("scode"))
|
||||
}
|
||||
|
||||
func TestFetchWeComQRCode(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/generate", r.URL.Path)
|
||||
assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("source"))
|
||||
assert.Equal(t, wecomQRSourceID, r.URL.Query().Get("sourceID"))
|
||||
assert.Equal(t, strconv.Itoa(wecomPlatformCode()), r.URL.Query().Get("plat"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":{"scode":"scode-1","auth_url":"https://example.com/qr"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{
|
||||
HTTPClient: server.Client(),
|
||||
GenerateURL: server.URL + "/generate",
|
||||
Writer: bytes.NewBuffer(nil),
|
||||
})
|
||||
|
||||
session, err := fetchWeComQRCode(context.Background(), opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "scode-1", session.SCode)
|
||||
assert.Equal(t, "https://example.com/qr", session.AuthURL)
|
||||
}
|
||||
|
||||
func TestPollWeComQRCodeResult(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
call := calls.Add(1)
|
||||
assert.Equal(t, "/query", r.URL.Path)
|
||||
assert.Equal(t, "scode-1", r.URL.Query().Get("scode"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch call {
|
||||
case 1:
|
||||
_, _ = w.Write([]byte(`{"data":{"status":"wait"}}`))
|
||||
case 2:
|
||||
_, _ = w.Write([]byte(`{"data":{"status":"scaned"}}`))
|
||||
default:
|
||||
_, _ = w.Write([]byte(`{"data":{"status":"success","bot_info":{"botid":"bot-1","secret":"secret-1"}}}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var output bytes.Buffer
|
||||
opts := normalizeWeComQRFlowOptions(wecomQRFlowOptions{
|
||||
HTTPClient: server.Client(),
|
||||
QueryURL: server.URL + "/query",
|
||||
PollInterval: time.Millisecond,
|
||||
PollTimeout: time.Second,
|
||||
Writer: &output,
|
||||
})
|
||||
|
||||
botInfo, err := pollWeComQRCodeResult(context.Background(), opts, "scode-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bot-1", botInfo.BotID)
|
||||
assert.Equal(t, "secret-1", botInfo.Secret)
|
||||
assert.Contains(t, output.String(), "QR code scanned. Confirm the login in WeCom.")
|
||||
}
|
||||
|
||||
func TestApplyWeComAuthResult(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels.WeCom.WebSocketURL = ""
|
||||
|
||||
applyWeComAuthResult(cfg, wecomQRBotInfo{
|
||||
BotID: "bot-1",
|
||||
Secret: "secret-1",
|
||||
})
|
||||
|
||||
assert.True(t, cfg.Channels.WeCom.Enabled)
|
||||
assert.Equal(t, "bot-1", cfg.Channels.WeCom.BotID)
|
||||
assert.Equal(t, "secret-1", cfg.Channels.WeCom.Secret())
|
||||
assert.Equal(t, wecomDefaultWebSocketURL, cfg.Channels.WeCom.WebSocketURL)
|
||||
}
|
||||
|
||||
func TestAuthWeComCmdWithScanner(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
t.Setenv(config.EnvHome, tmpDir)
|
||||
t.Setenv(config.EnvConfig, configPath)
|
||||
|
||||
var output bytes.Buffer
|
||||
err := authWeComCmdWithScanner(
|
||||
context.Background(),
|
||||
&output,
|
||||
time.Second,
|
||||
func(_ context.Context, opts wecomQRFlowOptions) (wecomQRBotInfo, error) {
|
||||
assert.Equal(t, wecomQRSourceID, opts.SourceID)
|
||||
return wecomQRBotInfo{
|
||||
BotID: "bot-1",
|
||||
Secret: "secret-1",
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := config.LoadConfig(internal.GetConfigPath())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, cfg.Channels.WeCom.Enabled)
|
||||
assert.Equal(t, "bot-1", cfg.Channels.WeCom.BotID)
|
||||
assert.Equal(t, "secret-1", cfg.Channels.WeCom.Secret())
|
||||
assert.Equal(t, wecomDefaultWebSocketURL, cfg.Channels.WeCom.WebSocketURL)
|
||||
assert.Contains(t, output.String(), "WeCom connected.")
|
||||
}
|
||||
@@ -184,39 +184,13 @@
|
||||
"reasoning_channel_id": ""
|
||||
},
|
||||
"wecom": {
|
||||
"_comment": "WeCom Bot - Easier setup, supports group chats",
|
||||
"enabled": false,
|
||||
"token": "YOUR_TOKEN",
|
||||
"encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
|
||||
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
|
||||
"webhook_path": "/webhook/wecom",
|
||||
"allow_from": [],
|
||||
"reply_timeout": 5,
|
||||
"reasoning_channel_id": ""
|
||||
},
|
||||
"wecom_app": {
|
||||
"_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only.",
|
||||
"enabled": false,
|
||||
"corp_id": "YOUR_CORP_ID",
|
||||
"corp_secret": "YOUR_CORP_SECRET",
|
||||
"agent_id": 1000002,
|
||||
"token": "YOUR_TOKEN",
|
||||
"encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
|
||||
"webhook_path": "/webhook/wecom-app",
|
||||
"allow_from": [],
|
||||
"reply_timeout": 5,
|
||||
"reasoning_channel_id": ""
|
||||
},
|
||||
"wecom_aibot": {
|
||||
"_comment": "WeCom AI Bot (智能机器人) - Official WeCom AI Bot integration, supports proactive messaging and private chats.",
|
||||
"_comment": "WeCom AI Bot over WebSocket.",
|
||||
"enabled": false,
|
||||
"bot_id": "YOUR_BOT_ID",
|
||||
"secret": "YOUR_SECRET",
|
||||
"token": "YOUR_TOKEN",
|
||||
"encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
|
||||
"webhook_path": "/webhook/wecom-aibot",
|
||||
"max_steps": 10,
|
||||
"welcome_message": "Hello! I'm your AI assistant. How can I help you today?",
|
||||
"websocket_url": "wss://openws.work.weixin.qq.com",
|
||||
"send_thinking_message": true,
|
||||
"allow_from": [],
|
||||
"reasoning_channel_id": ""
|
||||
},
|
||||
"pico": {
|
||||
|
||||
+11
-13
@@ -2042,18 +2042,17 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
t.Fatalf("Failed to create channel manager: %v", err)
|
||||
}
|
||||
for name, id := range map[string]string{
|
||||
"whatsapp": "rid-whatsapp",
|
||||
"telegram": "rid-telegram",
|
||||
"feishu": "rid-feishu",
|
||||
"discord": "rid-discord",
|
||||
"maixcam": "rid-maixcam",
|
||||
"qq": "rid-qq",
|
||||
"dingtalk": "rid-dingtalk",
|
||||
"slack": "rid-slack",
|
||||
"line": "rid-line",
|
||||
"onebot": "rid-onebot",
|
||||
"wecom": "rid-wecom",
|
||||
"wecom_app": "rid-wecom-app",
|
||||
"whatsapp": "rid-whatsapp",
|
||||
"telegram": "rid-telegram",
|
||||
"feishu": "rid-feishu",
|
||||
"discord": "rid-discord",
|
||||
"maixcam": "rid-maixcam",
|
||||
"qq": "rid-qq",
|
||||
"dingtalk": "rid-dingtalk",
|
||||
"slack": "rid-slack",
|
||||
"line": "rid-line",
|
||||
"onebot": "rid-onebot",
|
||||
"wecom": "rid-wecom",
|
||||
} {
|
||||
chManager.RegisterChannel(name, &fakeChannel{id: id})
|
||||
}
|
||||
@@ -2073,7 +2072,6 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
{channel: "line", wantID: "rid-line"},
|
||||
{channel: "onebot", wantID: "rid-onebot"},
|
||||
{channel: "wecom", wantID: "rid-wecom"},
|
||||
{channel: "wecom_app", wantID: "rid-wecom-app"},
|
||||
{channel: "unknown", wantID: ""},
|
||||
}
|
||||
|
||||
|
||||
@@ -1255,8 +1255,7 @@ make test # Full test suite
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | — (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | MediaSender |
|
||||
| `pkg/channels/qq/` | `"qq"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) |
|
||||
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) |
|
||||
@@ -1371,7 +1370,7 @@ agentLoop.Stop() // Stop Agent
|
||||
|
||||
2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
|
||||
|
||||
3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`.
|
||||
3. **WeCom is now a single channel**: `"wecom"` is implemented as a WebSocket-based AI Bot channel with route persistence. Access control uses the shared channel allowlist mechanism. It no longer exposes the legacy webhook/app split.
|
||||
|
||||
4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`).
|
||||
|
||||
@@ -1381,4 +1380,4 @@ agentLoop.Stop() // Stop Agent
|
||||
|
||||
7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields.
|
||||
|
||||
8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method.
|
||||
8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method.
|
||||
|
||||
@@ -1254,8 +1254,7 @@ make test # 全量测试
|
||||
| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
|
||||
| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
|
||||
| `pkg/channels/feishu/` | `"feishu"` | — (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
|
||||
| `pkg/channels/wecom/` | `"wecom"` | MediaSender |
|
||||
| `pkg/channels/qq/` | `"qq"` | — |
|
||||
| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) |
|
||||
| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) |
|
||||
@@ -1370,7 +1369,7 @@ agentLoop.Stop() // 停止 Agent
|
||||
|
||||
2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
|
||||
|
||||
3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler` 和 `HealthChecker`。
|
||||
3. **WeCom 现在只有一个 channel**:`"wecom"` 采用 WebSocket AI Bot 实现,带路由持久化;访问控制走统一的 channel 白名单机制,不再保留旧的 webhook/app 双分支。
|
||||
|
||||
4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。
|
||||
|
||||
@@ -1380,4 +1379,4 @@ agentLoop.Stop() // 停止 Agent
|
||||
|
||||
7. **PlaceholderConfig 的配置与实现**:`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channel(Telegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。
|
||||
|
||||
8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。
|
||||
8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。
|
||||
|
||||
+1
-10
@@ -405,19 +405,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
m.initChannel("onebot", "OneBot")
|
||||
}
|
||||
|
||||
if channels.WeCom.Enabled && channels.WeCom.Token() != "" {
|
||||
if channels.WeCom.Enabled && channels.WeCom.BotID != "" && channels.WeCom.Secret() != "" {
|
||||
m.initChannel("wecom", "WeCom")
|
||||
}
|
||||
|
||||
if channels.WeComAIBot.Enabled && (channels.WeComAIBot.Token() != "" ||
|
||||
(channels.WeComAIBot.Secret() != "" && channels.WeComAIBot.BotID != "")) {
|
||||
m.initChannel("wecom_aibot", "WeCom AI Bot")
|
||||
}
|
||||
|
||||
if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" {
|
||||
m.initChannel("wecom_app", "WeCom App")
|
||||
}
|
||||
|
||||
if channels.Weixin.Enabled && channels.Weixin.Token() != "" {
|
||||
m.initChannel("weixin", "Weixin")
|
||||
}
|
||||
|
||||
@@ -49,15 +49,7 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
|
||||
value["token"] = ch.LINE.ChannelAccessToken()
|
||||
value["secret"] = ch.LINE.ChannelSecret()
|
||||
case "wecom":
|
||||
value["token"] = ch.WeCom.Token()
|
||||
value["key"] = ch.WeCom.EncodingAESKey()
|
||||
case "wecom_app":
|
||||
value["token"] = ch.WeComApp.Token()
|
||||
value["secret"] = ch.WeComApp.CorpSecret()
|
||||
case "wecom_aibot":
|
||||
value["token"] = ch.WeComAIBot.Token()
|
||||
value["key"] = ch.WeComAIBot.EncodingAESKey()
|
||||
value["secret"] = ch.WeComAIBot.Secret()
|
||||
value["secret"] = ch.WeCom.Secret()
|
||||
case "dingtalk":
|
||||
value["secret"] = ch.QQ.AppSecret()
|
||||
case "qq":
|
||||
@@ -156,16 +148,7 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
|
||||
newcfg.LINE.SetChannelSecret(old.LINE.ChannelSecret())
|
||||
}
|
||||
if newcfg.WeCom.Enabled {
|
||||
newcfg.WeCom.SetToken(old.WeCom.Token())
|
||||
newcfg.WeCom.SetEncodingAESKey(old.WeCom.EncodingAESKey())
|
||||
}
|
||||
if newcfg.WeComApp.Enabled {
|
||||
newcfg.WeComApp.SetToken(old.WeComApp.Token())
|
||||
newcfg.WeComApp.SetCorpSecret(old.WeComApp.CorpSecret())
|
||||
}
|
||||
if newcfg.WeComAIBot.Enabled {
|
||||
newcfg.WeComAIBot.SetToken(old.WeComAIBot.Token())
|
||||
newcfg.WeComAIBot.SetEncodingAESKey(old.WeComAIBot.EncodingAESKey())
|
||||
newcfg.WeCom.SetSecret(old.WeCom.Secret())
|
||||
}
|
||||
if newcfg.DingTalk.Enabled {
|
||||
newcfg.DingTalk.SetClientSecret(old.DingTalk.ClientSecret())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,559 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// ---- Webhook mode tests ----
|
||||
|
||||
func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) {
|
||||
t.Run("success with valid config", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
cfg.WebhookPath = "/webhook/test"
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if ch == nil {
|
||||
t.Fatal("Expected channel to be created")
|
||||
}
|
||||
if ch.Name() != "wecom_aibot" {
|
||||
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
|
||||
}
|
||||
// Webhook mode must implement WebhookHandler.
|
||||
if _, ok := ch.(channels.WebhookHandler); !ok {
|
||||
t.Error("Webhook mode channel should implement WebhookHandler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing token", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing encoding key", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing encoding key, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAIBotWebhookChannelStartStop(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Failed to start channel: %v", err)
|
||||
}
|
||||
if !ch.IsRunning() {
|
||||
t.Error("Expected channel to be running after Start")
|
||||
}
|
||||
|
||||
if err := ch.Stop(ctx); err != nil {
|
||||
t.Fatalf("Failed to stop channel: %v", err)
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("Expected channel to be stopped after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComAIBotChannelWebhookPath(t *testing.T) {
|
||||
t.Run("default path", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
wh, ok := ch.(channels.WebhookHandler)
|
||||
if !ok {
|
||||
t.Fatal("Expected channel to implement WebhookHandler")
|
||||
}
|
||||
expectedPath := "/webhook/wecom-aibot"
|
||||
if wh.WebhookPath() != expectedPath {
|
||||
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom path", func(t *testing.T) {
|
||||
customPath := "/custom/webhook"
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
cfg.WebhookPath = customPath
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
wh, ok := ch.(channels.WebhookHandler)
|
||||
if !ok {
|
||||
t.Fatal("Expected channel to implement WebhookHandler")
|
||||
}
|
||||
if wh.WebhookPath() != customPath {
|
||||
t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAIBotChannelGetStreamResponseProcessingMessage(t *testing.T) {
|
||||
validAESKey := "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG"
|
||||
|
||||
t.Run("uses default processing message", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(validAESKey)
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
channel, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
ch, ok := channel.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
task := &streamTask{
|
||||
StreamID: "stream-default",
|
||||
ChatID: "chat-default",
|
||||
Deadline: time.Now().Add(-time.Second),
|
||||
}
|
||||
ch.streamTasks[task.StreamID] = task
|
||||
ch.chatTasks[task.ChatID] = []*streamTask{task}
|
||||
|
||||
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
|
||||
|
||||
if !resp.Stream.Finish {
|
||||
t.Fatal("Expected finished stream response after deadline")
|
||||
}
|
||||
if resp.Stream.Content != config.DefaultWeComAIBotProcessingMessage {
|
||||
t.Fatalf("Expected default processing message %q, got %q",
|
||||
config.DefaultWeComAIBotProcessingMessage, resp.Stream.Content)
|
||||
}
|
||||
if !task.StreamClosed {
|
||||
t.Fatal("Expected task stream to be marked closed")
|
||||
}
|
||||
if _, ok := ch.streamTasks[task.StreamID]; ok {
|
||||
t.Fatal("Expected closed stream task to be removed from streamTasks")
|
||||
}
|
||||
if len(ch.chatTasks[task.ChatID]) != 1 {
|
||||
t.Fatalf("Expected task to remain queued for response_url delivery, got %d entries",
|
||||
len(ch.chatTasks[task.ChatID]))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses custom processing message", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
ProcessingMessage: "Please wait a moment. The result will be delivered in a follow-up message.",
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(validAESKey)
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
channel, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
ch, ok := channel.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
task := &streamTask{
|
||||
StreamID: "stream-custom",
|
||||
ChatID: "chat-custom",
|
||||
Deadline: time.Now().Add(-time.Second),
|
||||
}
|
||||
|
||||
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
|
||||
|
||||
if resp.Stream.Content != cfg.ProcessingMessage {
|
||||
t.Fatalf("Expected custom processing message %q, got %q", cfg.ProcessingMessage, resp.Stream.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateStreamID(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
webhookCh, ok := ch.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := webhookCh.generateStreamID()
|
||||
if len(id) != 10 {
|
||||
t.Errorf("Expected stream ID length 10, got %d", len(id))
|
||||
}
|
||||
if ids[id] {
|
||||
t.Errorf("Duplicate stream ID generated: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Use a valid 43-character base64 key (企业微信标准格式)
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG") // 43 characters
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
webhookCh, ok := ch.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
plaintext := "Hello, World!"
|
||||
receiveid := ""
|
||||
|
||||
encrypted, err := webhookCh.encryptMessage(plaintext, receiveid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt message: %v", err)
|
||||
}
|
||||
if encrypted == "" {
|
||||
t.Fatal("Encrypted message is empty")
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey(), receiveid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt message: %v", err)
|
||||
}
|
||||
if decrypted != plaintext {
|
||||
t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSignature(t *testing.T) {
|
||||
token := "test_token"
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
encrypt := "encrypted_msg"
|
||||
|
||||
signature := computeSignature(token, timestamp, nonce, encrypt)
|
||||
if signature == "" {
|
||||
t.Error("Generated signature is empty")
|
||||
}
|
||||
if !verifySignature(token, signature, timestamp, nonce, encrypt) {
|
||||
t.Error("Generated signature does not verify correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func decodeStreamResponse(t *testing.T, ch *WeComAIBotChannel, encryptedResponse string) WeComAIBotStreamResponse {
|
||||
t.Helper()
|
||||
|
||||
var wrapped WeComAIBotEncryptedResponse
|
||||
if err := json.Unmarshal([]byte(encryptedResponse), &wrapped); err != nil {
|
||||
t.Fatalf("Failed to unmarshal encrypted response: %v", err)
|
||||
}
|
||||
|
||||
plaintext, err := decryptMessageWithVerify(wrapped.Encrypt, ch.config.EncodingAESKey(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt response: %v", err)
|
||||
}
|
||||
|
||||
var resp WeComAIBotStreamResponse
|
||||
if err := json.Unmarshal([]byte(plaintext), &resp); err != nil {
|
||||
t.Fatalf("Failed to unmarshal decrypted response: %v", err)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ---- WebSocket long-connection mode tests ----
|
||||
|
||||
func TestNewWeComAIBotChannel_WSMode(t *testing.T) {
|
||||
t.Run("success with bot_id and secret", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if ch == nil {
|
||||
t.Fatal("Expected channel to be created")
|
||||
}
|
||||
if ch.Name() != "wecom_aibot" {
|
||||
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
|
||||
}
|
||||
// WebSocket mode must NOT implement WebhookHandler.
|
||||
if _, ok := ch.(channels.WebhookHandler); ok {
|
||||
t.Error("WebSocket mode channel should NOT implement WebhookHandler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ws mode takes priority over webhook fields", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
cfg.SetToken("also_set")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if _, ok := ch.(*WeComAIBotWSChannel); !ok {
|
||||
t.Error("Expected WebSocket mode channel when both BotID+secret and Token+Key are set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing bot_id", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
// Missing bot_id alone means neither WS mode nor webhook mode is fully configured.
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing bot_id, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing secret", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing secret, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAIBotWSChannelStartStop(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start launches a background goroutine; it should not block or return an error.
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Failed to start channel: %v", err)
|
||||
}
|
||||
if !ch.IsRunning() {
|
||||
t.Error("Expected channel to be running after Start")
|
||||
}
|
||||
|
||||
// Stop should work regardless of whether the WebSocket actually connected.
|
||||
if err := ch.Stop(ctx); err != nil {
|
||||
t.Fatalf("Failed to stop channel: %v", err)
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("Expected channel to be stopped after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 200; i++ {
|
||||
id := generateRandomID(10)
|
||||
if len(id) != 10 {
|
||||
t.Errorf("Expected ID length 10, got %d", len(id))
|
||||
}
|
||||
if ids[id] {
|
||||
t.Errorf("Duplicate ID generated: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSGenerateID(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 200; i++ {
|
||||
id := wsGenerateID()
|
||||
if len(id) != 10 {
|
||||
t.Errorf("Expected ID length 10, got %d", len(id))
|
||||
}
|
||||
if ids[id] {
|
||||
t.Errorf("Duplicate wsGenerateID result: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Webhook streaming fallback tests ----
|
||||
|
||||
// makeWebhookChannel creates a started WeComAIBotChannel for testing.
|
||||
func makeWebhookChannel(t *testing.T) *WeComAIBotChannel {
|
||||
t.Helper()
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG")
|
||||
ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatalf("create channel: %v", err)
|
||||
}
|
||||
wc := ch.(*WeComAIBotChannel)
|
||||
wc.ctx, wc.cancel = context.WithCancel(context.Background())
|
||||
return wc
|
||||
}
|
||||
|
||||
// makeStreamTask creates and registers a streamTask for testing.
|
||||
func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask {
|
||||
t.Helper()
|
||||
task := &streamTask{
|
||||
StreamID: streamID,
|
||||
ChatID: chatID,
|
||||
Deadline: deadline,
|
||||
answerCh: make(chan string, 1),
|
||||
}
|
||||
task.ctx, task.cancel = context.WithCancel(ch.ctx)
|
||||
ch.taskMu.Lock()
|
||||
ch.streamTasks[streamID] = task
|
||||
ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task)
|
||||
ch.taskMu.Unlock()
|
||||
return task
|
||||
}
|
||||
|
||||
// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already
|
||||
// placed its answer in answerCh, getStreamResponse returns a finish=true response
|
||||
// and fully removes the task.
|
||||
func TestGetStreamResponse_ImmediateAnswer(t *testing.T) {
|
||||
ch := makeWebhookChannel(t)
|
||||
defer ch.cancel()
|
||||
|
||||
task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second))
|
||||
task.answerCh <- "hello from agent"
|
||||
|
||||
result := ch.getStreamResponse(task, "ts123", "nonce123")
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty encrypted response")
|
||||
}
|
||||
|
||||
ch.taskMu.RLock()
|
||||
_, exists := ch.streamTasks["stream-1"]
|
||||
ch.taskMu.RUnlock()
|
||||
if exists {
|
||||
t.Error("task should have been removed from streamTasks after normal finish")
|
||||
}
|
||||
if !task.Finished {
|
||||
t.Error("task.Finished should be true after normal finish")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has
|
||||
// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the
|
||||
// task alive so the response_url fallback can still deliver the answer.
|
||||
func TestGetStreamResponse_DeadlinePassed(t *testing.T) {
|
||||
ch := makeWebhookChannel(t)
|
||||
defer ch.cancel()
|
||||
|
||||
task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond))
|
||||
|
||||
result := ch.getStreamResponse(task, "ts456", "nonce456")
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty encrypted response")
|
||||
}
|
||||
|
||||
ch.taskMu.RLock()
|
||||
_, stillStreaming := ch.streamTasks["stream-2"]
|
||||
ch.taskMu.RUnlock()
|
||||
if stillStreaming {
|
||||
t.Error("task should have been removed from streamTasks after deadline")
|
||||
}
|
||||
if !task.StreamClosed {
|
||||
t.Error("task.StreamClosed should be true after deadline")
|
||||
}
|
||||
if task.Finished {
|
||||
t.Error("task.Finished must remain false: agent reply still expected via response_url")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetStreamResponse_StillPending verifies that when neither the agent has
|
||||
// replied nor the deadline has passed, getStreamResponse returns without altering
|
||||
// task state (client should poll again).
|
||||
func TestGetStreamResponse_StillPending(t *testing.T) {
|
||||
ch := makeWebhookChannel(t)
|
||||
defer ch.cancel()
|
||||
|
||||
task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second))
|
||||
|
||||
result := ch.getStreamResponse(task, "ts789", "nonce789")
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty encrypted response")
|
||||
}
|
||||
|
||||
ch.taskMu.RLock()
|
||||
_, exists := ch.streamTasks["stream-3"]
|
||||
ch.taskMu.RUnlock()
|
||||
if !exists {
|
||||
t.Error("pending task should still be in streamTasks")
|
||||
}
|
||||
if task.Finished || task.StreamClosed {
|
||||
t.Error("pending task should not be finished or stream-closed")
|
||||
}
|
||||
// Cleanup.
|
||||
ch.removeTask(task)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,295 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing.
|
||||
func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel {
|
||||
t.Helper()
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatalf("create WS channel: %v", err)
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no
|
||||
// MediaStore has been injected.
|
||||
func TestStoreWSMedia_NilStore(t *testing.T) {
|
||||
ch := newTestWSChannel(t)
|
||||
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no MediaStore is set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors
|
||||
// from the media server.
|
||||
func TestStoreWSMedia_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
ch.SetMediaStore(media.NewFileMediaStore())
|
||||
|
||||
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 404")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear
|
||||
// error when the media server cannot be reached.
|
||||
func TestStoreWSMedia_ServerUnavailable(t *testing.T) {
|
||||
ch := newTestWSChannel(t)
|
||||
ch.SetMediaStore(media.NewFileMediaStore())
|
||||
|
||||
// Port 1 is reserved and will refuse the connection immediately.
|
||||
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unreachable server")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded,
|
||||
// a media ref is returned, and the file persists and is readable via Resolve until
|
||||
// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used.
|
||||
func TestStoreWSMedia_Success_NoAES(t *testing.T) {
|
||||
imageData := bytes.Repeat([]byte("x"), 256)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(imageData)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ref == "" {
|
||||
t.Fatal("expected non-empty ref")
|
||||
}
|
||||
|
||||
// File must be accessible after storeWSMedia returns (no premature deletion).
|
||||
path, err := store.Resolve(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("ref should resolve: %v", err)
|
||||
}
|
||||
got, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("file should exist at %s: %v", path, err)
|
||||
}
|
||||
if !bytes.Equal(got, imageData) {
|
||||
t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData))
|
||||
}
|
||||
|
||||
// ReleaseAll must delete the file (store owns lifecycle).
|
||||
scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1")
|
||||
if err := store.ReleaseAll(scope); err != nil {
|
||||
t.Fatalf("ReleaseAll failed: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with
|
||||
// different msgIDs do not collide and each resolve to distinct files.
|
||||
func TestStoreWSMedia_MultipleMessages(t *testing.T) {
|
||||
imageA := bytes.Repeat([]byte("a"), 64)
|
||||
imageB := bytes.Repeat([]byte("b"), 64)
|
||||
|
||||
srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(imageA)
|
||||
}))
|
||||
defer srvA.Close()
|
||||
srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(imageB)
|
||||
}))
|
||||
defer srvB.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("storeWSMedia A: %v", err)
|
||||
}
|
||||
refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("storeWSMedia B: %v", err)
|
||||
}
|
||||
if refA == refB {
|
||||
t.Fatal("distinct messages must produce distinct refs")
|
||||
}
|
||||
|
||||
pathA, _ := store.Resolve(refA)
|
||||
pathB, _ := store.Resolve(refB)
|
||||
if pathA == pathB {
|
||||
t.Fatal("distinct messages must be stored at distinct paths")
|
||||
}
|
||||
|
||||
gotA, _ := os.ReadFile(pathA)
|
||||
gotB, _ := os.ReadFile(pathB)
|
||||
if !bytes.Equal(gotA, imageA) {
|
||||
t.Errorf("content mismatch for message A")
|
||||
}
|
||||
if !bytes.Equal(gotB, imageB) {
|
||||
t.Errorf("content mismatch for message B")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred
|
||||
// from the HTTP Content-Type header and the defaultExt fallback is used when the
|
||||
// type is absent or unrecognized.
|
||||
func TestStoreWSMedia_ContentTypeExt(t *testing.T) {
|
||||
tests := []struct {
|
||||
contentType string
|
||||
wantExt string
|
||||
}{
|
||||
{"image/jpeg", ".jpg"},
|
||||
{"image/png", ".png"},
|
||||
{"video/mp4", ".mp4"},
|
||||
{"application/pdf", ".pdf"},
|
||||
{"application/zip", ".zip"},
|
||||
// With parameters stripped.
|
||||
{"video/mp4; codecs=avc1", ".mp4"},
|
||||
// Unknown type → falls back to defaultExt.
|
||||
{"", ""},
|
||||
{"application/octet-stream", ""},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := wsMediaExtFromContentType(tc.contentType)
|
||||
if got != tc.wantExt {
|
||||
t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt)
|
||||
}
|
||||
}
|
||||
|
||||
// End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin.
|
||||
// The stored file should carry the .mp4 extension, not .bin.
|
||||
payload := bytes.Repeat([]byte("v"), 128)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "video/mp4")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(payload)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin")
|
||||
if err != nil {
|
||||
t.Fatalf("storeWSMedia: %v", err)
|
||||
}
|
||||
path, err := store.Resolve(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve: %v", err)
|
||||
}
|
||||
if ext := path[len(path)-4:]; ext != ".mp4" {
|
||||
t.Errorf("expected .mp4 extension from Content-Type, got %q", ext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSplitWSContent verifies byte-aware splitting of stream content.
|
||||
func TestSplitWSContent(t *testing.T) {
|
||||
t.Run("short content is not split", func(t *testing.T) {
|
||||
chunks := splitWSContent("hello", 20480)
|
||||
if len(chunks) != 1 || chunks[0] != "hello" {
|
||||
t.Fatalf("unexpected chunks: %v", chunks)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ASCII content split at byte boundary", func(t *testing.T) {
|
||||
// Build a string just over the limit.
|
||||
content := strings.Repeat("a", 20481)
|
||||
chunks := splitWSContent(content, 20480)
|
||||
if len(chunks) < 2 {
|
||||
t.Fatalf("expected >= 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
for i, c := range chunks {
|
||||
if len(c) > 20480 {
|
||||
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
|
||||
}
|
||||
}
|
||||
// Reassembled content must equal the original (possibly without leading
|
||||
// whitespace that splitWSContent trims between chunks).
|
||||
joined := strings.Join(chunks, "")
|
||||
if len(joined) < len(content)-len(chunks) {
|
||||
t.Errorf("joined length %d too short (original %d)", len(joined), len(content))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CJK content split within byte limit", func(t *testing.T) {
|
||||
// Each CJK rune is 3 bytes in UTF-8.
|
||||
// 7000 CJK chars = 21000 bytes, which exceeds 20480.
|
||||
content := strings.Repeat("\u4e2d", 7000)
|
||||
chunks := splitWSContent(content, 20480)
|
||||
if len(chunks) < 2 {
|
||||
t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks))
|
||||
}
|
||||
for i, c := range chunks {
|
||||
if len(c) > 20480 {
|
||||
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
|
||||
}
|
||||
// Every chunk must be valid UTF-8.
|
||||
if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 {
|
||||
// quick plausibility check — content was pure CJK
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter.
|
||||
func TestSplitAtByteBoundary(t *testing.T) {
|
||||
t.Run("ASCII fits in one chunk", func(t *testing.T) {
|
||||
parts := splitAtByteBoundary("hello world", 100)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) {
|
||||
// 10 CJK characters = 30 bytes; split at 20 bytes.
|
||||
s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes
|
||||
parts := splitAtByteBoundary(s, 20)
|
||||
for i, p := range parts {
|
||||
if len(p) > 20 {
|
||||
t.Errorf("part %d has %d bytes, want <= 20", i, len(p))
|
||||
}
|
||||
// Must be valid UTF-8 (no torn multi-byte sequences).
|
||||
for j, r := range p {
|
||||
if r == '\uFFFD' {
|
||||
t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,756 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomAPIBase = "https://qyapi.weixin.qq.com"
|
||||
)
|
||||
|
||||
// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用)
|
||||
type WeComAppChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComAppConfig
|
||||
client *http.Client
|
||||
accessToken string
|
||||
tokenExpiry time.Time
|
||||
tokenMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComXMLMessage represents the XML message structure from WeCom
|
||||
type WeComXMLMessage struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Content string `xml:"Content"`
|
||||
MsgId int64 `xml:"MsgId"`
|
||||
AgentID int64 `xml:"AgentID"`
|
||||
PicUrl string `xml:"PicUrl"`
|
||||
MediaId string `xml:"MediaId"`
|
||||
Format string `xml:"Format"`
|
||||
ThumbMediaId string `xml:"ThumbMediaId"`
|
||||
LocationX float64 `xml:"Location_X"`
|
||||
LocationY float64 `xml:"Location_Y"`
|
||||
Scale int `xml:"Scale"`
|
||||
Label string `xml:"Label"`
|
||||
Title string `xml:"Title"`
|
||||
Description string `xml:"Description"`
|
||||
Url string `xml:"Url"`
|
||||
Event string `xml:"Event"`
|
||||
EventKey string `xml:"EventKey"`
|
||||
}
|
||||
|
||||
// WeComTextMessage represents text message for sending
|
||||
type WeComTextMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Safe int `json:"safe,omitempty"`
|
||||
}
|
||||
|
||||
// WeComMarkdownMessage represents markdown message for sending
|
||||
type WeComMarkdownMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Markdown struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"markdown"`
|
||||
}
|
||||
|
||||
// WeComImageMessage represents image message for sending
|
||||
type WeComImageMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Image struct {
|
||||
MediaID string `json:"media_id"`
|
||||
} `json:"image"`
|
||||
}
|
||||
|
||||
// WeComAccessTokenResponse represents the access token API response
|
||||
type WeComAccessTokenResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// WeComSendMessageResponse represents the send message API response
|
||||
type WeComSendMessageResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
InvalidUser string `json:"invaliduser"`
|
||||
InvalidParty string `json:"invalidparty"`
|
||||
InvalidTag string `json:"invalidtag"`
|
||||
}
|
||||
|
||||
// PKCS7Padding adds PKCS7 padding
|
||||
type PKCS7Padding struct{}
|
||||
|
||||
// NewWeComAppChannel creates a new WeCom App channel instance
|
||||
func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) {
|
||||
if cfg.CorpID == "" || cfg.CorpSecret() == "" || cfg.AgentID == 0 {
|
||||
return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(2048),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComAppChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the channel name
|
||||
func (c *WeComAppChannel) Name() string {
|
||||
return "wecom_app"
|
||||
}
|
||||
|
||||
// Start initializes the WeCom App channel
|
||||
func (c *WeComAppChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Starting WeCom App channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Get initial access token
|
||||
if err := c.refreshAccessToken(); err != nil {
|
||||
logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Start token refresh goroutine
|
||||
go c.tokenRefreshLoop()
|
||||
|
||||
c.SetRunning(true)
|
||||
logger.InfoC("wecom_app", "WeCom App channel started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the WeCom App channel
|
||||
func (c *WeComAppChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Stopping WeCom App channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.SetRunning(false)
|
||||
logger.InfoC("wecom_app", "WeCom App channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to WeCom user proactively using access token
|
||||
func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
accessToken := c.getAccessToken()
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("no valid access token available")
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom_app", "Sending message", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content)
|
||||
}
|
||||
|
||||
// SendMedia implements the channels.MediaSender interface.
|
||||
func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
accessToken := c.getAccessToken()
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
localPath, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Map part type to WeCom media type
|
||||
var mediaType string
|
||||
switch part.Type {
|
||||
case "image":
|
||||
mediaType = "image"
|
||||
case "audio":
|
||||
mediaType = "voice"
|
||||
case "video":
|
||||
mediaType = "video"
|
||||
default:
|
||||
mediaType = "file"
|
||||
}
|
||||
|
||||
// Upload media to get media_id
|
||||
mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{
|
||||
"type": mediaType,
|
||||
"error": err.Error(),
|
||||
})
|
||||
// Fallback: send caption as text
|
||||
if part.Caption != "" {
|
||||
_ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Send media message using the media_id
|
||||
if mediaType == "image" {
|
||||
err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID)
|
||||
} else {
|
||||
// For non-image types, send as text fallback with caption
|
||||
caption := part.Caption
|
||||
if caption == "" {
|
||||
caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename)
|
||||
}
|
||||
err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// uploadMedia uploads a local file to WeCom temporary media storage.
|
||||
func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s",
|
||||
wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType))
|
||||
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
filename := filepath.Base(localPath)
|
||||
formFile, err := writer.CreateFormFile("media", filename)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
if _, err = io.Copy(formFile, file); err != nil {
|
||||
return "", fmt.Errorf("failed to copy file content: %w", err)
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom upload error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom upload error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
MediaID string `json:"media_id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse upload response: %w", err)
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
|
||||
}
|
||||
|
||||
return result.MediaID, nil
|
||||
}
|
||||
|
||||
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
|
||||
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom_app error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom_app API error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var sendResp WeComSendMessageResponse
|
||||
if err := json.Unmarshal(respBody, &sendResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if sendResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// WebhookPath returns the path for registering on the shared HTTP server.
|
||||
func (c *WeComAppChannel) WebhookPath() string {
|
||||
if c.config.WebhookPath != "" {
|
||||
return c.config.WebhookPath
|
||||
}
|
||||
return "/webhook/wecom-app"
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler for the shared HTTP server.
|
||||
func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleWebhook(w, r)
|
||||
}
|
||||
|
||||
// HealthPath returns the health check endpoint path.
|
||||
func (c *WeComAppChannel) HealthPath() string {
|
||||
return "/health/wecom-app"
|
||||
}
|
||||
|
||||
// HealthHandler handles health check requests.
|
||||
func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleHealth(w, r)
|
||||
}
|
||||
|
||||
// handleWebhook handles incoming webhook requests from WeCom
|
||||
func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Log all incoming requests for debugging
|
||||
logger.DebugCF("wecom_app", "Received webhook request", map[string]any{
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
"path": r.URL.Path,
|
||||
"query": r.URL.RawQuery,
|
||||
})
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
// Handle verification request
|
||||
c.handleVerification(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
// Handle message callback
|
||||
c.handleMessageCallback(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
logger.WarnCF("wecom_app", "Method not allowed", map[string]any{
|
||||
"method": r.Method,
|
||||
})
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handleVerification handles the URL verification request from WeCom
|
||||
func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
echostr := query.Get("echostr")
|
||||
|
||||
logger.DebugCF("wecom_app", "Handling verification request", map[string]any{
|
||||
"msg_signature": msgSignature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"echostr": echostr,
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
|
||||
logger.ErrorC("wecom_app", "Missing parameters in verification request")
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) {
|
||||
logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{
|
||||
"token": c.config.Token(),
|
||||
"msg_signature": msgSignature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
})
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugC("wecom_app", "Signature verification passed")
|
||||
|
||||
// Decrypt echostr with CorpID verification
|
||||
// For WeCom App (自建应用), receiveid should be corp_id
|
||||
logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{
|
||||
"encoding_aes_key": c.config.EncodingAESKey(),
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), c.config.CorpID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{
|
||||
"error": err.Error(),
|
||||
"encoding_aes_key": c.config.EncodingAESKey,
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{
|
||||
"decrypted": decryptedEchoStr,
|
||||
})
|
||||
|
||||
// Remove BOM and whitespace as per WeCom documentation
|
||||
// The response must be plain text without quotes, BOM, or newlines
|
||||
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
|
||||
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
|
||||
w.Write([]byte(decryptedEchoStr))
|
||||
}
|
||||
|
||||
// handleMessageCallback handles incoming messages from WeCom
|
||||
func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Parse XML to get encrypted message
|
||||
var encryptedMsg struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
}
|
||||
|
||||
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid XML", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
|
||||
logger.WarnC("wecom_app", "Message signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt message with CorpID verification
|
||||
// For WeCom App (自建应用), receiveid should be corp_id
|
||||
decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), c.config.CorpID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted XML message
|
||||
var msg WeComXMLMessage
|
||||
if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message with the channel's long-lived context (not the HTTP
|
||||
// request context, which is canceled as soon as we return the response).
|
||||
go c.processMessage(c.ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom App requires response within configured timeout (default 5 seconds)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
|
||||
// processMessage processes the received message
|
||||
func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) {
|
||||
// Skip non-text messages for now (can be extended)
|
||||
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" {
|
||||
logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{
|
||||
"msg_type": msg.MsgType,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
// As per WeCom documentation, use msg_id for deduplication
|
||||
msgID := fmt.Sprintf("%d", msg.MsgId)
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := msg.FromUserName
|
||||
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
|
||||
|
||||
// Build metadata
|
||||
// WeCom App only supports direct messages (private chat)
|
||||
peer := bus.Peer{Kind: "direct", ID: senderID}
|
||||
messageID := fmt.Sprintf("%d", msg.MsgId)
|
||||
|
||||
metadata := map[string]string{
|
||||
"msg_type": msg.MsgType,
|
||||
"msg_id": fmt.Sprintf("%d", msg.MsgId),
|
||||
"agent_id": fmt.Sprintf("%d", msg.AgentID),
|
||||
"platform": "wecom_app",
|
||||
"media_id": msg.MediaId,
|
||||
"create_time": fmt.Sprintf("%d", msg.CreateTime),
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
|
||||
logger.DebugCF("wecom_app", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"msg_type": msg.MsgType,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Build sender info
|
||||
appSender := bus.SenderInfo{
|
||||
Platform: "wecom",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender)
|
||||
}
|
||||
|
||||
// tokenRefreshLoop periodically refreshes the access token
|
||||
func (c *WeComAppChannel) tokenRefreshLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.refreshAccessToken(); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshAccessToken gets a new access token from WeCom API
|
||||
func (c *WeComAppChannel) refreshAccessToken() error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
|
||||
wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret()))
|
||||
|
||||
resp, err := http.Get(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to request access token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp WeComAccessTokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode)
|
||||
}
|
||||
|
||||
c.tokenMu.Lock()
|
||||
c.accessToken = tokenResp.AccessToken
|
||||
c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early
|
||||
c.tokenMu.Unlock()
|
||||
|
||||
logger.DebugC("wecom_app", "Access token refreshed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccessToken returns the current valid access token
|
||||
func (c *WeComAppChannel) getAccessToken() string {
|
||||
c.tokenMu.RLock()
|
||||
defer c.tokenMu.RUnlock()
|
||||
|
||||
if time.Now().After(c.tokenExpiry) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return c.accessToken
|
||||
}
|
||||
|
||||
// sendTextMessage sends a text message to a user.
|
||||
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
|
||||
msg := WeComTextMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "text",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Text.Content = content
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]any{
|
||||
"status": "ok",
|
||||
"running": c.IsRunning(),
|
||||
"has_token": c.getAccessToken() != "",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,499 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人)
|
||||
// Uses webhook callback mode - simpler than WeCom App but only supports passive replies
|
||||
type WeComBotChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComConfig
|
||||
client *http.Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
|
||||
type WeComBotMessage struct {
|
||||
MsgID string `json:"msgid"`
|
||||
AIBotID string `json:"aibotid"`
|
||||
ChatID string `json:"chatid"` // Session ID, only present for group chats
|
||||
ChatType string `json:"chattype"` // "single" for DM, "group" for group chat
|
||||
From struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"from"`
|
||||
ResponseURL string `json:"response_url"`
|
||||
MsgType string `json:"msgtype"` // text, image, voice, file, mixed
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Image struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"image"`
|
||||
Voice struct {
|
||||
Content string `json:"content"` // Voice to text content
|
||||
} `json:"voice"`
|
||||
File struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"file"`
|
||||
Mixed struct {
|
||||
MsgItem []struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Image struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"image"`
|
||||
} `json:"msg_item"`
|
||||
} `json:"mixed"`
|
||||
Quote struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
} `json:"quote"`
|
||||
}
|
||||
|
||||
// WeComBotReplyMessage represents the reply message structure
|
||||
type WeComBotReplyMessage struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// NewWeComBotChannel creates a new WeCom Bot channel instance
|
||||
func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) {
|
||||
if cfg.Token() == "" || cfg.WebhookURL == "" {
|
||||
return nil, fmt.Errorf("wecom token and webhook_url are required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(2048),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the channel name
|
||||
func (c *WeComBotChannel) Name() string {
|
||||
return "wecom"
|
||||
}
|
||||
|
||||
// Start initializes the WeCom Bot channel
|
||||
func (c *WeComBotChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Starting WeCom Bot channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
c.SetRunning(true)
|
||||
logger.InfoC("wecom", "WeCom Bot channel started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the WeCom Bot channel
|
||||
func (c *WeComBotChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Stopping WeCom Bot channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.SetRunning(false)
|
||||
logger.InfoC("wecom", "WeCom Bot channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to WeCom user via webhook API
|
||||
// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message
|
||||
// For delayed responses, we use the webhook URL
|
||||
func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom", "Sending message via webhook", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
return c.sendWebhookReply(ctx, msg.ChatID, msg.Content)
|
||||
}
|
||||
|
||||
// WebhookPath returns the path for registering on the shared HTTP server.
|
||||
func (c *WeComBotChannel) WebhookPath() string {
|
||||
if c.config.WebhookPath != "" {
|
||||
return c.config.WebhookPath
|
||||
}
|
||||
return "/webhook/wecom"
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler for the shared HTTP server.
|
||||
func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleWebhook(w, r)
|
||||
}
|
||||
|
||||
// HealthPath returns the health check endpoint path.
|
||||
func (c *WeComBotChannel) HealthPath() string {
|
||||
return "/health/wecom"
|
||||
}
|
||||
|
||||
// HealthHandler handles health check requests.
|
||||
func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleHealth(w, r)
|
||||
}
|
||||
|
||||
// handleWebhook handles incoming webhook requests from WeCom
|
||||
func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
// Handle verification request
|
||||
c.handleVerification(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
// Handle message callback
|
||||
c.handleMessageCallback(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handleVerification handles the URL verification request from WeCom
|
||||
func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
echostr := query.Get("echostr")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) {
|
||||
logger.WarnC("wecom", "Signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt echostr
|
||||
// For AIBOT (智能机器人), receiveid should be empty string ""
|
||||
// Reference: https://developer.work.weixin.qq.com/document/path/101033
|
||||
decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), "")
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove BOM and whitespace as per WeCom documentation
|
||||
// The response must be plain text without quotes, BOM, or newlines
|
||||
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
|
||||
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
|
||||
w.Write([]byte(decryptedEchoStr))
|
||||
}
|
||||
|
||||
// handleMessageCallback handles incoming messages from WeCom
|
||||
func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Parse XML to get encrypted message
|
||||
var encryptedMsg struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
}
|
||||
|
||||
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid XML", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
|
||||
logger.WarnC("wecom", "Message signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt message
|
||||
// For AIBOT (智能机器人), receiveid should be empty string ""
|
||||
// Reference: https://developer.work.weixin.qq.com/document/path/101033
|
||||
decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), "")
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted JSON message (AIBOT uses JSON format)
|
||||
var msg WeComBotMessage
|
||||
if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message with the channel's long-lived context (not the HTTP
|
||||
// request context, which is canceled as soon as we return the response).
|
||||
go c.processMessage(c.ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom Bot requires response within configured timeout (default 5 seconds)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
|
||||
// processMessage processes the received message
|
||||
func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) {
|
||||
// Skip unsupported message types
|
||||
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" &&
|
||||
msg.MsgType != "mixed" {
|
||||
logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{
|
||||
"msg_type": msg.MsgType,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
msgID := msg.MsgID
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := msg.From.UserID
|
||||
|
||||
// Determine if this is a group chat or direct message
|
||||
// ChatType: "single" for DM, "group" for group chat
|
||||
isGroupChat := msg.ChatType == "group"
|
||||
|
||||
var chatID, peerKind, peerID string
|
||||
if isGroupChat {
|
||||
// Group chat: use ChatID as chatID and peer_id
|
||||
chatID = msg.ChatID
|
||||
peerKind = "group"
|
||||
peerID = msg.ChatID
|
||||
} else {
|
||||
// Direct message: use senderID as chatID and peer_id
|
||||
chatID = senderID
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
// Extract content based on message type
|
||||
var content string
|
||||
switch msg.MsgType {
|
||||
case "text":
|
||||
content = msg.Text.Content
|
||||
case "voice":
|
||||
content = msg.Voice.Content // Voice to text content
|
||||
case "mixed":
|
||||
// For mixed messages, concatenate text items
|
||||
for _, item := range msg.Mixed.MsgItem {
|
||||
if item.MsgType == "text" {
|
||||
content += item.Text.Content
|
||||
}
|
||||
}
|
||||
case "image", "file":
|
||||
// For image and file, we don't have text content
|
||||
content = ""
|
||||
}
|
||||
|
||||
// Build metadata
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
// In group chats, apply unified group trigger filtering
|
||||
if isGroupChat {
|
||||
respond, cleaned := c.ShouldRespondInGroup(false, content)
|
||||
if !respond {
|
||||
return
|
||||
}
|
||||
content = cleaned
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"msg_type": msg.MsgType,
|
||||
"msg_id": msg.MsgID,
|
||||
"platform": "wecom",
|
||||
"response_url": msg.ResponseURL,
|
||||
}
|
||||
if isGroupChat {
|
||||
metadata["chat_id"] = msg.ChatID
|
||||
metadata["sender_id"] = senderID
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"msg_type": msg.MsgType,
|
||||
"peer_kind": peerKind,
|
||||
"is_group_chat": isGroupChat,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Build sender info
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "wecom",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender)
|
||||
}
|
||||
|
||||
// sendWebhookReply sends a reply using the webhook URL
|
||||
func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error {
|
||||
reply := WeComBotReplyMessage{
|
||||
MsgType: "text",
|
||||
}
|
||||
reply.Text.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(reply)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal reply: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading webhook error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("webhook API error: %s", string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Check response
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]any{
|
||||
"status": "ok",
|
||||
"running": c.IsRunning(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
@@ -1,734 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// generateTestAESKey generates a valid test AES key
|
||||
func generateTestAESKey() string {
|
||||
// AES key needs to be 32 bytes (256 bits) for AES-256
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
// Return base64 encoded key without padding
|
||||
return base64.StdEncoding.EncodeToString(key)[:43]
|
||||
}
|
||||
|
||||
// encryptTestMessage encrypts a message for testing (AIBOT JSON format)
|
||||
func encryptTestMessage(message, aesKey string) (string, error) {
|
||||
// Decode AES key
|
||||
key, err := base64.StdEncoding.DecodeString(aesKey + "=")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prepare message: random(16) + msg_len(4) + msg + receiveid
|
||||
random := make([]byte, 0, 16)
|
||||
for i := range 16 {
|
||||
random = append(random, byte(i))
|
||||
}
|
||||
|
||||
msgBytes := []byte(message)
|
||||
receiveID := []byte("test_aibot_id")
|
||||
|
||||
msgLen := uint32(len(msgBytes))
|
||||
lenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(lenBytes, msgLen)
|
||||
|
||||
plainText := append(random, lenBytes...)
|
||||
plainText = append(plainText, msgBytes...)
|
||||
plainText = append(plainText, receiveID...)
|
||||
|
||||
// PKCS7 padding
|
||||
blockSize := aes.BlockSize
|
||||
padding := blockSize - len(plainText)%blockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
plainText = append(plainText, padText...)
|
||||
|
||||
// Encrypt
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize])
|
||||
cipherText := make([]byte, len(plainText))
|
||||
mode.CryptBlocks(cipherText, plainText)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// generateSignature generates a signature for testing
|
||||
func generateSignature(token, timestamp, nonce, msgEncrypt string) string {
|
||||
params := []string{token, timestamp, nonce, msgEncrypt}
|
||||
sort.Strings(params)
|
||||
str := strings.Join(params, "")
|
||||
hash := sha1.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
func TestNewWeComBotChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing token", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
_, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing webhook_url", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = ""
|
||||
_, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing webhook_url, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.AllowFrom = []string{"user1", "user2"}
|
||||
ch, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "wecom" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "wecom")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotChannelIsAllowed(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("empty allowlist allows all", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.AllowFrom = []string{}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("any_user") {
|
||||
t.Error("empty allowlist should allow all users")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allowlist restricts users", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.AllowFrom = []string{"allowed_user"}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("allowed_user") {
|
||||
t.Error("allowed user should pass allowlist check")
|
||||
}
|
||||
if ch.IsAllowed("blocked_user") {
|
||||
t.Error("non-allowed user should be blocked")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotVerifySignature(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid signature", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
msgEncrypt := "test_message"
|
||||
expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt)
|
||||
|
||||
if !verifySignature(ch.config.Token(), expectedSig, timestamp, nonce, msgEncrypt) {
|
||||
t.Error("valid signature should pass verification")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
msgEncrypt := "test_message"
|
||||
|
||||
if verifySignature(ch.config.Token(), "invalid_sig", timestamp, nonce, msgEncrypt) {
|
||||
t.Error("invalid signature should fail verification")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
|
||||
cfgEmpty := config.WeComConfig{}
|
||||
cfgEmpty.SetToken("")
|
||||
cfgEmpty.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
chEmpty := &WeComBotChannel{
|
||||
config: cfgEmpty,
|
||||
}
|
||||
|
||||
if verifySignature(chEmpty.config.Token(), "any_sig", "any_ts", "any_nonce", "any_msg") {
|
||||
t.Error("empty token should reject verification (fail-closed)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotDecryptMessage(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("decrypt without AES key", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey("")
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
// Without AES key, message should be base64 decoded only
|
||||
plainText := "hello world"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
|
||||
|
||||
result, err := decryptMessage(encoded, ch.config.EncodingAESKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != plainText {
|
||||
t.Errorf("decryptMessage() = %q, want %q", result, plainText)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decrypt with AES key", func(t *testing.T) {
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey(aesKey)
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
originalMsg := "<xml><Content>Hello</Content></xml>"
|
||||
encrypted, err := encryptTestMessage(originalMsg, aesKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt test message: %v", err)
|
||||
}
|
||||
|
||||
result, err := decryptMessage(encrypted, ch.config.EncodingAESKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != originalMsg {
|
||||
t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid base64", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey("")
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
_, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey())
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid base64, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid AES key", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey("invalid_key")
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
_, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey())
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid AES key, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotPKCS7Unpad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "valid padding 3 bytes",
|
||||
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
|
||||
expected: []byte("hello"),
|
||||
},
|
||||
{
|
||||
name: "valid padding 16 bytes (full block)",
|
||||
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("123456789012345"),
|
||||
},
|
||||
{
|
||||
name: "invalid padding larger than data",
|
||||
input: []byte{20},
|
||||
expected: nil, // should return error
|
||||
},
|
||||
{
|
||||
name: "invalid padding zero",
|
||||
input: append([]byte("test"), byte(0)),
|
||||
expected: nil, // should return error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs7Unpad(tt.input)
|
||||
if tt.expected == nil {
|
||||
// This case should return an error
|
||||
if err == nil {
|
||||
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotHandleVerification(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(aesKey)
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid verification request", func(t *testing.T) {
|
||||
echostr := "test_echostr_123"
|
||||
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != echostr {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), echostr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing parameters", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
echostr := "test_echostr"
|
||||
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(aesKey)
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid group message callback", func(t *testing.T) {
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_456",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user456"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello Group"}
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing parameters", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid XML", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, "")
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
strings.NewReader("invalid xml"),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: "encrypted_data",
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotProcessMessage(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("process direct text message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_123",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
msg.Text.Content = "Hello World"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("process group text message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_456",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatID: "group_chat_id_123",
|
||||
ChatType: "group",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.From.UserID = "user456"
|
||||
msg.Text.Content = "Hello Group"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("process voice message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_789",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "voice",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
msg.Voice.Content = "Voice message text"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("skip unsupported message type", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_000",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "video",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleWebhook(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("GET request calls verification", func(t *testing.T) {
|
||||
echostr := "test_echostr"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(echostr))
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encoded)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("POST request calls message callback", func(t *testing.T) {
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: base64.StdEncoding.EncodeToString([]byte("test")),
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
// Should not be method not allowed
|
||||
if w.Code == http.StatusMethodNotAllowed {
|
||||
t.Error("POST request should not return Method Not Allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported method", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleHealth(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleHealth(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want %q", contentType, "application/json")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "status") || !strings.Contains(body, "running") {
|
||||
t.Errorf("response body should contain status and running fields, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotReplyMessage(t *testing.T) {
|
||||
msg := WeComBotReplyMessage{
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.Text.Content = "Hello World"
|
||||
|
||||
if msg.MsgType != "text" {
|
||||
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
|
||||
}
|
||||
if msg.Text.Content != "Hello World" {
|
||||
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotMessageStructure(t *testing.T) {
|
||||
jsonData := `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`
|
||||
|
||||
var msg WeComBotMessage
|
||||
err := json.Unmarshal([]byte(jsonData), &msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
if msg.MsgID != "test_msg_id_123" {
|
||||
t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123")
|
||||
}
|
||||
if msg.AIBotID != "test_aibot_id" {
|
||||
t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id")
|
||||
}
|
||||
if msg.ChatID != "group_chat_id_123" {
|
||||
t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123")
|
||||
}
|
||||
if msg.ChatType != "group" {
|
||||
t.Errorf("ChatType = %q, want %q", msg.ChatType, "group")
|
||||
}
|
||||
if msg.From.UserID != "user123" {
|
||||
t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123")
|
||||
}
|
||||
if msg.MsgType != "text" {
|
||||
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
|
||||
}
|
||||
if msg.Text.Content != "Hello World" {
|
||||
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
|
||||
}
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// blockSize is the PKCS7 block size used by WeCom (32)
|
||||
const blockSize = 32
|
||||
|
||||
// computeSignature computes the WeCom message signature from the given parameters.
|
||||
// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest.
|
||||
func computeSignature(token, timestamp, nonce, encrypt string) string {
|
||||
params := []string{token, timestamp, nonce, encrypt}
|
||||
sort.Strings(params)
|
||||
str := strings.Join(params, "")
|
||||
hash := sha1.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// verifySignature verifies the message signature for WeCom
|
||||
// This is a common function used by both WeCom Bot and WeCom App
|
||||
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
|
||||
}
|
||||
|
||||
// decryptMessage decrypts the encrypted message using AES
|
||||
// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id
|
||||
func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) {
|
||||
return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "")
|
||||
}
|
||||
|
||||
// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid
|
||||
// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification.
|
||||
func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) {
|
||||
if encodingAESKey == "" {
|
||||
// No encryption, return as is (base64 decode)
|
||||
decoded, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(decoded), nil
|
||||
}
|
||||
|
||||
aesKey, err := decodeWeComAESKey(encodingAESKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
|
||||
plainText, err := decryptAESCBC(aesKey, cipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return unpackWeComFrame(plainText, receiveid)
|
||||
}
|
||||
|
||||
// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
|
||||
// appended automatically) and validates that the result is exactly 32 bytes.
|
||||
// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
|
||||
func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AES key: %w", err)
|
||||
}
|
||||
if len(aesKey) != 32 {
|
||||
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
|
||||
}
|
||||
return aesKey, nil
|
||||
}
|
||||
|
||||
// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
|
||||
// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
|
||||
// plaintext to a multiple of aes.BlockSize before calling.
|
||||
func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
ciphertext := make([]byte, len(plaintext))
|
||||
cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// packWeComFrame builds the WeCom wire format:
|
||||
//
|
||||
// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
|
||||
func packWeComFrame(msg, receiveid string) ([]byte, error) {
|
||||
randomBytes := make([]byte, 16)
|
||||
for i := range 16 {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(10))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random: %w", err)
|
||||
}
|
||||
randomBytes[i] = byte('0' + n.Int64())
|
||||
}
|
||||
msgBytes := []byte(msg)
|
||||
msgLenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
|
||||
var buf bytes.Buffer
|
||||
buf.Write(randomBytes)
|
||||
buf.Write(msgLenBytes)
|
||||
buf.Write(msgBytes)
|
||||
buf.WriteString(receiveid)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
|
||||
// If receiveid is non-empty it verifies the frame's trailing receiveid field.
|
||||
func unpackWeComFrame(data []byte, receiveid string) (string, error) {
|
||||
if len(data) < 20 {
|
||||
return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
|
||||
}
|
||||
msgLen := binary.BigEndian.Uint32(data[16:20])
|
||||
if int(msgLen) > len(data)-20 {
|
||||
return "", fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
msg := data[20 : 20+msgLen]
|
||||
if receiveid != "" && len(data) > 20+int(msgLen) {
|
||||
actualReceiveID := string(data[20+msgLen:])
|
||||
if actualReceiveID != receiveid {
|
||||
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
|
||||
}
|
||||
}
|
||||
return string(msg), nil
|
||||
}
|
||||
|
||||
// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
|
||||
// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
|
||||
func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("ciphertext is empty")
|
||||
}
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
|
||||
}
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
|
||||
plaintext, err = pkcs7Unpad(plaintext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unpad: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// pkcs7Pad adds PKCS7 padding
|
||||
func pkcs7Pad(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - (len(data) % blockSize)
|
||||
if padding == 0 {
|
||||
padding = blockSize
|
||||
}
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
// pkcs7Unpad removes PKCS7 padding with validation
|
||||
func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
padding := int(data[len(data)-1])
|
||||
// WeCom uses 32-byte block size for PKCS7 padding
|
||||
if padding == 0 || padding > blockSize {
|
||||
return nil, fmt.Errorf("invalid padding size: %d", padding)
|
||||
}
|
||||
if padding > len(data) {
|
||||
return nil, fmt.Errorf("padding size larger than data")
|
||||
}
|
||||
// Verify all padding bytes
|
||||
for i := range padding {
|
||||
if data[len(data)-1-i] != byte(padding) {
|
||||
return nil, fmt.Errorf("invalid padding byte at position %d", i)
|
||||
}
|
||||
}
|
||||
return data[:len(data)-padding], nil
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import "sync"
|
||||
|
||||
const wecomMaxProcessedMessages = 1000
|
||||
|
||||
// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
|
||||
// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
|
||||
// messages without causing "amnesia cliffs" when the limit is reached.
|
||||
type MessageDeduplicator struct {
|
||||
mu sync.Mutex
|
||||
msgs map[string]bool
|
||||
ring []string
|
||||
idx int
|
||||
max int
|
||||
}
|
||||
|
||||
// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
|
||||
func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
|
||||
if maxEntries <= 0 {
|
||||
maxEntries = wecomMaxProcessedMessages
|
||||
}
|
||||
return &MessageDeduplicator{
|
||||
msgs: make(map[string]bool, maxEntries),
|
||||
ring: make([]string, maxEntries),
|
||||
max: maxEntries,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
|
||||
func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 1. Check for duplicate
|
||||
if d.msgs[msgID] {
|
||||
return false
|
||||
}
|
||||
|
||||
// 2. Evict the oldest message at our current ring position (if any)
|
||||
oldestID := d.ring[d.idx]
|
||||
if oldestID != "" {
|
||||
delete(d.msgs, oldestID)
|
||||
}
|
||||
|
||||
// 3. Store the new message
|
||||
d.msgs[msgID] = true
|
||||
d.ring[d.idx] = msgID
|
||||
|
||||
// 4. Advance the circle queue index
|
||||
d.idx = (d.idx + 1) % d.max
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("first message should be accepted")
|
||||
}
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); ok {
|
||||
t.Fatalf("duplicate message should be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
const goroutines = 64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
results := make(chan bool, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results <- d.MarkMessageProcessed("msg-concurrent")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successes := 0
|
||||
for ok := range results {
|
||||
if ok {
|
||||
successes++
|
||||
}
|
||||
}
|
||||
|
||||
if successes != 1 {
|
||||
t.Fatalf("expected exactly 1 successful mark, got %d", successes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
|
||||
// Create a deduplicator with a very small capacity to test eviction easily.
|
||||
capacity := 3
|
||||
d := NewMessageDeduplicator(capacity)
|
||||
|
||||
// Fill the queue.
|
||||
d.MarkMessageProcessed("msg-1")
|
||||
d.MarkMessageProcessed("msg-2")
|
||||
d.MarkMessageProcessed("msg-3")
|
||||
|
||||
// At this point, the queue is full. msg-1 is the oldest.
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// This should evict msg-1 and add msg-4.
|
||||
if ok := d.MarkMessageProcessed("msg-4"); !ok {
|
||||
t.Fatalf("msg-4 should be accepted")
|
||||
}
|
||||
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// msg-1 should now be forgotten (evicted).
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("msg-1 should be accepted again because it was evicted")
|
||||
}
|
||||
|
||||
// msg-2 should have been evicted when we added msg-1 back.
|
||||
if ok := d.MarkMessageProcessed("msg-2"); !ok {
|
||||
t.Fatalf("msg-2 should be accepted again because it was evicted")
|
||||
}
|
||||
}
|
||||
@@ -8,12 +8,6 @@ import (
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComBotChannel(cfg.Channels.WeCom, b)
|
||||
})
|
||||
channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComAppChannel(cfg.Channels.WeComApp, b)
|
||||
})
|
||||
channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b)
|
||||
return NewChannel(cfg.Channels.WeCom, b)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,802 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomOutboundMediaMaxBytes = 20 << 20
|
||||
wecomOutboundImageMaxBytes = 2 << 20
|
||||
wecomOutboundVoiceMaxBytes = 2 << 20
|
||||
wecomOutboundVideoMaxBytes = 10 << 20
|
||||
wecomUploadChunkMaxBytes = 512 << 10
|
||||
wecomUploadMaxChunks = 100
|
||||
wecomUploadMinBytes = 5
|
||||
)
|
||||
|
||||
type wecomOutboundMedia struct {
|
||||
MsgType string
|
||||
MediaID string
|
||||
Title string
|
||||
Description string
|
||||
}
|
||||
|
||||
func (m *wecomOutboundMedia) respondBody() wecomRespondMsgBody {
|
||||
body := wecomRespondMsgBody{MsgType: m.MsgType}
|
||||
switch m.MsgType {
|
||||
case "file":
|
||||
body.File = &wecomMediaRefContent{MediaID: m.MediaID}
|
||||
case "image":
|
||||
body.Image = &wecomMediaRefContent{MediaID: m.MediaID}
|
||||
case "voice":
|
||||
body.Voice = &wecomMediaRefContent{MediaID: m.MediaID}
|
||||
case "video":
|
||||
body.Video = &wecomVideoContent{
|
||||
MediaID: m.MediaID,
|
||||
Title: m.Title,
|
||||
Description: m.Description,
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (m *wecomOutboundMedia) sendBody(chatID string, chatType uint32) wecomSendMsgBody {
|
||||
body := wecomSendMsgBody{
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
MsgType: m.MsgType,
|
||||
}
|
||||
switch m.MsgType {
|
||||
case "file":
|
||||
body.File = &wecomMediaRefContent{MediaID: m.MediaID}
|
||||
case "image":
|
||||
body.Image = &wecomMediaRefContent{MediaID: m.MediaID}
|
||||
case "voice":
|
||||
body.Voice = &wecomMediaRefContent{MediaID: m.MediaID}
|
||||
case "video":
|
||||
body.Video = &wecomVideoContent{
|
||||
MediaID: m.MediaID,
|
||||
Title: m.Title,
|
||||
Description: m.Description,
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func decodeMediaAESKey(value string) ([]byte, error) {
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
key, err := base64.StdEncoding.DecodeString(value)
|
||||
if err == nil && len(key) == 32 {
|
||||
return key, nil
|
||||
}
|
||||
key, err = base64.StdEncoding.DecodeString(value + "=")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode AES key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("invalid AES key length %d", len(key))
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func decryptAESCBC(key, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("ciphertext is empty")
|
||||
}
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
iv := key[:aes.BlockSize]
|
||||
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs7Unpad(plaintext)
|
||||
}
|
||||
|
||||
func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, fmt.Errorf("empty plaintext")
|
||||
}
|
||||
padding := int(data[len(data)-1])
|
||||
if padding == 0 || padding > 32 || padding > len(data) {
|
||||
return nil, fmt.Errorf("invalid padding size %d", padding)
|
||||
}
|
||||
for i := 0; i < padding; i++ {
|
||||
if data[len(data)-1-i] != byte(padding) {
|
||||
return nil, fmt.Errorf("invalid padding byte")
|
||||
}
|
||||
}
|
||||
return data[:len(data)-padding], nil
|
||||
}
|
||||
|
||||
func inferMediaExt(contentType, fallback string) string {
|
||||
contentType = normalizeWeComContentType(contentType)
|
||||
switch contentType {
|
||||
case "image/jpeg", "image/jpg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "application/pdf":
|
||||
return ".pdf"
|
||||
case "video/mp4":
|
||||
return ".mp4"
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeWeComContentType(value string) string {
|
||||
value = strings.ToLower(strings.TrimSpace(value))
|
||||
if idx := strings.Index(value, ";"); idx >= 0 {
|
||||
value = strings.TrimSpace(value[:idx])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func isGenericWeComContentType(value string) bool {
|
||||
switch normalizeWeComContentType(value) {
|
||||
case "", "application/octet-stream", "binary/octet-stream", "application/unknown", "application/binary":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeWeComFilename(name string) string {
|
||||
name = filepath.Base(strings.TrimSpace(name))
|
||||
if name == "." || name == "/" || name == "" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func candidateWeComFilename(resourceURL, contentDisposition, fallbackName string) string {
|
||||
if _, params, err := mime.ParseMediaType(contentDisposition); err == nil {
|
||||
if name := sanitizeWeComFilename(params["filename"]); name != "" {
|
||||
return name
|
||||
}
|
||||
if name := sanitizeWeComFilename(params["filename*"]); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
if parsed, err := url.Parse(resourceURL); err == nil {
|
||||
query := parsed.Query()
|
||||
for _, key := range []string{"filename", "file_name", "name"} {
|
||||
if name := sanitizeWeComFilename(query.Get(key)); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
if name := sanitizeWeComFilename(parsed.Path); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
return sanitizeWeComFilename(fallbackName)
|
||||
}
|
||||
|
||||
func detectWeComFiletype(data []byte) (string, string) {
|
||||
kind, err := filetype.Match(data)
|
||||
if err != nil || kind == filetype.Unknown {
|
||||
return "", ""
|
||||
}
|
||||
ext := ""
|
||||
if kind.Extension != "" {
|
||||
ext = "." + strings.ToLower(kind.Extension)
|
||||
}
|
||||
return normalizeWeComContentType(kind.MIME.Value), ext
|
||||
}
|
||||
|
||||
func detectWeComMediaMetadata(
|
||||
data []byte,
|
||||
fallbackName, fallbackContentType, resourceURL, contentDisposition string,
|
||||
) (string, string) {
|
||||
filename := candidateWeComFilename(resourceURL, contentDisposition, fallbackName)
|
||||
if filename == "" {
|
||||
filename = "media"
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
contentType := normalizeWeComContentType(fallbackContentType)
|
||||
detectedType, detectedExt := detectWeComFiletype(data)
|
||||
|
||||
if ext != "" && isGenericWeComContentType(contentType) {
|
||||
if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" {
|
||||
contentType = byExt
|
||||
}
|
||||
}
|
||||
|
||||
if detectedType != "" {
|
||||
switch {
|
||||
case contentType == "":
|
||||
contentType = detectedType
|
||||
case isGenericWeComContentType(contentType):
|
||||
contentType = detectedType
|
||||
case strings.HasPrefix(detectedType, "image/") && !strings.HasPrefix(contentType, "image/"):
|
||||
contentType = detectedType
|
||||
case strings.HasPrefix(detectedType, "audio/") && !strings.HasPrefix(contentType, "audio/"):
|
||||
contentType = detectedType
|
||||
case strings.HasPrefix(detectedType, "video/") && !strings.HasPrefix(contentType, "video/"):
|
||||
contentType = detectedType
|
||||
}
|
||||
}
|
||||
|
||||
if contentType == "" && ext != "" {
|
||||
contentType = normalizeWeComContentType(mime.TypeByExtension(ext))
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = normalizeWeComContentType(http.DetectContentType(data))
|
||||
}
|
||||
|
||||
if ext == "" {
|
||||
ext = detectedExt
|
||||
}
|
||||
if ext == "" && contentType != "" {
|
||||
if exts, err := mime.ExtensionsByType(contentType); err == nil && len(exts) > 0 {
|
||||
ext = strings.ToLower(exts[0])
|
||||
}
|
||||
}
|
||||
|
||||
if filepath.Ext(filename) == "" && ext != "" {
|
||||
filename += ext
|
||||
}
|
||||
return filename, contentType
|
||||
}
|
||||
|
||||
func (c *WeComChannel) storeRemoteMedia(
|
||||
ctx context.Context,
|
||||
scope, msgID, resourceURL, aesKey, fallbackExt string,
|
||||
) (string, error) {
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("no media store available")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
resp, err := c.mediaClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("download media: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("download media returned HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read media: %w", err)
|
||||
}
|
||||
if len(data) > wecomOutboundMediaMaxBytes {
|
||||
return "", fmt.Errorf("media too large")
|
||||
}
|
||||
|
||||
if aesKey != "" {
|
||||
key, keyErr := decodeMediaAESKey(aesKey)
|
||||
if keyErr != nil {
|
||||
return "", keyErr
|
||||
}
|
||||
data, err = decryptAESCBC(key, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt media: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filename, contentType := detectWeComMediaMetadata(
|
||||
data,
|
||||
msgID+fallbackExt,
|
||||
resp.Header.Get("Content-Type"),
|
||||
resourceURL,
|
||||
resp.Header.Get("Content-Disposition"),
|
||||
)
|
||||
ext := filepath.Ext(filename)
|
||||
if ext == "" {
|
||||
ext = inferMediaExt(contentType, fallbackExt)
|
||||
}
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
return "", fmt.Errorf("mkdir media dir: %w", mkdirErr)
|
||||
}
|
||||
tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, writeErr := tmpFile.Write(data); writeErr != nil {
|
||||
tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("write temp file: %w", writeErr)
|
||||
}
|
||||
if closeErr := tmpFile.Close(); closeErr != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("close temp file: %w", closeErr)
|
||||
}
|
||||
|
||||
ref, err := store.Store(tmpPath, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: contentType,
|
||||
Source: "wecom",
|
||||
CleanupPolicy: media.CleanupPolicyDeleteOnCleanup,
|
||||
}, scope)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", err
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
func detectLocalWeComContentType(localPath, hint string) string {
|
||||
contentType := normalizeWeComContentType(hint)
|
||||
if !isGenericWeComContentType(contentType) {
|
||||
return contentType
|
||||
}
|
||||
|
||||
if kind, err := filetype.MatchFile(localPath); err == nil && kind != filetype.Unknown {
|
||||
return normalizeWeComContentType(kind.MIME.Value)
|
||||
}
|
||||
|
||||
if ext := strings.ToLower(filepath.Ext(localPath)); ext != "" {
|
||||
if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" {
|
||||
return byExt
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return contentType
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
n, err := file.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
return contentType
|
||||
}
|
||||
if n == 0 {
|
||||
return contentType
|
||||
}
|
||||
return normalizeWeComContentType(http.DetectContentType(buf[:n]))
|
||||
}
|
||||
|
||||
func writeWeComTempFile(prefix, filename string, data []byte) (string, error) {
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
return "", fmt.Errorf("mkdir media dir: %w", err)
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
tmpFile, err := os.CreateTemp(mediaDir, prefix+"-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("close temp file: %w", err)
|
||||
}
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) downloadRemoteMediaToTemp(
|
||||
ctx context.Context,
|
||||
resourceURL, fallbackName string,
|
||||
) (string, string, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.mediaClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("download media: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return "", "", "", fmt.Errorf("download media returned HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, wecomOutboundMediaMaxBytes+1))
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("read media: %w", err)
|
||||
}
|
||||
if len(data) > wecomOutboundMediaMaxBytes {
|
||||
return "", "", "", fmt.Errorf("media too large")
|
||||
}
|
||||
|
||||
filename, contentType := detectWeComMediaMetadata(
|
||||
data,
|
||||
fallbackName,
|
||||
resp.Header.Get("Content-Type"),
|
||||
resourceURL,
|
||||
resp.Header.Get("Content-Disposition"),
|
||||
)
|
||||
tmpPath, err := writeWeComTempFile("wecom-outbound", filename, data)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
return tmpPath, filename, contentType, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) resolveOutboundPart(
|
||||
ctx context.Context,
|
||||
part bus.MediaPart,
|
||||
) (string, string, string, func(), error) {
|
||||
cleanup := func() {}
|
||||
filename := sanitizeWeComFilename(part.Filename)
|
||||
contentType := normalizeWeComContentType(part.ContentType)
|
||||
ref := strings.TrimSpace(part.Ref)
|
||||
|
||||
switch {
|
||||
case ref == "":
|
||||
return "", filename, contentType, cleanup, nil
|
||||
|
||||
case strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://"):
|
||||
localPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, ref, filename)
|
||||
if err != nil {
|
||||
return "", "", "", cleanup, err
|
||||
}
|
||||
return localPath, name, ct, func() { _ = os.Remove(localPath) }, nil
|
||||
|
||||
case strings.HasPrefix(ref, "media://"):
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return "", "", "", cleanup, fmt.Errorf("no media store available")
|
||||
}
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
return "", "", "", cleanup, err
|
||||
}
|
||||
if filename == "" {
|
||||
filename = sanitizeWeComFilename(meta.Filename)
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = normalizeWeComContentType(meta.ContentType)
|
||||
}
|
||||
if strings.HasPrefix(localPath, "http://") || strings.HasPrefix(localPath, "https://") {
|
||||
tmpPath, name, ct, err := c.downloadRemoteMediaToTemp(ctx, localPath, filename)
|
||||
if err != nil {
|
||||
return "", "", "", cleanup, err
|
||||
}
|
||||
return tmpPath, name, ct, func() { _ = os.Remove(tmpPath) }, nil
|
||||
}
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
return "", "", "", cleanup, err
|
||||
}
|
||||
if filename == "" {
|
||||
filename = sanitizeWeComFilename(filepath.Base(localPath))
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = detectLocalWeComContentType(localPath, "")
|
||||
}
|
||||
return localPath, filename, contentType, cleanup, nil
|
||||
|
||||
case strings.HasPrefix(ref, "file://"):
|
||||
u, err := url.Parse(ref)
|
||||
if err != nil {
|
||||
return "", "", "", cleanup, err
|
||||
}
|
||||
localPath := u.Path
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
return "", "", "", cleanup, err
|
||||
}
|
||||
if filename == "" {
|
||||
filename = sanitizeWeComFilename(filepath.Base(localPath))
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = detectLocalWeComContentType(localPath, "")
|
||||
}
|
||||
return localPath, filename, contentType, cleanup, nil
|
||||
|
||||
default:
|
||||
if _, err := os.Stat(ref); err != nil {
|
||||
return "", "", "", cleanup, err
|
||||
}
|
||||
if filename == "" {
|
||||
filename = sanitizeWeComFilename(filepath.Base(ref))
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = detectLocalWeComContentType(ref, "")
|
||||
}
|
||||
return ref, filename, contentType, cleanup, nil
|
||||
}
|
||||
}
|
||||
|
||||
func canWeComSendImage(contentType, ext string, size int64) bool {
|
||||
if size > wecomOutboundImageMaxBytes {
|
||||
return false
|
||||
}
|
||||
switch normalizeWeComContentType(contentType) {
|
||||
case "image/jpeg", "image/jpg", "image/png", "image/gif":
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(ext) {
|
||||
case ".jpg", ".jpeg", ".png", ".gif":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func canWeComSendVoice(contentType, ext string, size int64) bool {
|
||||
if size > wecomOutboundVoiceMaxBytes {
|
||||
return false
|
||||
}
|
||||
contentType = normalizeWeComContentType(contentType)
|
||||
return strings.Contains(contentType, "amr") || strings.EqualFold(ext, ".amr")
|
||||
}
|
||||
|
||||
func canWeComSendVideo(contentType, ext string, size int64) bool {
|
||||
if size > wecomOutboundVideoMaxBytes {
|
||||
return false
|
||||
}
|
||||
return normalizeWeComContentType(contentType) == "video/mp4" || strings.EqualFold(ext, ".mp4")
|
||||
}
|
||||
|
||||
func outboundWeComMediaKind(partType, filename, contentType string, size int64) string {
|
||||
if size < wecomUploadMinBytes {
|
||||
return ""
|
||||
}
|
||||
|
||||
partType = strings.ToLower(strings.TrimSpace(partType))
|
||||
contentType = normalizeWeComContentType(contentType)
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
|
||||
if partType == "file" {
|
||||
if size <= wecomOutboundMediaMaxBytes {
|
||||
return "file"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
if (partType == "image" || partType == "") && canWeComSendImage(contentType, ext, size) {
|
||||
return "image"
|
||||
}
|
||||
if (partType == "audio" || partType == "voice" || partType == "") && canWeComSendVoice(contentType, ext, size) {
|
||||
return "voice"
|
||||
}
|
||||
if (partType == "video" || partType == "") && canWeComSendVideo(contentType, ext, size) {
|
||||
return "video"
|
||||
}
|
||||
if size <= wecomOutboundMediaMaxBytes {
|
||||
return "file"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func trimWeComBytes(value string, limit int) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if limit <= 0 || len(value) <= limit {
|
||||
return value
|
||||
}
|
||||
size := 0
|
||||
var out strings.Builder
|
||||
for _, r := range value {
|
||||
width := len(string(r))
|
||||
if size+width > limit {
|
||||
break
|
||||
}
|
||||
size += width
|
||||
out.WriteRune(r)
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func ensureWeComOutboundFilename(filename, localPath, contentType string) string {
|
||||
filename = sanitizeWeComFilename(filename)
|
||||
if filename == "" {
|
||||
filename = sanitizeWeComFilename(filepath.Base(localPath))
|
||||
}
|
||||
if filename == "" {
|
||||
filename = "media"
|
||||
}
|
||||
if filepath.Ext(filename) == "" {
|
||||
fallbackExt := inferMediaExt(contentType, strings.ToLower(filepath.Ext(localPath)))
|
||||
if fallbackExt != "" {
|
||||
filename += fallbackExt
|
||||
}
|
||||
}
|
||||
filename = trimWeComBytes(filename, 256)
|
||||
if filename == "" {
|
||||
return "media"
|
||||
}
|
||||
return filename
|
||||
}
|
||||
|
||||
func buildWeComVideoContent(mediaID, filename, description string) *wecomVideoContent {
|
||||
title := strings.TrimSuffix(filename, filepath.Ext(filename))
|
||||
title = trimWeComBytes(title, 64)
|
||||
if title == "" {
|
||||
title = "video"
|
||||
}
|
||||
description = trimWeComBytes(description, 512)
|
||||
return &wecomVideoContent{
|
||||
MediaID: mediaID,
|
||||
Title: title,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
|
||||
func decodeWeComEnvelopeBody[T any](env wecomEnvelope) (T, error) {
|
||||
var out T
|
||||
if len(env.Body) == 0 {
|
||||
return out, fmt.Errorf("wecom response body is empty")
|
||||
}
|
||||
if err := json.Unmarshal(env.Body, &out); err != nil {
|
||||
return out, fmt.Errorf("decode wecom response body: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) uploadOutboundMedia(
|
||||
ctx context.Context,
|
||||
localPath, filename, contentType string,
|
||||
part bus.MediaPart,
|
||||
) (*wecomOutboundMedia, error) {
|
||||
_ = ctx
|
||||
|
||||
contentType = detectLocalWeComContentType(localPath, contentType)
|
||||
filename = ensureWeComOutboundFilename(filename, localPath, contentType)
|
||||
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read media file: %w", err)
|
||||
}
|
||||
size := int64(len(data))
|
||||
kind := outboundWeComMediaKind(part.Type, filename, contentType, size)
|
||||
if kind == "" {
|
||||
return nil, fmt.Errorf("unsupported wecom media type or size for %q", filename)
|
||||
}
|
||||
|
||||
totalChunks := (len(data) + wecomUploadChunkMaxBytes - 1) / wecomUploadChunkMaxBytes
|
||||
if totalChunks <= 0 || totalChunks > wecomUploadMaxChunks {
|
||||
return nil, fmt.Errorf("wecom upload requires 1-%d chunks, got %d", wecomUploadMaxChunks, totalChunks)
|
||||
}
|
||||
|
||||
sum := md5.Sum(data)
|
||||
initEnv, err := c.sendCommandAck(wecomCommand{
|
||||
Cmd: wecomCmdUploadMediaInit,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: wecomUploadMediaInitBody{
|
||||
Type: kind,
|
||||
Filename: filename,
|
||||
TotalSize: size,
|
||||
TotalChunks: totalChunks,
|
||||
MD5: hex.EncodeToString(sum[:]),
|
||||
},
|
||||
}, wecomUploadTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initResp, err := decodeWeComEnvelopeBody[wecomUploadMediaInitResponse](initEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(initResp.UploadID) == "" {
|
||||
return nil, fmt.Errorf("wecom upload init returned empty upload_id")
|
||||
}
|
||||
|
||||
for idx, offset := 0, 0; offset < len(data); idx, offset = idx+1, offset+wecomUploadChunkMaxBytes {
|
||||
end := offset + wecomUploadChunkMaxBytes
|
||||
if end > len(data) {
|
||||
end = len(data)
|
||||
}
|
||||
sendErr := c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdUploadMediaChunk,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: wecomUploadMediaChunkBody{
|
||||
UploadID: initResp.UploadID,
|
||||
ChunkIndex: idx,
|
||||
Base64Data: base64.StdEncoding.EncodeToString(data[offset:end]),
|
||||
},
|
||||
}, wecomUploadTimeout)
|
||||
if sendErr != nil {
|
||||
return nil, sendErr
|
||||
}
|
||||
}
|
||||
|
||||
finishEnv, err := c.sendCommandAck(wecomCommand{
|
||||
Cmd: wecomCmdUploadMediaEnd,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: wecomUploadMediaFinishBody{
|
||||
UploadID: initResp.UploadID,
|
||||
},
|
||||
}, wecomUploadTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
finishResp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](finishEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(finishResp.MediaID) == "" {
|
||||
return nil, fmt.Errorf("wecom upload finish returned empty media_id")
|
||||
}
|
||||
|
||||
uploaded := &wecomOutboundMedia{
|
||||
MsgType: kind,
|
||||
MediaID: finishResp.MediaID,
|
||||
}
|
||||
if kind == "video" {
|
||||
video := buildWeComVideoContent(finishResp.MediaID, filename, part.Caption)
|
||||
uploaded.Title = video.Title
|
||||
uploaded.Description = video.Description
|
||||
}
|
||||
return uploaded, nil
|
||||
}
|
||||
|
||||
func fallbackWeComMediaText(part bus.MediaPart, kind, filename string) string {
|
||||
var lines []string
|
||||
if caption := strings.TrimSpace(part.Caption); caption != "" {
|
||||
lines = append(lines, caption)
|
||||
}
|
||||
|
||||
label := kind
|
||||
if label == "" {
|
||||
label = "media"
|
||||
}
|
||||
if filename != "" {
|
||||
lines = append(lines, fmt.Sprintf("[%s: %s]", label, filename))
|
||||
} else {
|
||||
lines = append(lines, fmt.Sprintf("[%s attachment]", label))
|
||||
}
|
||||
|
||||
ref := strings.TrimSpace(part.Ref)
|
||||
if strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") {
|
||||
lines = append(lines, ref)
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (c *WeComChannel) resolveMediaRoute(chatID string) (wecomTurn, uint32, bool) {
|
||||
if turn, ok := c.getTurn(chatID); ok {
|
||||
if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration {
|
||||
return turn, turn.ChatType, true
|
||||
}
|
||||
c.deleteTurn(chatID)
|
||||
}
|
||||
if route, ok := c.routes.Get(chatID); ok {
|
||||
return wecomTurn{ChatID: route.ChatID, ChatType: route.ChatType}, route.ChatType, false
|
||||
}
|
||||
return wecomTurn{ChatID: chatID}, 0, false
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
basechannels "github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestStoreRemoteMedia_DetectsJPEGContentTypeFromBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" +
|
||||
"//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" +
|
||||
"//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k="
|
||||
|
||||
jpegData := decodeTestBase64(t, jpegBase64)
|
||||
store := media.NewFileMediaStore()
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
|
||||
mediaClient: &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/octet-stream"}},
|
||||
Body: io.NopCloser(bytes.NewReader(jpegData)),
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeRemoteMedia(context.Background(), "test-scope", "msg-1", "https://wecom.example/media", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("storeRemoteMedia returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = store.ReleaseAll("test-scope")
|
||||
})
|
||||
|
||||
_, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve media ref: %v", err)
|
||||
}
|
||||
if meta.ContentType != "image/jpeg" {
|
||||
t.Fatalf("expected image/jpeg content type, got %q", meta.ContentType)
|
||||
}
|
||||
if !strings.HasSuffix(meta.Filename, ".jpg") && !strings.HasSuffix(meta.Filename, ".jpeg") {
|
||||
t.Fatalf("expected jpeg filename, got %q", meta.Filename)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectWeComMediaMetadata_UsesFallbackExtensionWhenBodyUnknown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
filename, contentType := detectWeComMediaMetadata([]byte("not a real image"), "msg-2.pdf", "", "", "")
|
||||
if filename != "msg-2.pdf" {
|
||||
t.Fatalf("expected fallback filename to be preserved, got %q", filename)
|
||||
}
|
||||
if contentType != "application/pdf" {
|
||||
t.Fatalf("expected application/pdf from fallback extension, got %q", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRemoteMedia_PreservesSuffixFromURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
docxLikeData := []byte("PK\x03\x04fake office payload")
|
||||
store := media.NewFileMediaStore()
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
|
||||
mediaClient: &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/octet-stream"}},
|
||||
Body: io.NopCloser(bytes.NewReader(docxLikeData)),
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeRemoteMedia(
|
||||
context.Background(),
|
||||
"test-scope",
|
||||
"msg-docx",
|
||||
"https://wecom.example/media/report.docx?signature=1",
|
||||
"",
|
||||
".bin",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("storeRemoteMedia returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = store.ReleaseAll("test-scope")
|
||||
})
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve media ref: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(meta.Filename, ".docx") {
|
||||
t.Fatalf("expected docx filename, got %q", meta.Filename)
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(localPath), ".docx") {
|
||||
t.Fatalf("expected docx temp path, got %q", localPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRemoteMedia_PreservesSuffixFromContentDisposition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pptxLikeData := []byte("PK\x03\x04fake office payload")
|
||||
store := media.NewFileMediaStore()
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
|
||||
mediaClient: &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
"Content-Disposition": []string{`attachment; filename="slides.pptx"`},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(pptxLikeData)),
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeRemoteMedia(
|
||||
context.Background(),
|
||||
"test-scope",
|
||||
"msg-pptx",
|
||||
"https://wecom.example/media/download",
|
||||
"",
|
||||
".bin",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("storeRemoteMedia returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = store.ReleaseAll("test-scope")
|
||||
})
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve media ref: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(meta.Filename, ".pptx") {
|
||||
t.Fatalf("expected pptx filename, got %q", meta.Filename)
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(localPath), ".pptx") {
|
||||
t.Fatalf("expected pptx temp path, got %q", localPath)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeTestBase64(t *testing.T, value string) []byte {
|
||||
t.Helper()
|
||||
|
||||
data, err := io.ReadAll(base64.NewDecoder(base64.StdEncoding, strings.NewReader(value)))
|
||||
if err != nil {
|
||||
t.Fatalf("decode base64 fixture: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package wecom
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
const (
|
||||
wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com"
|
||||
wecomCmdSubscribe = "aibot_subscribe"
|
||||
wecomCmdPing = "ping"
|
||||
wecomCmdMsgCallback = "aibot_msg_callback"
|
||||
wecomCmdEventCallback = "aibot_event_callback"
|
||||
wecomCmdRespondMsg = "aibot_respond_msg"
|
||||
wecomCmdSendMsg = "aibot_send_msg"
|
||||
wecomCmdUploadMediaInit = "aibot_upload_media_init"
|
||||
wecomCmdUploadMediaChunk = "aibot_upload_media_chunk"
|
||||
wecomCmdUploadMediaEnd = "aibot_upload_media_finish"
|
||||
)
|
||||
|
||||
type wecomEnvelope struct {
|
||||
Cmd string `json:"cmd,omitempty"`
|
||||
Headers wecomHeaders `json:"headers"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
ErrCode int `json:"errcode,omitempty"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
type wecomHeaders struct {
|
||||
ReqID string `json:"req_id,omitempty"`
|
||||
}
|
||||
|
||||
type wecomCommand struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Headers wecomHeaders `json:"headers"`
|
||||
Body any `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
type wecomSendMsgBody struct {
|
||||
ChatID string `json:"chatid"`
|
||||
ChatType uint32 `json:"chat_type,omitempty"`
|
||||
MsgType string `json:"msgtype"`
|
||||
Markdown *wecomMarkdownContent `json:"markdown,omitempty"`
|
||||
File *wecomMediaRefContent `json:"file,omitempty"`
|
||||
Image *wecomMediaRefContent `json:"image,omitempty"`
|
||||
Voice *wecomMediaRefContent `json:"voice,omitempty"`
|
||||
Video *wecomVideoContent `json:"video,omitempty"`
|
||||
TemplateCard map[string]any `json:"template_card,omitempty"`
|
||||
}
|
||||
|
||||
type wecomRespondMsgBody struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Stream *wecomStreamContent `json:"stream,omitempty"`
|
||||
Markdown *wecomMarkdownContent `json:"markdown,omitempty"`
|
||||
File *wecomMediaRefContent `json:"file,omitempty"`
|
||||
Image *wecomMediaRefContent `json:"image,omitempty"`
|
||||
Voice *wecomMediaRefContent `json:"voice,omitempty"`
|
||||
Video *wecomVideoContent `json:"video,omitempty"`
|
||||
TemplateCard map[string]any `json:"template_card,omitempty"`
|
||||
}
|
||||
|
||||
type wecomStreamContent struct {
|
||||
ID string `json:"id"`
|
||||
Finish bool `json:"finish"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type wecomMarkdownContent struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type wecomMediaRefContent struct {
|
||||
MediaID string `json:"media_id"`
|
||||
}
|
||||
|
||||
type wecomVideoContent struct {
|
||||
MediaID string `json:"media_id"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type wecomUploadMediaInitBody struct {
|
||||
Type string `json:"type"`
|
||||
Filename string `json:"filename"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
TotalChunks int `json:"total_chunks"`
|
||||
MD5 string `json:"md5,omitempty"`
|
||||
}
|
||||
|
||||
type wecomUploadMediaInitResponse struct {
|
||||
UploadID string `json:"upload_id"`
|
||||
}
|
||||
|
||||
type wecomUploadMediaChunkBody struct {
|
||||
UploadID string `json:"upload_id"`
|
||||
ChunkIndex int `json:"chunk_index"`
|
||||
Base64Data string `json:"base64_data"`
|
||||
}
|
||||
|
||||
type wecomUploadMediaFinishBody struct {
|
||||
UploadID string `json:"upload_id"`
|
||||
}
|
||||
|
||||
type wecomUploadMediaFinishResponse struct {
|
||||
Type string `json:"type"`
|
||||
MediaID string `json:"media_id"`
|
||||
CreatedAt json.RawMessage `json:"created_at"`
|
||||
}
|
||||
|
||||
type wecomIncomingMessage struct {
|
||||
MsgID string `json:"msgid"`
|
||||
AIBotID string `json:"aibotid"`
|
||||
ChatID string `json:"chatid,omitempty"`
|
||||
ChatType string `json:"chattype,omitempty"`
|
||||
From struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"from"`
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
Image *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"image,omitempty"`
|
||||
File *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"file,omitempty"`
|
||||
Video *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"video,omitempty"`
|
||||
Voice *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"voice,omitempty"`
|
||||
Mixed *struct {
|
||||
MsgItem []struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
Image *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"image,omitempty"`
|
||||
File *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"file,omitempty"`
|
||||
} `json:"msg_item"`
|
||||
} `json:"mixed,omitempty"`
|
||||
Quote *struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
} `json:"quote,omitempty"`
|
||||
Event *struct {
|
||||
EventType string `json:"eventtype"`
|
||||
} `json:"event,omitempty"`
|
||||
}
|
||||
|
||||
func incomingChatID(msg wecomIncomingMessage) string {
|
||||
if msg.ChatID != "" {
|
||||
return msg.ChatID
|
||||
}
|
||||
return msg.From.UserID
|
||||
}
|
||||
|
||||
func incomingChatTypeCode(kind string) uint32 {
|
||||
if kind == "group" {
|
||||
return 2
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type wecomRoute struct {
|
||||
ReqID string `json:"req_id"`
|
||||
ChatID string `json:"chat_id"`
|
||||
ChatType uint32 `json:"chat_type"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type reqIDStore struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
routes map[string]wecomRoute
|
||||
}
|
||||
|
||||
func newReqIDStore(path string) *reqIDStore {
|
||||
if path == "" {
|
||||
path = defaultReqIDStorePath()
|
||||
}
|
||||
s := &reqIDStore{
|
||||
path: path,
|
||||
routes: make(map[string]wecomRoute),
|
||||
}
|
||||
_ = s.load()
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultReqIDStorePath() string {
|
||||
if home, err := os.UserHomeDir(); err == nil && home != "" {
|
||||
return filepath.Join(home, ".picoclaw", "wecom", "reqid-store.json")
|
||||
}
|
||||
return filepath.Join(os.TempDir(), "picoclaw-wecom-reqid-store.json")
|
||||
}
|
||||
|
||||
func (s *reqIDStore) Put(chatID, reqID string, chatType uint32, ttl time.Duration) error {
|
||||
if reqID == "" || chatID == "" {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deleteExpiredLocked(time.Now())
|
||||
s.routes[chatID] = wecomRoute{
|
||||
ReqID: reqID,
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *reqIDStore) Get(chatID string) (wecomRoute, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deleteExpiredLocked(time.Now())
|
||||
route, ok := s.routes[chatID]
|
||||
return route, ok
|
||||
}
|
||||
|
||||
func (s *reqIDStore) Delete(chatID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.routes, chatID)
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *reqIDStore) load() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var routes map[string]wecomRoute
|
||||
if err := json.Unmarshal(data, &routes); err != nil {
|
||||
return err
|
||||
}
|
||||
s.routes = routes
|
||||
s.deleteExpiredLocked(time.Now())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *reqIDStore) deleteExpiredLocked(now time.Time) {
|
||||
for chatID, route := range s.routes {
|
||||
if !route.ExpiresAt.IsZero() && now.After(route.ExpiresAt) {
|
||||
delete(s.routes, chatID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *reqIDStore) saveLocked() error {
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(s.routes, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, data, 0o600)
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReqIDStorePersistsRoutes(t *testing.T) {
|
||||
storePath := filepath.Join(t.TempDir(), "reqids.json")
|
||||
store := newReqIDStore(storePath)
|
||||
if err := store.Put("chat-1", "req-1", 2, time.Hour); err != nil {
|
||||
t.Fatalf("Put() error = %v", err)
|
||||
}
|
||||
|
||||
reloaded := newReqIDStore(storePath)
|
||||
route, ok := reloaded.Get("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected persisted route to be loaded")
|
||||
}
|
||||
if route.ChatID != "chat-1" || route.ReqID != "req-1" || route.ChatType != 2 {
|
||||
t.Fatalf("loaded route = %+v", route)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,970 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomConnectTimeout = 15 * time.Second
|
||||
wecomCommandTimeout = 10 * time.Second
|
||||
wecomUploadTimeout = 30 * time.Second
|
||||
wecomHeartbeatInterval = 30 * time.Second
|
||||
wecomStreamMaxDuration = 5*time.Minute + 30*time.Second
|
||||
wecomStreamMinInterval = 500 * time.Millisecond
|
||||
wecomRouteTTL = 30 * time.Minute
|
||||
wecomMediaTimeout = 30 * time.Second
|
||||
wecomRecentMessageMax = 1000
|
||||
)
|
||||
|
||||
type WeComChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComConfig
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
conn *websocket.Conn
|
||||
connMu sync.Mutex
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pending map[string]chan wecomEnvelope
|
||||
|
||||
turnsMu sync.Mutex
|
||||
turns map[string][]wecomTurn
|
||||
|
||||
recent *recentMessageSet
|
||||
routes *reqIDStore
|
||||
mediaClient *http.Client
|
||||
commandSend func(wecomCommand, time.Duration) (wecomEnvelope, error)
|
||||
}
|
||||
|
||||
type wecomTurn struct {
|
||||
ReqID string
|
||||
ChatID string
|
||||
ChatType uint32
|
||||
StreamID string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type wecomStreamer struct {
|
||||
channel *WeComChannel
|
||||
chatID string
|
||||
turn wecomTurn
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
lastSentAt time.Time
|
||||
content string
|
||||
}
|
||||
|
||||
type recentMessageSet struct {
|
||||
mu sync.Mutex
|
||||
seen map[string]struct{}
|
||||
ring []string
|
||||
idx int
|
||||
}
|
||||
|
||||
func newRecentMessageSet(capacity int) *recentMessageSet {
|
||||
if capacity <= 0 {
|
||||
capacity = wecomRecentMessageMax
|
||||
}
|
||||
return &recentMessageSet{
|
||||
seen: make(map[string]struct{}, capacity),
|
||||
ring: make([]string, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *recentMessageSet) Mark(id string) bool {
|
||||
if id == "" {
|
||||
return true
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.seen[id]; ok {
|
||||
return false
|
||||
}
|
||||
if old := s.ring[s.idx]; old != "" {
|
||||
delete(s.seen, old)
|
||||
}
|
||||
s.ring[s.idx] = id
|
||||
s.idx = (s.idx + 1) % len(s.ring)
|
||||
s.seen[id] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChannel, error) {
|
||||
if cfg.BotID == "" || cfg.Secret() == "" {
|
||||
return nil, fmt.Errorf("wecom bot_id and secret are required")
|
||||
}
|
||||
if cfg.WebSocketURL == "" {
|
||||
cfg.WebSocketURL = wecomDefaultWebSocketURL
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"wecom",
|
||||
cfg,
|
||||
messageBus,
|
||||
cfg.AllowFrom,
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
pending: make(map[string]chan wecomEnvelope),
|
||||
turns: make(map[string][]wecomTurn),
|
||||
recent: newRecentMessageSet(wecomRecentMessageMax),
|
||||
routes: newReqIDStore(""),
|
||||
mediaClient: &http.Client{Timeout: wecomMediaTimeout},
|
||||
}
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) Name() string { return "wecom" }
|
||||
|
||||
func (c *WeComChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Starting WeCom channel...")
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
c.SetRunning(true)
|
||||
go c.connectLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) Stop(_ context.Context) error {
|
||||
logger.InfoC("wecom", "Stopping WeCom channel...")
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.connMu.Lock()
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
c.clearTurns()
|
||||
c.SetRunning(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) BeginStream(_ context.Context, chatID string) (channels.Streamer, error) {
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
turn, ok := c.getTurn(chatID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("wecom streaming unavailable: no active turn")
|
||||
}
|
||||
if time.Since(turn.CreatedAt) > wecomStreamMaxDuration {
|
||||
c.consumeTurn(chatID, turn)
|
||||
return nil, fmt.Errorf("wecom streaming unavailable: turn expired")
|
||||
}
|
||||
|
||||
return &wecomStreamer{
|
||||
channel: c,
|
||||
chatID: chatID,
|
||||
turn: turn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if content == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if turn, ok := c.getTurn(msg.ChatID); ok {
|
||||
if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration {
|
||||
if err := c.sendStreamReply(turn, content); err == nil {
|
||||
c.consumeTurn(msg.ChatID, turn)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
c.consumeTurn(msg.ChatID, turn)
|
||||
}
|
||||
|
||||
if route, ok := c.routes.Get(msg.ChatID); ok {
|
||||
if err := c.sendActivePush(route.ChatID, route.ChatType, content); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.sendActivePush(msg.ChatID, 0, content); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
route, chatType, hasTurn := c.resolveMediaRoute(msg.ChatID)
|
||||
chatID := route.ChatID
|
||||
if chatID == "" {
|
||||
chatID = msg.ChatID
|
||||
}
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
if strings.TrimSpace(part.Ref) == "" {
|
||||
if caption := strings.TrimSpace(part.Caption); caption != "" {
|
||||
if err := c.sendActivePush(chatID, chatType, caption); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
localPath, filename, contentType, cleanup, err := c.resolveOutboundPart(ctx, part)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wecom resolve media %q: %v: %w", part.Ref, err, channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
func() {
|
||||
if cleanup != nil {
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
uploaded, uploadErr := c.uploadOutboundMedia(ctx, localPath, filename, contentType, part)
|
||||
if uploadErr != nil {
|
||||
logger.WarnCF("wecom", "Falling back to placeholder after media upload failure", map[string]any{
|
||||
"chat_id": chatID,
|
||||
"ref": part.Ref,
|
||||
"filename": filename,
|
||||
"content_type": contentType,
|
||||
"error": uploadErr.Error(),
|
||||
})
|
||||
if hasTurn {
|
||||
if finishErr := c.sendStreamChunk(route, true, ""); finishErr != nil {
|
||||
err = finishErr
|
||||
return
|
||||
}
|
||||
c.deleteTurn(msg.ChatID)
|
||||
hasTurn = false
|
||||
}
|
||||
err = c.sendActivePush(chatID, chatType, fallbackWeComMediaText(part, "", filename))
|
||||
return
|
||||
}
|
||||
|
||||
if hasTurn {
|
||||
err = c.sendTurnMedia(route, uploaded)
|
||||
c.deleteTurn(msg.ChatID)
|
||||
hasTurn = false
|
||||
} else {
|
||||
err = c.sendActiveMedia(chatID, chatType, uploaded)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if caption := strings.TrimSpace(part.Caption); caption != "" {
|
||||
err = c.sendActivePush(chatID, chatType, caption)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) connectLoop() {
|
||||
backoff := time.Second
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err := c.runConnection(); err != nil {
|
||||
logger.WarnCF("wecom", "WeCom connection lost", map[string]any{
|
||||
"error": err.Error(),
|
||||
"backoff": backoff.String(),
|
||||
})
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
if backoff < time.Minute {
|
||||
backoff *= 2
|
||||
if backoff > time.Minute {
|
||||
backoff = time.Minute
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) runConnection() error {
|
||||
dialCtx, cancel := context.WithTimeout(c.ctx, wecomConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, resp, err := websocket.DefaultDialer.DialContext(dialCtx, c.config.WebSocketURL, nil)
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
||||
}
|
||||
|
||||
c.connMu.Lock()
|
||||
c.conn = conn
|
||||
c.connMu.Unlock()
|
||||
defer func() {
|
||||
c.connMu.Lock()
|
||||
if c.conn == conn {
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
_ = conn.Close()
|
||||
c.clearTurns()
|
||||
}()
|
||||
|
||||
readErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
readErrCh <- c.readLoop(conn)
|
||||
}()
|
||||
|
||||
if writeErr := c.writeAndWait(conn, wecomCommand{
|
||||
Cmd: wecomCmdSubscribe,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: map[string]string{
|
||||
"bot_id": c.config.BotID,
|
||||
"secret": c.config.Secret(),
|
||||
},
|
||||
}, wecomCommandTimeout); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
heartbeatDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(heartbeatDone)
|
||||
c.heartbeatLoop(conn)
|
||||
}()
|
||||
|
||||
err = <-readErrCh
|
||||
_ = conn.Close()
|
||||
<-heartbeatDone
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *WeComChannel) heartbeatLoop(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(wecomHeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.writeAndWait(conn, wecomCommand{
|
||||
Cmd: wecomCmdPing,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
}, wecomCommandTimeout); err != nil {
|
||||
logger.WarnCF("wecom", "Heartbeat failed", map[string]any{"error": err.Error()})
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) readLoop(conn *websocket.Conn) error {
|
||||
for {
|
||||
_, raw, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
||||
}
|
||||
}
|
||||
|
||||
var env wecomEnvelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to parse WebSocket message", map[string]any{"error": err.Error()})
|
||||
continue
|
||||
}
|
||||
|
||||
if env.Cmd == "" && env.Headers.ReqID != "" {
|
||||
c.pendingMu.Lock()
|
||||
ch, ok := c.pending[env.Headers.ReqID]
|
||||
if ok {
|
||||
delete(c.pending, env.Headers.ReqID)
|
||||
}
|
||||
c.pendingMu.Unlock()
|
||||
if ok {
|
||||
ch <- env
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleEnvelope(env)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) handleEnvelope(env wecomEnvelope) {
|
||||
switch env.Cmd {
|
||||
case wecomCmdMsgCallback:
|
||||
c.handleMessageCallback(env)
|
||||
case wecomCmdEventCallback:
|
||||
c.handleEventCallback(env)
|
||||
default:
|
||||
logger.DebugCF("wecom", "Ignoring unsupported WeCom command", map[string]any{"cmd": env.Cmd})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) handleEventCallback(env wecomEnvelope) {
|
||||
var msg wecomIncomingMessage
|
||||
if err := json.Unmarshal(env.Body, &msg); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to parse WeCom event callback", map[string]any{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) handleMessageCallback(env wecomEnvelope) {
|
||||
var msg wecomIncomingMessage
|
||||
if err := json.Unmarshal(env.Body, &msg); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to parse WeCom message callback", map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !c.recent.Mark(msg.MsgID) {
|
||||
return
|
||||
}
|
||||
|
||||
reqID := env.Headers.ReqID
|
||||
if reqID == "" {
|
||||
logger.WarnC("wecom", "WeCom message callback missing req_id")
|
||||
return
|
||||
}
|
||||
if msg.Event != nil && msg.Event.EventType != "" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.dispatchIncoming(reqID, msg); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to dispatch WeCom message", map[string]any{
|
||||
"req_id": reqID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
_ = c.respondImmediate(reqID, "The WeCom message could not be processed.")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) error {
|
||||
senderID := msg.From.UserID
|
||||
if senderID == "" {
|
||||
senderID = "unknown"
|
||||
}
|
||||
actualChatID := incomingChatID(msg)
|
||||
chatType := incomingChatTypeCode(msg.ChatType)
|
||||
peerKind := "direct"
|
||||
if msg.ChatType == "group" {
|
||||
peerKind = "group"
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "wecom",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
|
||||
DisplayName: senderID,
|
||||
}
|
||||
|
||||
var (
|
||||
content string
|
||||
quoteText string
|
||||
mediaRefs []string
|
||||
err error
|
||||
)
|
||||
scope := channels.BuildMediaScope("wecom", actualChatID, msg.MsgID)
|
||||
switch msg.MsgType {
|
||||
case "text":
|
||||
if msg.Text != nil {
|
||||
content = strings.TrimSpace(msg.Text.Content)
|
||||
}
|
||||
case "voice":
|
||||
if msg.Voice != nil {
|
||||
content = strings.TrimSpace(msg.Voice.Content)
|
||||
}
|
||||
case "image":
|
||||
content = "[image]"
|
||||
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
||||
url: msg.Image.URL,
|
||||
aesKey: msg.Image.AESKey,
|
||||
}, "image", ".jpg")
|
||||
case "file":
|
||||
content = "[file]"
|
||||
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
||||
url: msg.File.URL,
|
||||
aesKey: msg.File.AESKey,
|
||||
}, "file", ".bin")
|
||||
case "video":
|
||||
content = "[video]"
|
||||
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
||||
url: msg.Video.URL,
|
||||
aesKey: msg.Video.AESKey,
|
||||
}, "video", ".mp4")
|
||||
case "mixed":
|
||||
content, mediaRefs, err = c.collectMixedMedia(c.ctx, scope, msg)
|
||||
default:
|
||||
return c.respondImmediate(reqID, "Unsupported WeCom message type: "+msg.MsgType)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Quote != nil && msg.Quote.Text != nil {
|
||||
quoteText = strings.TrimSpace(msg.Quote.Text.Content)
|
||||
if content == "" {
|
||||
content = quoteText
|
||||
}
|
||||
}
|
||||
if content == "" && len(mediaRefs) == 0 {
|
||||
return c.respondImmediate(reqID, "The WeCom message did not contain usable content.")
|
||||
}
|
||||
|
||||
turn := wecomTurn{
|
||||
ReqID: reqID,
|
||||
ChatID: actualChatID,
|
||||
ChatType: chatType,
|
||||
StreamID: randomID(10),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
c.queueTurn(actualChatID, turn)
|
||||
if err := c.routes.Put(actualChatID, reqID, chatType, wecomRouteTTL); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to persist req_id route", map[string]any{
|
||||
"chat_id": actualChatID,
|
||||
"req_id": reqID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
opening := ""
|
||||
if c.config.SendThinkingMessage {
|
||||
opening = "Processing..."
|
||||
}
|
||||
if err := c.sendStreamChunk(turn, false, opening); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
|
||||
metadata := map[string]string{
|
||||
"channel": "wecom",
|
||||
"req_id": reqID,
|
||||
"chat_id": actualChatID,
|
||||
"chat_type": msg.ChatType,
|
||||
"msg_id": msg.MsgID,
|
||||
"msg_type": msg.MsgType,
|
||||
}
|
||||
if quoteText != "" {
|
||||
metadata["quote_text"] = quoteText
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) collectSingleMedia(
|
||||
ctx context.Context,
|
||||
scope, msgID string,
|
||||
payload interface {
|
||||
GetURL() string
|
||||
GetAESKey() string
|
||||
},
|
||||
label, fallbackExt string,
|
||||
) ([]string, error) {
|
||||
if payload == nil || payload.GetURL() == "" {
|
||||
return nil, fmt.Errorf("%s payload is empty", label)
|
||||
}
|
||||
ref, err := c.storeRemoteMedia(ctx, scope, msgID, payload.GetURL(), payload.GetAESKey(), fallbackExt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{ref}, nil
|
||||
}
|
||||
|
||||
type mediaPayload struct {
|
||||
url string
|
||||
aesKey string
|
||||
}
|
||||
|
||||
func (p *mediaPayload) GetURL() string { return p.url }
|
||||
func (p *mediaPayload) GetAESKey() string { return p.aesKey }
|
||||
|
||||
func (c *WeComChannel) collectMixedMedia(
|
||||
ctx context.Context,
|
||||
scope string,
|
||||
msg wecomIncomingMessage,
|
||||
) (string, []string, error) {
|
||||
if msg.Mixed == nil {
|
||||
return "", nil, fmt.Errorf("mixed message is empty")
|
||||
}
|
||||
|
||||
var textParts []string
|
||||
var refs []string
|
||||
for idx, item := range msg.Mixed.MsgItem {
|
||||
switch item.MsgType {
|
||||
case "text":
|
||||
if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" {
|
||||
textParts = append(textParts, strings.TrimSpace(item.Text.Content))
|
||||
}
|
||||
case "image":
|
||||
if item.Image != nil && item.Image.URL != "" {
|
||||
ref, err := c.storeRemoteMedia(
|
||||
ctx,
|
||||
scope,
|
||||
fmt.Sprintf("%s-%d", msg.MsgID, idx),
|
||||
item.Image.URL,
|
||||
item.Image.AESKey,
|
||||
".jpg",
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
case "file":
|
||||
if item.File != nil && item.File.URL != "" {
|
||||
ref, err := c.storeRemoteMedia(
|
||||
ctx,
|
||||
scope,
|
||||
fmt.Sprintf("%s-%d", msg.MsgID, idx),
|
||||
item.File.URL,
|
||||
item.File.AESKey,
|
||||
".bin",
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.Join(textParts, "\n")
|
||||
if content == "" && len(refs) > 0 {
|
||||
content = "[media]"
|
||||
}
|
||||
return content, refs, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) respondImmediate(reqID, content string) error {
|
||||
turn := wecomTurn{
|
||||
ReqID: reqID,
|
||||
StreamID: randomID(10),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return c.sendStreamChunk(turn, true, content)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error {
|
||||
return c.sendStreamChunk(turn, true, content)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error {
|
||||
return c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdRespondMsg,
|
||||
Headers: wecomHeaders{ReqID: turn.ReqID},
|
||||
Body: wecomRespondMsgBody{
|
||||
MsgType: "stream",
|
||||
Stream: &wecomStreamContent{
|
||||
ID: turn.StreamID,
|
||||
Finish: finish,
|
||||
Content: content,
|
||||
},
|
||||
},
|
||||
}, wecomCommandTimeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendTurnMedia(turn wecomTurn, uploaded *wecomOutboundMedia) error {
|
||||
if uploaded == nil {
|
||||
return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed)
|
||||
}
|
||||
if err := c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdRespondMsg,
|
||||
Headers: wecomHeaders{ReqID: turn.ReqID},
|
||||
Body: uploaded.respondBody(),
|
||||
}, wecomCommandTimeout); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.sendStreamChunk(turn, true, "")
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error {
|
||||
if strings.TrimSpace(chatID) == "" {
|
||||
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
|
||||
}
|
||||
return c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdSendMsg,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: wecomSendMsgBody{
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
MsgType: "markdown",
|
||||
Markdown: &wecomMarkdownContent{Content: content},
|
||||
},
|
||||
}, wecomCommandTimeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendActiveMedia(chatID string, chatType uint32, uploaded *wecomOutboundMedia) error {
|
||||
if strings.TrimSpace(chatID) == "" {
|
||||
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
|
||||
}
|
||||
if uploaded == nil {
|
||||
return fmt.Errorf("wecom outbound media is nil: %w", channels.ErrSendFailed)
|
||||
}
|
||||
return c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdSendMsg,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: uploaded.sendBody(chatID, chatType),
|
||||
}, wecomCommandTimeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error {
|
||||
_, err := c.sendCommandAck(cmd, timeout)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendCommandAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) {
|
||||
if c.commandSend != nil {
|
||||
return c.commandSend(cmd, timeout)
|
||||
}
|
||||
return c.writeCurrentAck(cmd, timeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) writeCurrentAck(cmd wecomCommand, timeout time.Duration) (wecomEnvelope, error) {
|
||||
c.connMu.Lock()
|
||||
conn := c.conn
|
||||
c.connMu.Unlock()
|
||||
if conn == nil {
|
||||
return wecomEnvelope{}, fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary)
|
||||
}
|
||||
return c.writeAndWaitAck(conn, cmd, timeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error {
|
||||
_, err := c.writeAndWaitAck(conn, cmd, timeout)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *WeComChannel) writeAndWaitAck(
|
||||
conn *websocket.Conn,
|
||||
cmd wecomCommand,
|
||||
timeout time.Duration,
|
||||
) (wecomEnvelope, error) {
|
||||
if cmd.Headers.ReqID == "" {
|
||||
cmd.Headers.ReqID = randomID(10)
|
||||
}
|
||||
waitCh := make(chan wecomEnvelope, 1)
|
||||
c.pendingMu.Lock()
|
||||
c.pending[cmd.Headers.ReqID] = waitCh
|
||||
c.pendingMu.Unlock()
|
||||
defer func() {
|
||||
c.pendingMu.Lock()
|
||||
delete(c.pending, cmd.Headers.ReqID)
|
||||
c.pendingMu.Unlock()
|
||||
}()
|
||||
|
||||
data, err := json.Marshal(cmd)
|
||||
if err != nil {
|
||||
return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrSendFailed, err)
|
||||
}
|
||||
c.connMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.connMu.Unlock()
|
||||
if err != nil {
|
||||
return wecomEnvelope{}, fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case env := <-waitCh:
|
||||
if env.ErrCode != 0 {
|
||||
return wecomEnvelope{}, fmt.Errorf(
|
||||
"%w: wecom errcode=%d errmsg=%s",
|
||||
channels.ErrTemporary,
|
||||
env.ErrCode,
|
||||
env.ErrMsg,
|
||||
)
|
||||
}
|
||||
return env, nil
|
||||
case <-timer.C:
|
||||
return wecomEnvelope{}, fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary)
|
||||
case <-c.ctx.Done():
|
||||
return wecomEnvelope{}, c.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) getTurn(chatID string) (wecomTurn, bool) {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
queue := c.turns[chatID]
|
||||
if len(queue) == 0 {
|
||||
return wecomTurn{}, false
|
||||
}
|
||||
return queue[0], true
|
||||
}
|
||||
|
||||
func (c *WeComChannel) deleteTurn(chatID string) {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
queue := c.turns[chatID]
|
||||
if len(queue) <= 1 {
|
||||
delete(c.turns, chatID)
|
||||
return
|
||||
}
|
||||
c.turns[chatID] = queue[1:]
|
||||
}
|
||||
|
||||
func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
c.turns[chatID] = append(c.turns[chatID], turn)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) consumeTurn(chatID string, turn wecomTurn) bool {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
|
||||
queue := c.turns[chatID]
|
||||
if len(queue) == 0 {
|
||||
return false
|
||||
}
|
||||
current := queue[0]
|
||||
if current.ReqID != turn.ReqID || current.StreamID != turn.StreamID {
|
||||
return false
|
||||
}
|
||||
if len(queue) == 1 {
|
||||
delete(c.turns, chatID)
|
||||
return true
|
||||
}
|
||||
c.turns[chatID] = queue[1:]
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *WeComChannel) clearTurns() {
|
||||
c.turnsMu.Lock()
|
||||
c.turns = make(map[string][]wecomTurn)
|
||||
c.turnsMu.Unlock()
|
||||
}
|
||||
|
||||
func randomID(n int) string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
if n <= 0 {
|
||||
n = 10
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
for i := range buf {
|
||||
v, _ := rand.Int(rand.Reader, big.NewInt(int64(len(alphabet))))
|
||||
buf[i] = alphabet[v.Int64()]
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (s *wecomStreamer) Update(ctx context.Context, content string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
if err := s.validateActiveTurn(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.lastSentAt.IsZero() {
|
||||
wait := time.Until(s.lastSentAt.Add(wecomStreamMinInterval))
|
||||
if wait > 0 {
|
||||
timer := time.NewTimer(wait)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.channel.sendStreamChunk(s.turn, false, content); err != nil {
|
||||
return err
|
||||
}
|
||||
s.content = content
|
||||
s.lastSentAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wecomStreamer) Finalize(ctx context.Context, content string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
if err := s.validateActiveTurn(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.channel.sendStreamChunk(s.turn, true, content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.content = content
|
||||
s.closed = true
|
||||
s.channel.consumeTurn(s.chatID, s.turn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wecomStreamer) Cancel(_ context.Context) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
if s.validateActiveTurn() == nil {
|
||||
_ = s.channel.sendStreamChunk(s.turn, true, s.content)
|
||||
s.channel.consumeTurn(s.chatID, s.turn)
|
||||
}
|
||||
s.closed = true
|
||||
}
|
||||
|
||||
func (s *wecomStreamer) validateActiveTurn() error {
|
||||
if time.Since(s.turn.CreatedAt) > wecomStreamMaxDuration {
|
||||
s.channel.consumeTurn(s.chatID, s.turn)
|
||||
return fmt.Errorf("wecom streaming unavailable: turn expired")
|
||||
}
|
||||
current, ok := s.channel.getTurn(s.chatID)
|
||||
if !ok || current.ReqID != s.turn.ReqID || current.StreamID != s.turn.StreamID {
|
||||
return fmt.Errorf("wecom streaming unavailable: turn no longer active")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,660 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := newTestWeComChannel(t, messageBus)
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
msg := wecomIncomingMessage{
|
||||
MsgID: "msg-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: "direct",
|
||||
MsgType: "text",
|
||||
Text: &struct {
|
||||
Content string `json:"content"`
|
||||
}{Content: "hello"},
|
||||
}
|
||||
msg.From.UserID = "user-1"
|
||||
|
||||
if err := ch.dispatchIncoming("req-1", msg); err != nil {
|
||||
t.Fatalf("dispatchIncoming() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case inbound := <-messageBus.InboundChan():
|
||||
if inbound.ChatID != "chat-1" {
|
||||
t.Fatalf("inbound ChatID = %q, want chat-1", inbound.ChatID)
|
||||
}
|
||||
if inbound.MessageID != "msg-1" {
|
||||
t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID)
|
||||
}
|
||||
if inbound.Peer.ID != "chat-1" {
|
||||
t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID)
|
||||
}
|
||||
if inbound.Metadata["req_id"] != "req-1" {
|
||||
t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"])
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected inbound message to be published")
|
||||
}
|
||||
|
||||
turn, ok := ch.getTurn("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected queued turn for chat-1")
|
||||
}
|
||||
if turn.ReqID != "req-1" {
|
||||
t.Fatalf("turn.ReqID = %q, want req-1", turn.ReqID)
|
||||
}
|
||||
|
||||
route, ok := ch.routes.Get("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected persisted route for chat-1")
|
||||
}
|
||||
if route.ReqID != "req-1" || route.ChatType != 1 {
|
||||
t.Fatalf("route = %+v", route)
|
||||
}
|
||||
|
||||
if len(commands) != 1 {
|
||||
t.Fatalf("expected 1 opening command, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdRespondMsg {
|
||||
t.Fatalf("opening command = %q, want %q", commands[0].Cmd, wecomCmdRespondMsg)
|
||||
}
|
||||
if commands[0].Headers.ReqID != "req-1" {
|
||||
t.Fatalf("opening req_id = %q, want req-1", commands[0].Headers.ReqID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewChannel_DoesNotRegisterMessageSplitLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
if got := ch.MaxMessageLength(); got != 0 {
|
||||
t.Fatalf("MaxMessageLength() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeginStream_UpdateAndFinalize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-1",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
streamer, err := ch.BeginStream(context.Background(), "chat-1")
|
||||
if err != nil {
|
||||
t.Fatalf("BeginStream() error = %v", err)
|
||||
}
|
||||
if err := streamer.Update(context.Background(), "draft"); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
if err := streamer.Finalize(context.Background(), "final"); err != nil {
|
||||
t.Fatalf("Finalize() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 2 {
|
||||
t.Fatalf("expected 2 commands, got %d", len(commands))
|
||||
}
|
||||
for i, wantFinish := range []bool{false, true} {
|
||||
if commands[i].Cmd != wecomCmdRespondMsg {
|
||||
t.Fatalf("command[%d].Cmd = %q, want %q", i, commands[i].Cmd, wecomCmdRespondMsg)
|
||||
}
|
||||
body, ok := commands[i].Body.(wecomRespondMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("command[%d] body type = %T", i, commands[i].Body)
|
||||
}
|
||||
if body.Stream == nil {
|
||||
t.Fatalf("command[%d] missing stream body", i)
|
||||
}
|
||||
if body.Stream.ID != "stream-1" {
|
||||
t.Fatalf("command[%d] stream id = %q, want stream-1", i, body.Stream.ID)
|
||||
}
|
||||
if body.Stream.Finish != wantFinish {
|
||||
t.Fatalf("command[%d] finish = %v, want %v", i, body.Stream.Finish, wantFinish)
|
||||
}
|
||||
}
|
||||
if body := commands[0].Body.(wecomRespondMsgBody); body.Stream.Content != "draft" {
|
||||
t.Fatalf("update content = %q, want draft", body.Stream.Content)
|
||||
}
|
||||
if body := commands[1].Body.(wecomRespondMsgBody); body.Stream.Content != "final" {
|
||||
t.Fatalf("final content = %q, want final", body.Stream.Content)
|
||||
}
|
||||
if _, ok := ch.getTurn("chat-1"); ok {
|
||||
t.Fatal("expected turn to be consumed after Finalize")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-1",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-2",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-2",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
if err := ch.routes.Put("chat-1", "req-2", 1, time.Hour); err != nil {
|
||||
t.Fatalf("Put() error = %v", err)
|
||||
}
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
if len(commands) == 1 && cmd.Cmd == wecomCmdRespondMsg {
|
||||
return wecomEnvelope{}, errors.New("stream send failed")
|
||||
}
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
if err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 2 {
|
||||
t.Fatalf("expected 2 commands, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdRespondMsg || commands[0].Headers.ReqID != "req-1" {
|
||||
t.Fatalf("first command = %+v", commands[0])
|
||||
}
|
||||
if commands[1].Cmd != wecomCmdSendMsg {
|
||||
t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdSendMsg)
|
||||
}
|
||||
body, ok := commands[1].Body.(wecomSendMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected send body type %T", commands[1].Body)
|
||||
}
|
||||
if body.ChatID != "chat-1" {
|
||||
t.Fatalf("send chatid = %q, want chat-1", body.ChatID)
|
||||
}
|
||||
if body.ChatType != 1 {
|
||||
t.Fatalf("send chat_type = %d, want 1", body.ChatType)
|
||||
}
|
||||
|
||||
nextTurn, ok := ch.getTurn("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected second turn to remain queued")
|
||||
}
|
||||
if nextTurn.ReqID != "req-2" {
|
||||
t.Fatalf("next queued req_id = %q, want req-2", nextTurn.ReqID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_DoesNotSplitStreamReply(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-1",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
content := strings.Repeat("\u4e2d", 30000)
|
||||
if err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 1 {
|
||||
t.Fatalf("expected 1 stream command, got %d", len(commands))
|
||||
}
|
||||
body, ok := commands[0].Body.(wecomRespondMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected body type %T", commands[0].Body)
|
||||
}
|
||||
if body.Stream == nil || !body.Stream.Finish {
|
||||
t.Fatalf("stream body = %+v", body.Stream)
|
||||
}
|
||||
if body.Stream.Content != content {
|
||||
t.Fatalf("stream content length = %d, want %d", len(body.Stream.Content), len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_DoesNotSplitActivePush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
|
||||
content := strings.Repeat("a", 30000)
|
||||
if err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 1 {
|
||||
t.Fatalf("expected 1 send command, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdSendMsg {
|
||||
t.Fatalf("command = %q, want %q", commands[0].Cmd, wecomCmdSendMsg)
|
||||
}
|
||||
body, ok := commands[0].Body.(wecomSendMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected body type %T", commands[0].Body)
|
||||
}
|
||||
if body.Markdown == nil || body.Markdown.Content != content {
|
||||
t.Fatalf("markdown content length = %d, want %d", len(body.Markdown.Content), len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_SendsActiveImage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
imageData := wecomTestJPEGData(t)
|
||||
imagePath := filepath.Join(t.TempDir(), "photo.jpg")
|
||||
if err := os.WriteFile(imagePath, imageData, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
ref, err := store.Store(imagePath, media.MediaMeta{
|
||||
Filename: "photo.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
Source: "test",
|
||||
CleanupPolicy: media.CleanupPolicyForgetOnly,
|
||||
}, "scope-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
switch cmd.Cmd {
|
||||
case wecomCmdUploadMediaInit:
|
||||
return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-1"}), nil
|
||||
case wecomCmdUploadMediaEnd:
|
||||
return wecomTestAck(wecomUploadMediaFinishResponse{
|
||||
Type: "image",
|
||||
MediaID: "media-1",
|
||||
}), nil
|
||||
default:
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
}
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Ref: ref,
|
||||
Type: "image",
|
||||
Filename: "photo.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 4 {
|
||||
t.Fatalf("expected 4 commands, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdUploadMediaInit {
|
||||
t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit)
|
||||
}
|
||||
initBody, ok := commands[0].Body.(wecomUploadMediaInitBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected init body type %T", commands[0].Body)
|
||||
}
|
||||
if initBody.Type != "image" || initBody.Filename != "photo.jpg" || initBody.TotalChunks != 1 {
|
||||
t.Fatalf("init body = %+v", initBody)
|
||||
}
|
||||
if commands[1].Cmd != wecomCmdUploadMediaChunk {
|
||||
t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk)
|
||||
}
|
||||
chunkBody, ok := commands[1].Body.(wecomUploadMediaChunkBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected chunk body type %T", commands[1].Body)
|
||||
}
|
||||
if chunkBody.UploadID != "upload-1" || chunkBody.ChunkIndex != 0 || chunkBody.Base64Data == "" {
|
||||
t.Fatalf("chunk body = %+v", chunkBody)
|
||||
}
|
||||
if commands[2].Cmd != wecomCmdUploadMediaEnd {
|
||||
t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd)
|
||||
}
|
||||
if commands[3].Cmd != wecomCmdSendMsg {
|
||||
t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg)
|
||||
}
|
||||
|
||||
body, ok := commands[3].Body.(wecomSendMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected send body type %T", commands[3].Body)
|
||||
}
|
||||
if body.MsgType != "image" || body.Image == nil {
|
||||
t.Fatalf("send body = %+v", body)
|
||||
}
|
||||
if body.ChatID != "chat-1" {
|
||||
t.Fatalf("send chatid = %q, want chat-1", body.ChatID)
|
||||
}
|
||||
if body.Image.MediaID != "media-1" {
|
||||
t.Fatalf("image media_id = %q, want media-1", body.Image.MediaID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_UsesTurnImageAndFinishesStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
imageData := wecomTestJPEGData(t)
|
||||
imagePath := filepath.Join(t.TempDir(), "reply.jpg")
|
||||
if err := os.WriteFile(imagePath, imageData, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
ref, err := store.Store(imagePath, media.MediaMeta{
|
||||
Filename: "reply.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
Source: "test",
|
||||
CleanupPolicy: media.CleanupPolicyForgetOnly,
|
||||
}, "scope-2")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-1",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
putErr := ch.routes.Put("chat-1", "req-1", 1, time.Hour)
|
||||
if putErr != nil {
|
||||
t.Fatalf("Put() error = %v", putErr)
|
||||
}
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
switch cmd.Cmd {
|
||||
case wecomCmdUploadMediaInit:
|
||||
return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-2"}), nil
|
||||
case wecomCmdUploadMediaEnd:
|
||||
return wecomTestAck(wecomUploadMediaFinishResponse{
|
||||
Type: "image",
|
||||
MediaID: "media-2",
|
||||
}), nil
|
||||
default:
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
}
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Ref: ref,
|
||||
Type: "image",
|
||||
Filename: "reply.jpg",
|
||||
ContentType: "image/jpeg",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 5 {
|
||||
t.Fatalf("expected 5 commands, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdUploadMediaInit {
|
||||
t.Fatalf("first command = %+v", commands[0])
|
||||
}
|
||||
if commands[1].Cmd != wecomCmdUploadMediaChunk {
|
||||
t.Fatalf("second command = %+v", commands[1])
|
||||
}
|
||||
if commands[2].Cmd != wecomCmdUploadMediaEnd {
|
||||
t.Fatalf("third command = %+v", commands[2])
|
||||
}
|
||||
if commands[3].Cmd != wecomCmdRespondMsg || commands[3].Headers.ReqID != "req-1" {
|
||||
t.Fatalf("fourth command = %+v", commands[3])
|
||||
}
|
||||
if commands[4].Cmd != wecomCmdRespondMsg || commands[4].Headers.ReqID != "req-1" {
|
||||
t.Fatalf("fifth command = %+v", commands[4])
|
||||
}
|
||||
|
||||
imageBody, ok := commands[3].Body.(wecomRespondMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected image body type %T", commands[3].Body)
|
||||
}
|
||||
if imageBody.MsgType != "image" || imageBody.Image == nil {
|
||||
t.Fatalf("image body = %+v", imageBody)
|
||||
}
|
||||
if imageBody.Image.MediaID != "media-2" {
|
||||
t.Fatalf("image media_id = %q, want media-2", imageBody.Image.MediaID)
|
||||
}
|
||||
|
||||
streamBody, ok := commands[4].Body.(wecomRespondMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected finish body type %T", commands[4].Body)
|
||||
}
|
||||
if streamBody.MsgType != "stream" || streamBody.Stream == nil || !streamBody.Stream.Finish {
|
||||
t.Fatalf("finish body = %+v", streamBody)
|
||||
}
|
||||
|
||||
if _, ok := ch.getTurn("chat-1"); ok {
|
||||
t.Fatal("expected turn to be removed after media send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_SendsActiveFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
filePath := filepath.Join(t.TempDir(), "report.pdf")
|
||||
if err := os.WriteFile(filePath, []byte("%PDF-1.4"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
ref, err := store.Store(filePath, media.MediaMeta{
|
||||
Filename: "report.pdf",
|
||||
ContentType: "application/pdf",
|
||||
Source: "test",
|
||||
CleanupPolicy: media.CleanupPolicyForgetOnly,
|
||||
}, "scope-3")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) (wecomEnvelope, error) {
|
||||
commands = append(commands, cmd)
|
||||
switch cmd.Cmd {
|
||||
case wecomCmdUploadMediaInit:
|
||||
return wecomTestAck(wecomUploadMediaInitResponse{UploadID: "upload-3"}), nil
|
||||
case wecomCmdUploadMediaEnd:
|
||||
return wecomTestAck(wecomUploadMediaFinishResponse{
|
||||
Type: "file",
|
||||
MediaID: "media-3",
|
||||
}), nil
|
||||
default:
|
||||
return wecomTestAck(nil), nil
|
||||
}
|
||||
}
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-2",
|
||||
Parts: []bus.MediaPart{{
|
||||
Ref: ref,
|
||||
Type: "file",
|
||||
Filename: "report.pdf",
|
||||
ContentType: "application/pdf",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 4 {
|
||||
t.Fatalf("expected 4 commands, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdUploadMediaInit {
|
||||
t.Fatalf("first command = %q, want %q", commands[0].Cmd, wecomCmdUploadMediaInit)
|
||||
}
|
||||
initBody, ok := commands[0].Body.(wecomUploadMediaInitBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected init body type %T", commands[0].Body)
|
||||
}
|
||||
if initBody.Type != "file" || initBody.Filename != "report.pdf" {
|
||||
t.Fatalf("init body = %+v", initBody)
|
||||
}
|
||||
if commands[1].Cmd != wecomCmdUploadMediaChunk {
|
||||
t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdUploadMediaChunk)
|
||||
}
|
||||
if commands[2].Cmd != wecomCmdUploadMediaEnd {
|
||||
t.Fatalf("third command = %q, want %q", commands[2].Cmd, wecomCmdUploadMediaEnd)
|
||||
}
|
||||
if commands[3].Cmd != wecomCmdSendMsg {
|
||||
t.Fatalf("fourth command = %q, want %q", commands[3].Cmd, wecomCmdSendMsg)
|
||||
}
|
||||
|
||||
body, ok := commands[3].Body.(wecomSendMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected body type %T", commands[3].Body)
|
||||
}
|
||||
if body.MsgType != "file" || body.File == nil {
|
||||
t.Fatalf("body = %+v", body)
|
||||
}
|
||||
if body.File.MediaID != "media-3" {
|
||||
t.Fatalf("file media_id = %q, want media-3", body.File.MediaID)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.WeComConfig{BotID: "bot-1"}
|
||||
cfg.SetSecret("secret-1")
|
||||
ch, err := NewChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("NewChannel() error = %v", err)
|
||||
}
|
||||
ch.ctx = context.Background()
|
||||
ch.routes = newReqIDStore(filepath.Join(t.TempDir(), "reqids.json"))
|
||||
return ch
|
||||
}
|
||||
|
||||
func wecomTestJPEGData(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" +
|
||||
"//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" +
|
||||
"//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k="
|
||||
|
||||
return decodeTestBase64(t, jpegBase64)
|
||||
}
|
||||
|
||||
func TestDecodeWeComUploadFinish_AcceptsNumericCreatedAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp, err := decodeWeComEnvelopeBody[wecomUploadMediaFinishResponse](wecomEnvelope{
|
||||
Body: json.RawMessage(`{"type":"file","media_id":"media-1","created_at":1380000000}`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("decodeWeComEnvelopeBody() error = %v", err)
|
||||
}
|
||||
if resp.Type != "file" || resp.MediaID != "media-1" {
|
||||
t.Fatalf("response = %+v", resp)
|
||||
}
|
||||
if string(resp.CreatedAt) != "1380000000" {
|
||||
t.Fatalf("created_at = %s, want 1380000000", string(resp.CreatedAt))
|
||||
}
|
||||
}
|
||||
|
||||
func wecomTestAck(body any) wecomEnvelope {
|
||||
var raw []byte
|
||||
if body != nil {
|
||||
encoded, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
raw = encoded
|
||||
}
|
||||
return wecomEnvelope{
|
||||
ErrCode: 0,
|
||||
ErrMsg: "ok",
|
||||
Body: raw,
|
||||
}
|
||||
}
|
||||
@@ -99,13 +99,7 @@ Examples:
|
||||
- `ref:channels.line.channel_secret`
|
||||
- `ref:channels.line.channel_access_token`
|
||||
- `ref:channels.onebot.access_token`
|
||||
- `ref:channels.wecom.token`
|
||||
- `ref:channels.wecom.encoding_aes_key`
|
||||
- `ref:channels.wecom_app.corp_secret`
|
||||
- `ref:channels.wecom_app.token`
|
||||
- `ref:channels.wecom_app.encoding_aes_key`
|
||||
- `ref:channels.wecom_aibot.token`
|
||||
- `ref:channels.wecom_aibot.encoding_aes_key`
|
||||
- `ref:channels.wecom.secret`
|
||||
- `ref:channels.pico.token`
|
||||
- `ref:channels.irc.password`
|
||||
- `ref:channels.irc.nickserv_password`
|
||||
|
||||
+22
-181
@@ -321,10 +321,7 @@ type AgentDefaults struct {
|
||||
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
DefaultWeComAIBotProcessingMessage = "⏳ Processing, please wait. The results will be sent shortly."
|
||||
)
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
func (d *AgentDefaults) GetMaxMediaSize() int {
|
||||
if d.MaxMediaSize > 0 {
|
||||
@@ -364,9 +361,7 @@ type ChannelsConfig struct {
|
||||
Matrix MatrixConfig `json:"matrix"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
WeCom WeComConfig `json:"wecom"`
|
||||
WeComApp WeComAppConfig `json:"wecom_app"`
|
||||
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
|
||||
WeCom WeComConfig `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
|
||||
Weixin WeixinConfig `json:"weixin"`
|
||||
Pico PicoConfig `json:"pico"`
|
||||
PicoClient PicoClientConfig `json:"pico_client"`
|
||||
@@ -680,136 +675,28 @@ func (c *OneBotConfig) SetAccessToken(token string) {
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
type WeComGroupConfig struct {
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from,omitempty"`
|
||||
}
|
||||
|
||||
type WeComConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
|
||||
token string
|
||||
encodingAESKey string
|
||||
WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
BotID string `json:"bot_id" env:"BOT_ID"`
|
||||
secret string
|
||||
WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"`
|
||||
SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
}
|
||||
|
||||
// Token returns the WeCom token
|
||||
func (c *WeComConfig) Token() string {
|
||||
return c.token
|
||||
}
|
||||
|
||||
// SetToken sets the WeCom token
|
||||
func (c *WeComConfig) SetToken(token string) {
|
||||
c.token = token
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// EncodingAESKey returns the WeCom encoding AES key
|
||||
func (c *WeComConfig) EncodingAESKey() string {
|
||||
return c.encodingAESKey
|
||||
}
|
||||
|
||||
// SetEncodingAESKey sets the WeCom encoding AES key
|
||||
func (c *WeComConfig) SetEncodingAESKey(key string) {
|
||||
c.encodingAESKey = key
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
type WeComAppConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
|
||||
CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
|
||||
corpSecret string
|
||||
AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
|
||||
token string
|
||||
encodingAESKey string
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
}
|
||||
|
||||
// CorpSecret returns the corporate secret for WeCom app
|
||||
func (c *WeComAppConfig) CorpSecret() string {
|
||||
return c.corpSecret
|
||||
}
|
||||
|
||||
// SetCorpSecret sets the corporate secret for WeCom app
|
||||
func (c *WeComAppConfig) SetCorpSecret(secret string) {
|
||||
c.corpSecret = secret
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// Token returns the webhook token for WeCom app
|
||||
func (c *WeComAppConfig) Token() string {
|
||||
return c.token
|
||||
}
|
||||
|
||||
// SetToken sets the webhook token for WeCom app
|
||||
func (c *WeComAppConfig) SetToken(token string) {
|
||||
c.token = token
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// EncodingAESKey returns the encoding AES key for WeCom app
|
||||
func (c *WeComAppConfig) EncodingAESKey() string {
|
||||
return c.encodingAESKey
|
||||
}
|
||||
|
||||
// SetEncodingAESKey sets the encoding AES key for WeCom app
|
||||
func (c *WeComAppConfig) SetEncodingAESKey(key string) {
|
||||
c.encodingAESKey = key
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
type WeComAIBotConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
|
||||
BotID string `json:"bot_id,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_BOT_ID"`
|
||||
secret string
|
||||
token string
|
||||
encodingAESKey string
|
||||
WebhookPath string `json:"webhook_path,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
|
||||
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
|
||||
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
|
||||
ProcessingMessage string `json:"processing_message,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_PROCESSING_MESSAGE"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
}
|
||||
|
||||
// Token returns the webhook token for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) Token() string {
|
||||
return c.token
|
||||
}
|
||||
|
||||
// EncodingAESKey returns the encoding AES key for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) EncodingAESKey() string {
|
||||
return c.encodingAESKey
|
||||
}
|
||||
|
||||
// SetToken sets the token for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) SetToken(token string) {
|
||||
c.token = token
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// SetEncodingAESKey sets the encoding AES key for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) SetEncodingAESKey(key string) {
|
||||
c.encodingAESKey = key
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
func (c *WeComAIBotConfig) Secret() string {
|
||||
// Secret returns the WeCom bot secret.
|
||||
func (c *WeComConfig) Secret() string {
|
||||
return c.secret
|
||||
}
|
||||
|
||||
func (c *WeComAIBotConfig) SetSecret(secret string) {
|
||||
// SetSecret sets the WeCom bot secret.
|
||||
func (c *WeComConfig) SetSecret(secret string) {
|
||||
c.secret = secret
|
||||
c.secDirty = true
|
||||
}
|
||||
@@ -1625,39 +1512,10 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error {
|
||||
cfg.Channels.OneBot.accessToken = sec.Channels.OneBot.AccessToken
|
||||
}
|
||||
|
||||
// Handle WeCom token and encoding key
|
||||
// Handle WeCom bot secret
|
||||
if sec.Channels.WeCom != nil {
|
||||
if sec.Channels.WeCom.Token != "" {
|
||||
cfg.Channels.WeCom.token = sec.Channels.WeCom.Token
|
||||
}
|
||||
if sec.Channels.WeCom.EncodingAESKey != "" {
|
||||
cfg.Channels.WeCom.encodingAESKey = sec.Channels.WeCom.EncodingAESKey
|
||||
}
|
||||
}
|
||||
|
||||
// Handle WeCom App credentials
|
||||
if sec.Channels.WeComApp != nil {
|
||||
if sec.Channels.WeComApp.CorpSecret != "" {
|
||||
cfg.Channels.WeComApp.corpSecret = sec.Channels.WeComApp.CorpSecret
|
||||
}
|
||||
if sec.Channels.WeComApp.Token != "" {
|
||||
cfg.Channels.WeComApp.token = sec.Channels.WeComApp.Token
|
||||
}
|
||||
if sec.Channels.WeComApp.EncodingAESKey != "" {
|
||||
cfg.Channels.WeComApp.encodingAESKey = sec.Channels.WeComApp.EncodingAESKey
|
||||
}
|
||||
}
|
||||
|
||||
// Handle WeCom AI Bot credentials
|
||||
if sec.Channels.WeComAIBot != nil {
|
||||
if sec.Channels.WeComAIBot.Token != "" {
|
||||
cfg.Channels.WeComAIBot.token = sec.Channels.WeComAIBot.Token
|
||||
}
|
||||
if sec.Channels.WeComAIBot.EncodingAESKey != "" {
|
||||
cfg.Channels.WeComAIBot.encodingAESKey = sec.Channels.WeComAIBot.EncodingAESKey
|
||||
}
|
||||
if sec.Channels.WeComAIBot.Secret != "" {
|
||||
cfg.Channels.WeComAIBot.secret = sec.Channels.WeComAIBot.Secret
|
||||
if sec.Channels.WeCom.Secret != "" {
|
||||
cfg.Channels.WeCom.secret = sec.Channels.WeCom.Secret
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1881,27 +1739,10 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
}
|
||||
if cfg.Channels.WeCom.secDirty {
|
||||
cfg.security.Channels.WeCom = &WeComSecurity{
|
||||
Token: cfg.Channels.WeCom.Token(),
|
||||
EncodingAESKey: cfg.Channels.WeCom.EncodingAESKey(),
|
||||
Secret: cfg.Channels.WeCom.Secret(),
|
||||
}
|
||||
cfg.Channels.WeCom.secDirty = false
|
||||
}
|
||||
if cfg.Channels.WeComApp.secDirty {
|
||||
cfg.security.Channels.WeComApp = &WeComAppSecurity{
|
||||
CorpSecret: cfg.Channels.WeComApp.CorpSecret(),
|
||||
Token: cfg.Channels.WeComApp.Token(),
|
||||
EncodingAESKey: cfg.Channels.WeComApp.EncodingAESKey(),
|
||||
}
|
||||
cfg.Channels.WeComApp.secDirty = false
|
||||
}
|
||||
if cfg.Channels.WeComAIBot.secDirty {
|
||||
cfg.security.Channels.WeComAIBot = &WeComAIBotSecurity{
|
||||
Token: cfg.Channels.WeComAIBot.Token(),
|
||||
EncodingAESKey: cfg.Channels.WeComAIBot.EncodingAESKey(),
|
||||
Secret: cfg.Channels.WeComAIBot.Secret(),
|
||||
}
|
||||
cfg.Channels.WeComAIBot.secDirty = false
|
||||
}
|
||||
if cfg.Tools.Web.Brave.secDirty {
|
||||
cfg.security.Web.Brave = &BraveSecurity{
|
||||
APIKeys: cfg.Tools.Web.Brave.APIKeys(),
|
||||
|
||||
+63
-153
@@ -85,23 +85,21 @@ type toolsConfigV0 struct {
|
||||
}
|
||||
|
||||
type channelsConfigV0 struct {
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp"`
|
||||
Telegram telegramConfigV0 `json:"telegram"`
|
||||
Feishu feishuConfigV0 `json:"feishu"`
|
||||
Discord discordConfigV0 `json:"discord"`
|
||||
MaixCam maixcamConfigV0 `json:"maixcam"`
|
||||
Weixin weixinConfigV0 `json:"weixin"`
|
||||
QQ qqConfigV0 `json:"qq"`
|
||||
DingTalk dingtalkConfigV0 `json:"dingtalk"`
|
||||
Slack slackConfigV0 `json:"slack"`
|
||||
Matrix matrixConfigV0 `json:"matrix"`
|
||||
LINE lineConfigV0 `json:"line"`
|
||||
OneBot onebotConfigV0 `json:"onebot"`
|
||||
WeCom wecomConfigV0 `json:"wecom"`
|
||||
WeComApp wecomappConfigV0 `json:"wecom_app"`
|
||||
WeComAIBot wecomaibotConfigV0 `json:"wecom_aibot"`
|
||||
Pico picoConfigV0 `json:"pico"`
|
||||
IRC ircConfigV0 `json:"irc"`
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp"`
|
||||
Telegram telegramConfigV0 `json:"telegram"`
|
||||
Feishu feishuConfigV0 `json:"feishu"`
|
||||
Discord discordConfigV0 `json:"discord"`
|
||||
MaixCam maixcamConfigV0 `json:"maixcam"`
|
||||
Weixin weixinConfigV0 `json:"weixin"`
|
||||
QQ qqConfigV0 `json:"qq"`
|
||||
DingTalk dingtalkConfigV0 `json:"dingtalk"`
|
||||
Slack slackConfigV0 `json:"slack"`
|
||||
Matrix matrixConfigV0 `json:"matrix"`
|
||||
LINE lineConfigV0 `json:"line"`
|
||||
OneBot onebotConfigV0 `json:"onebot"`
|
||||
WeCom wecomConfigV0 `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
|
||||
Pico picoConfigV0 `json:"pico"`
|
||||
IRC ircConfigV0 `json:"irc"`
|
||||
}
|
||||
|
||||
func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity) {
|
||||
@@ -117,45 +115,39 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity)
|
||||
line, lineSecurity := v.LINE.ToLINEConfig()
|
||||
onebot, onebotSecurity := v.OneBot.ToOneBotConfig()
|
||||
wecom, wecomSecurity := v.WeCom.ToWeComConfig()
|
||||
wecomapp, wecomappSecurity := v.WeComApp.ToWeComAppConfig()
|
||||
wecomaibot, wecomaibotSecurity := v.WeComAIBot.ToWeComAIBotConfig()
|
||||
pico, picoSecurity := v.Pico.ToPicoConfig()
|
||||
irc, ircSecurity := v.IRC.ToIRCConfig()
|
||||
|
||||
return ChannelsConfig{
|
||||
WhatsApp: v.WhatsApp,
|
||||
Telegram: telegram,
|
||||
Feishu: feishu,
|
||||
Discord: discord,
|
||||
MaixCam: maixcam,
|
||||
QQ: qq,
|
||||
Weixin: weixin,
|
||||
DingTalk: dingtalk,
|
||||
Slack: slack,
|
||||
Matrix: matrix,
|
||||
LINE: line,
|
||||
OneBot: onebot,
|
||||
WeCom: wecom,
|
||||
WeComApp: wecomapp,
|
||||
WeComAIBot: wecomaibot,
|
||||
Pico: pico,
|
||||
IRC: irc,
|
||||
WhatsApp: v.WhatsApp,
|
||||
Telegram: telegram,
|
||||
Feishu: feishu,
|
||||
Discord: discord,
|
||||
MaixCam: maixcam,
|
||||
QQ: qq,
|
||||
Weixin: weixin,
|
||||
DingTalk: dingtalk,
|
||||
Slack: slack,
|
||||
Matrix: matrix,
|
||||
LINE: line,
|
||||
OneBot: onebot,
|
||||
WeCom: wecom,
|
||||
Pico: pico,
|
||||
IRC: irc,
|
||||
}, ChannelsSecurity{
|
||||
Telegram: telegramSecurity,
|
||||
Feishu: feishuSecurity,
|
||||
Discord: discordSecurity,
|
||||
QQ: qqSecurity,
|
||||
Weixin: weixinSecurity,
|
||||
DingTalk: dingtalkSecurity,
|
||||
Slack: slackSecurity,
|
||||
Matrix: matrixSecurity,
|
||||
LINE: lineSecurity,
|
||||
OneBot: onebotSecurity,
|
||||
WeCom: wecomSecurity,
|
||||
WeComApp: wecomappSecurity,
|
||||
WeComAIBot: wecomaibotSecurity,
|
||||
Pico: picoSecurity,
|
||||
IRC: ircSecurity,
|
||||
Telegram: telegramSecurity,
|
||||
Feishu: feishuSecurity,
|
||||
Discord: discordSecurity,
|
||||
QQ: qqSecurity,
|
||||
Weixin: weixinSecurity,
|
||||
DingTalk: dingtalkSecurity,
|
||||
Slack: slackSecurity,
|
||||
Matrix: matrixSecurity,
|
||||
LINE: lineSecurity,
|
||||
OneBot: onebotSecurity,
|
||||
WeCom: wecomSecurity,
|
||||
Pico: picoSecurity,
|
||||
IRC: ircSecurity,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,39 +465,32 @@ func (v *onebotConfigV0) ToOneBotConfig() (OneBotConfig, *OneBotSecurity) {
|
||||
}
|
||||
|
||||
type wecomConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
|
||||
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
|
||||
WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"`
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
BotID string `json:"bot_id" env:"BOT_ID"`
|
||||
Secret string `json:"secret" env:"SECRET"`
|
||||
WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"`
|
||||
SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"`
|
||||
DMPolicy string `json:"dm_policy,omitempty" env:"DM_POLICY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"`
|
||||
GroupPolicy string `json:"group_policy,omitempty" env:"GROUP_POLICY"`
|
||||
GroupAllowFrom FlexibleStringSlice `json:"group_allow_from,omitempty" env:"GROUP_ALLOW_FROM"`
|
||||
Groups map[string]WeComGroupConfig `json:"groups,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
func (v *wecomConfigV0) ToWeComConfig() (WeComConfig, *WeComSecurity) {
|
||||
var sec *WeComSecurity
|
||||
if v.Token != "" || v.EncodingAESKey != "" {
|
||||
sec = &WeComSecurity{
|
||||
Token: v.Token,
|
||||
EncodingAESKey: v.EncodingAESKey,
|
||||
}
|
||||
if v.Secret != "" {
|
||||
sec = &WeComSecurity{Secret: v.Secret}
|
||||
}
|
||||
return WeComConfig{
|
||||
Enabled: v.Enabled,
|
||||
token: v.Token,
|
||||
encodingAESKey: v.EncodingAESKey,
|
||||
WebhookURL: v.WebhookURL,
|
||||
WebhookHost: v.WebhookHost,
|
||||
WebhookPort: v.WebhookPort,
|
||||
WebhookPath: v.WebhookPath,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReplyTimeout: v.ReplyTimeout,
|
||||
GroupTrigger: v.GroupTrigger,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
Enabled: v.Enabled,
|
||||
BotID: v.BotID,
|
||||
secret: v.Secret,
|
||||
WebSocketURL: v.WebSocketURL,
|
||||
SendThinkingMessage: v.SendThinkingMessage,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
}, sec
|
||||
}
|
||||
|
||||
@@ -537,81 +522,6 @@ func (v *weixinConfigV0) ToWeiXinConfig() (WeixinConfig, *WeixinSecurity) {
|
||||
}, sec
|
||||
}
|
||||
|
||||
type wecomappConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
|
||||
CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
|
||||
CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
|
||||
AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
|
||||
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
func (v *wecomappConfigV0) ToWeComAppConfig() (WeComAppConfig, *WeComAppSecurity) {
|
||||
var sec *WeComAppSecurity
|
||||
if v.CorpSecret != "" || v.Token != "" || v.EncodingAESKey != "" {
|
||||
sec = &WeComAppSecurity{
|
||||
CorpSecret: v.CorpSecret,
|
||||
Token: v.Token,
|
||||
EncodingAESKey: v.EncodingAESKey,
|
||||
}
|
||||
}
|
||||
return WeComAppConfig{
|
||||
Enabled: v.Enabled,
|
||||
CorpID: v.CorpID,
|
||||
corpSecret: v.CorpSecret,
|
||||
AgentID: v.AgentID,
|
||||
token: v.Token,
|
||||
encodingAESKey: v.EncodingAESKey,
|
||||
WebhookHost: v.WebhookHost,
|
||||
WebhookPort: v.WebhookPort,
|
||||
WebhookPath: v.WebhookPath,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReplyTimeout: v.ReplyTimeout,
|
||||
GroupTrigger: v.GroupTrigger,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
}, sec
|
||||
}
|
||||
|
||||
type wecomaibotConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
|
||||
Secret string `json:"secret" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"`
|
||||
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
|
||||
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"`
|
||||
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
func (v *wecomaibotConfigV0) ToWeComAIBotConfig() (WeComAIBotConfig, *WeComAIBotSecurity) {
|
||||
var sec *WeComAIBotSecurity
|
||||
if v.Token != "" || v.Secret != "" || v.EncodingAESKey != "" {
|
||||
sec = &WeComAIBotSecurity{
|
||||
Token: v.Token,
|
||||
Secret: v.Secret,
|
||||
EncodingAESKey: v.EncodingAESKey,
|
||||
}
|
||||
}
|
||||
return WeComAIBotConfig{
|
||||
Enabled: v.Enabled,
|
||||
WebhookPath: v.WebhookPath,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReplyTimeout: v.ReplyTimeout,
|
||||
MaxSteps: v.MaxSteps,
|
||||
WelcomeMessage: v.WelcomeMessage,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
}, sec
|
||||
}
|
||||
|
||||
type picoConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
|
||||
|
||||
@@ -1372,8 +1372,7 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) {
|
||||
Feishu: &FeishuSecurity{AppSecret: "feishu-app-secret-123", EncryptKey: "feishu-encrypt-key"},
|
||||
DingTalk: &DingTalkSecurity{ClientSecret: "dingtalk-client-secret"},
|
||||
OneBot: &OneBotSecurity{AccessToken: "onebot-access-token"},
|
||||
WeCom: &WeComSecurity{Token: "wecom-token", EncodingAESKey: "wecom-aes-key"},
|
||||
WeComApp: &WeComAppSecurity{CorpSecret: "wecom-app-secret", Token: "wecom-app-token"},
|
||||
WeCom: &WeComSecurity{Secret: "wecom-secret"},
|
||||
Pico: &PicoSecurity{Token: "pico-token-abc123"},
|
||||
IRC: &IRCSecurity{
|
||||
Password: "irc-password",
|
||||
|
||||
+5
-26
@@ -131,32 +131,11 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
WeCom: WeComConfig{
|
||||
Enabled: false,
|
||||
WebhookURL: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18793,
|
||||
WebhookPath: "/webhook/wecom",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
},
|
||||
WeComApp: WeComAppConfig{
|
||||
Enabled: false,
|
||||
CorpID: "",
|
||||
AgentID: 0,
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18792,
|
||||
WebhookPath: "/webhook/wecom-app",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
},
|
||||
WeComAIBot: WeComAIBotConfig{
|
||||
Enabled: false,
|
||||
WebhookPath: "/webhook/wecom-aibot",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
MaxSteps: 10,
|
||||
WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
|
||||
ProcessingMessage: DefaultWeComAIBotProcessingMessage,
|
||||
Enabled: false,
|
||||
BotID: "",
|
||||
WebSocketURL: "wss://openws.work.weixin.qq.com",
|
||||
SendThinkingMessage: true,
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Weixin: WeixinConfig{
|
||||
Enabled: false,
|
||||
|
||||
@@ -153,13 +153,7 @@ Both single and multiple keys should use the array format.
|
||||
- ref:channels.line.channel_secret
|
||||
- ref:channels.line.channel_access_token
|
||||
- ref:channels.onebot.access_token
|
||||
- ref:channels.wecom.token
|
||||
- ref:channels.wecom.encoding_aes_key
|
||||
- ref:channels.wecom_app.corp_secret
|
||||
- ref:channels.wecom_app.token
|
||||
- ref:channels.wecom_app.encoding_aes_key
|
||||
- ref:channels.wecom_aibot.token
|
||||
- ref:channels.wecom_aibot.encoding_aes_key
|
||||
- ref:channels.wecom.secret
|
||||
- ref:channels.pico.token
|
||||
- ref:channels.irc.password
|
||||
- ref:channels.irc.nickserv_password
|
||||
|
||||
+15
-38
@@ -69,21 +69,19 @@ type ModelSecurityEntry struct {
|
||||
|
||||
// ChannelsSecurity stores channel-related security data
|
||||
type ChannelsSecurity struct {
|
||||
Telegram *TelegramSecurity `yaml:"telegram,omitempty"`
|
||||
Feishu *FeishuSecurity `yaml:"feishu,omitempty"`
|
||||
Discord *DiscordSecurity `yaml:"discord,omitempty"`
|
||||
Weixin *WeixinSecurity `yaml:"weixin,omitempty"`
|
||||
QQ *QQSecurity `yaml:"qq,omitempty"`
|
||||
DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"`
|
||||
Slack *SlackSecurity `yaml:"slack,omitempty"`
|
||||
Matrix *MatrixSecurity `yaml:"matrix,omitempty"`
|
||||
LINE *LINESecurity `yaml:"line,omitempty"`
|
||||
OneBot *OneBotSecurity `yaml:"onebot,omitempty"`
|
||||
WeCom *WeComSecurity `yaml:"wecom,omitempty"`
|
||||
WeComApp *WeComAppSecurity `yaml:"wecom_app,omitempty"`
|
||||
WeComAIBot *WeComAIBotSecurity `yaml:"wecom_aibot,omitempty"`
|
||||
Pico *PicoSecurity `yaml:"pico,omitempty"`
|
||||
IRC *IRCSecurity `yaml:"irc,omitempty"`
|
||||
Telegram *TelegramSecurity `yaml:"telegram,omitempty"`
|
||||
Feishu *FeishuSecurity `yaml:"feishu,omitempty"`
|
||||
Discord *DiscordSecurity `yaml:"discord,omitempty"`
|
||||
Weixin *WeixinSecurity `yaml:"weixin,omitempty"`
|
||||
QQ *QQSecurity `yaml:"qq,omitempty"`
|
||||
DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"`
|
||||
Slack *SlackSecurity `yaml:"slack,omitempty"`
|
||||
Matrix *MatrixSecurity `yaml:"matrix,omitempty"`
|
||||
LINE *LINESecurity `yaml:"line,omitempty"`
|
||||
OneBot *OneBotSecurity `yaml:"onebot,omitempty"`
|
||||
WeCom *WeComSecurity `yaml:"wecom,omitempty"`
|
||||
Pico *PicoSecurity `yaml:"pico,omitempty"`
|
||||
IRC *IRCSecurity `yaml:"irc,omitempty"`
|
||||
}
|
||||
|
||||
type TelegramSecurity struct {
|
||||
@@ -131,20 +129,7 @@ type OneBotSecurity struct {
|
||||
}
|
||||
|
||||
type WeComSecurity struct {
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
|
||||
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
|
||||
}
|
||||
|
||||
type WeComAppSecurity struct {
|
||||
CorpSecret string `yaml:"corp_secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
|
||||
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
|
||||
}
|
||||
|
||||
type WeComAIBotSecurity struct {
|
||||
Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"`
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
|
||||
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
|
||||
Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_SECRET"`
|
||||
}
|
||||
|
||||
type PicoSecurity struct {
|
||||
@@ -334,17 +319,9 @@ func mergeChannelsSecurity(dst, src *ChannelsSecurity) {
|
||||
if src.OneBot != nil && src.OneBot.AccessToken != "" {
|
||||
dst.OneBot = src.OneBot
|
||||
}
|
||||
if src.WeCom != nil && (src.WeCom.Token != "" || src.WeCom.EncodingAESKey != "") {
|
||||
if src.WeCom != nil && src.WeCom.Secret != "" {
|
||||
dst.WeCom = src.WeCom
|
||||
}
|
||||
if src.WeComApp != nil &&
|
||||
(src.WeComApp.CorpSecret != "" || src.WeComApp.Token != "" || src.WeComApp.EncodingAESKey != "") {
|
||||
dst.WeComApp = src.WeComApp
|
||||
}
|
||||
if src.WeComAIBot != nil &&
|
||||
(src.WeComAIBot.Secret != "" || src.WeComAIBot.Token != "" || src.WeComAIBot.EncodingAESKey != "") {
|
||||
dst.WeComAIBot = src.WeComAIBot
|
||||
}
|
||||
if src.Pico != nil && src.Pico.Token != "" {
|
||||
dst.Pico = src.Pico
|
||||
}
|
||||
|
||||
@@ -240,15 +240,7 @@ func TestAllSecurityKeysAccessible(t *testing.T) {
|
||||
},
|
||||
"wecom": {
|
||||
"enabled": true,
|
||||
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook"
|
||||
},
|
||||
"wecom_app": {
|
||||
"enabled": true,
|
||||
"corp_id": "test_corp_id",
|
||||
"agent_id": 123456
|
||||
},
|
||||
"wecom_aibot": {
|
||||
"enabled": true
|
||||
"bot_id": "test_wecom_bot_id"
|
||||
},
|
||||
"pico": {
|
||||
"enabled": true
|
||||
@@ -315,15 +307,7 @@ channels:
|
||||
onebot:
|
||||
access_token: "onebot_test_access_token"
|
||||
wecom:
|
||||
token: "wecom_test_webhook_token"
|
||||
encoding_aes_key: "wecom_test_aes_key"
|
||||
wecom_app:
|
||||
corp_secret: "wecom_app_test_corp_secret"
|
||||
token: "wecom_app_test_token"
|
||||
encoding_aes_key: "wecom_app_test_aes_key"
|
||||
wecom_aibot:
|
||||
token: "wecom_aibot_test_token"
|
||||
encoding_aes_key: "wecom_aibot_test_aes_key"
|
||||
secret: "wecom_test_secret"
|
||||
pico:
|
||||
token: "pico_test_token"
|
||||
irc:
|
||||
@@ -409,24 +393,10 @@ skills:
|
||||
t.Logf("OneBot AccessToken(): %s", cfg.Channels.OneBot.AccessToken())
|
||||
|
||||
// WeCom
|
||||
assert.Equal(t, "wecom_test_webhook_token", cfg.Channels.WeCom.Token())
|
||||
assert.Equal(t, "wecom_test_aes_key", cfg.Channels.WeCom.EncodingAESKey())
|
||||
t.Logf("WeCom Token(): %s", cfg.Channels.WeCom.Token())
|
||||
t.Logf("WeCom EncodingAESKey(): %s", cfg.Channels.WeCom.EncodingAESKey())
|
||||
|
||||
// WeCom App
|
||||
assert.Equal(t, "wecom_app_test_corp_secret", cfg.Channels.WeComApp.CorpSecret())
|
||||
assert.Equal(t, "wecom_app_test_token", cfg.Channels.WeComApp.Token())
|
||||
assert.Equal(t, "wecom_app_test_aes_key", cfg.Channels.WeComApp.EncodingAESKey())
|
||||
t.Logf("WeComApp CorpSecret(): %s", cfg.Channels.WeComApp.CorpSecret())
|
||||
t.Logf("WeComApp Token(): %s", cfg.Channels.WeComApp.Token())
|
||||
t.Logf("WeComApp EncodingAESKey(): %s", cfg.Channels.WeComApp.EncodingAESKey())
|
||||
|
||||
// WeCom AI Bot
|
||||
assert.Equal(t, "wecom_aibot_test_token", cfg.Channels.WeComAIBot.Token())
|
||||
assert.Equal(t, "wecom_aibot_test_aes_key", cfg.Channels.WeComAIBot.EncodingAESKey())
|
||||
t.Logf("WeComAIBot Token(): %s", cfg.Channels.WeComAIBot.Token())
|
||||
t.Logf("WeComAIBot EncodingAESKey(): %s", cfg.Channels.WeComAIBot.EncodingAESKey())
|
||||
assert.Equal(t, "test_wecom_bot_id", cfg.Channels.WeCom.BotID)
|
||||
assert.Equal(t, "wecom_test_secret", cfg.Channels.WeCom.Secret())
|
||||
t.Logf("WeCom BotID: %s", cfg.Channels.WeCom.BotID)
|
||||
t.Logf("WeCom Secret(): %s", cfg.Channels.WeCom.Secret())
|
||||
|
||||
// Pico
|
||||
assert.Equal(t, "pico_test_token", cfg.Channels.Pico.Token())
|
||||
|
||||
@@ -13,17 +13,16 @@ var migrateableDirs = []string{
|
||||
}
|
||||
|
||||
var supportedChannels = map[string]bool{
|
||||
"whatsapp": true,
|
||||
"telegram": true,
|
||||
"feishu": true,
|
||||
"discord": true,
|
||||
"maixcam": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"slack": true,
|
||||
"matrix": true,
|
||||
"line": true,
|
||||
"onebot": true,
|
||||
"wecom": true,
|
||||
"wecom_app": true,
|
||||
"whatsapp": true,
|
||||
"telegram": true,
|
||||
"feishu": true,
|
||||
"discord": true,
|
||||
"maixcam": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"slack": true,
|
||||
"matrix": true,
|
||||
"line": true,
|
||||
"onebot": true,
|
||||
"wecom": true,
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ var channelCatalog = []channelCatalogItem{
|
||||
{Name: "qq", ConfigKey: "qq"},
|
||||
{Name: "onebot", ConfigKey: "onebot"},
|
||||
{Name: "wecom", ConfigKey: "wecom"},
|
||||
{Name: "wecom_app", ConfigKey: "wecom_app"},
|
||||
{Name: "wecom_aibot", ConfigKey: "wecom_aibot"},
|
||||
{Name: "whatsapp", ConfigKey: "whatsapp", Variant: "bridge"},
|
||||
{Name: "whatsapp_native", ConfigKey: "whatsapp", Variant: "native"},
|
||||
{Name: "pico", ConfigKey: "pico"},
|
||||
|
||||
@@ -209,6 +209,15 @@ func validateConfig(cfg *config.Config) []string {
|
||||
errs = append(errs, "channels.discord.token is required when discord channel is enabled")
|
||||
}
|
||||
|
||||
if cfg.Channels.WeCom.Enabled {
|
||||
if cfg.Channels.WeCom.BotID == "" {
|
||||
errs = append(errs, "channels.wecom.bot_id is required when wecom channel is enabled")
|
||||
}
|
||||
if cfg.Channels.WeCom.Secret() == "" {
|
||||
errs = append(errs, "channels.wecom.secret is required when wecom channel is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Tools.Exec.Enabled {
|
||||
if cfg.Tools.Exec.EnableDenyPatterns {
|
||||
errs = append(
|
||||
|
||||
@@ -145,13 +145,7 @@ function isConfigured(
|
||||
case "weixin":
|
||||
return asString(config.account_id) !== ""
|
||||
case "wecom":
|
||||
return asString(config.token) !== ""
|
||||
case "wecom_app":
|
||||
return (
|
||||
asString(config.corp_id) !== "" && asString(config.corp_secret) !== ""
|
||||
)
|
||||
case "wecom_aibot":
|
||||
return asString(config.token) !== ""
|
||||
return asString(config.bot_id) !== ""
|
||||
case "whatsapp":
|
||||
return asString(config.bridge_url) !== ""
|
||||
case "whatsapp_native":
|
||||
@@ -192,11 +186,7 @@ function getRequiredFieldKeys(channelName: string): string[] {
|
||||
case "onebot":
|
||||
return ["ws_url"]
|
||||
case "wecom":
|
||||
return ["token"]
|
||||
case "wecom_app":
|
||||
return ["corp_id", "corp_secret"]
|
||||
case "wecom_aibot":
|
||||
return ["token"]
|
||||
return ["bot_id", "secret"]
|
||||
case "whatsapp":
|
||||
return ["bridge_url"]
|
||||
case "pico":
|
||||
|
||||
@@ -28,6 +28,7 @@ const SECRET_FIELDS = new Set([
|
||||
"encoding_aes_key",
|
||||
"encrypt_key",
|
||||
"verification_token",
|
||||
"secret",
|
||||
"password",
|
||||
"nickserv_password",
|
||||
"sasl_password",
|
||||
@@ -44,6 +45,7 @@ const OBJECT_FIELDS = new Set([
|
||||
"allow_token_query",
|
||||
"allow_from",
|
||||
"allow_origins",
|
||||
"groups",
|
||||
])
|
||||
|
||||
function formatLabel(key: string): string {
|
||||
@@ -118,6 +120,14 @@ export function GenericForm({
|
||||
app_id: t("channels.form.desc.appId"),
|
||||
client_id: t("channels.form.desc.clientId"),
|
||||
corp_id: t("channels.form.desc.corpId"),
|
||||
bot_id: t("channels.form.desc.appId"),
|
||||
websocket_url: t("channels.form.desc.wsUrl"),
|
||||
dm_policy: t("channels.form.desc.genericField", { field: "DM policy" }),
|
||||
group_policy: t("channels.form.desc.genericField", { field: "group policy" }),
|
||||
group_allow_from: t("channels.form.desc.allowFrom"),
|
||||
send_thinking_message: t("channels.form.desc.genericField", {
|
||||
field: "thinking message behavior",
|
||||
}),
|
||||
agent_id: t("channels.form.desc.agentId"),
|
||||
webhook_url: t("channels.form.desc.webhookUrl"),
|
||||
webhook_host: t("channels.form.desc.webhookHost"),
|
||||
|
||||
@@ -32,8 +32,6 @@ const CHANNEL_IMPORTANCE_TAIL = [
|
||||
"slack",
|
||||
"line",
|
||||
"wecom",
|
||||
"wecom_app",
|
||||
"wecom_aibot",
|
||||
"dingtalk",
|
||||
"qq",
|
||||
"onebot",
|
||||
@@ -78,8 +76,6 @@ const CHANNEL_ICON_MAP: Record<
|
||||
qq: IconBrandQq,
|
||||
weixin: IconBrandWechat,
|
||||
wecom: IconBrandWechat,
|
||||
wecom_app: IconBrandWechat,
|
||||
wecom_aibot: IconBrandWechat,
|
||||
whatsapp: IconBrandWhatsapp,
|
||||
whatsapp_native: IconBrandWhatsapp,
|
||||
matrix: IconBrandMatrix,
|
||||
|
||||
@@ -233,8 +233,6 @@
|
||||
"qq": "QQ",
|
||||
"onebot": "OneBot",
|
||||
"wecom": "WeCom",
|
||||
"wecom_app": "WeCom App",
|
||||
"wecom_aibot": "WeCom AI Bot",
|
||||
"whatsapp": "WhatsApp",
|
||||
"whatsapp_native": "WhatsApp Native",
|
||||
"pico": "Web",
|
||||
|
||||
@@ -233,8 +233,6 @@
|
||||
"qq": "QQ",
|
||||
"onebot": "OneBot",
|
||||
"wecom": "企业微信",
|
||||
"wecom_app": "企业微信应用",
|
||||
"wecom_aibot": "企业微信 AI 机器人",
|
||||
"whatsapp": "WhatsApp",
|
||||
"whatsapp_native": "WhatsApp Native",
|
||||
"pico": "Web",
|
||||
|
||||
Reference in New Issue
Block a user